handle empty fused norm result (#9688)
* handle empty fused norm result * remove fast_rms_norm * fix style
This commit is contained in:
parent
a5c481fedd
commit
320110d158
4 changed files with 51 additions and 69 deletions
|
|
@ -49,26 +49,20 @@ KV_CACHE_ALLOC_BLOCK_LENGTH = 256
|
||||||
|
|
||||||
def baichuan_13b_rms_norm_forward(self, hidden_states):
|
def baichuan_13b_rms_norm_forward(self, hidden_states):
|
||||||
if hidden_states.device.type == "xpu" and not (self.training and hidden_states.requires_grad):
|
if hidden_states.device.type == "xpu" and not (self.training and hidden_states.requires_grad):
|
||||||
if get_ipex_version() <= "2.0.110+xpu":
|
import linear_q4_0
|
||||||
import linear_q4_0
|
result = linear_q4_0.fused_rms_norm(hidden_states,
|
||||||
hidden_states = linear_q4_0.fused_rms_norm(hidden_states,
|
[self.weight.size(0)],
|
||||||
[self.weight.size(0)],
|
self.weight,
|
||||||
self.weight,
|
None,
|
||||||
None,
|
self.epsilon)
|
||||||
self.epsilon)
|
# if nelement == 0, means fused norm failed, go back to python implement.
|
||||||
else:
|
if result.nelement != 0:
|
||||||
hidden_states = torch.ops.torch_ipex.fast_rms_norm(hidden_states,
|
return result
|
||||||
[self.weight.size(0)],
|
input_dtype = hidden_states.dtype
|
||||||
self.weight,
|
hidden_states = hidden_states.to(torch.float32)
|
||||||
None,
|
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
||||||
self.epsilon)
|
hidden_states = hidden_states * torch.rsqrt(variance + self.epsilon)
|
||||||
return hidden_states
|
return self.weight * hidden_states.to(input_dtype)
|
||||||
else:
|
|
||||||
input_dtype = hidden_states.dtype
|
|
||||||
hidden_states = hidden_states.to(torch.float32)
|
|
||||||
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
|
||||||
hidden_states = hidden_states * torch.rsqrt(variance + self.epsilon)
|
|
||||||
return self.weight * hidden_states.to(input_dtype)
|
|
||||||
|
|
||||||
|
|
||||||
def baichuan_mlp_forward(
|
def baichuan_mlp_forward(
|
||||||
|
|
|
||||||
|
|
@ -65,14 +65,15 @@ def dropout_add(x: torch.Tensor, residual: torch.Tensor, prob: float, training:
|
||||||
def bloom_layer_norm_forward(self, hidden_states):
|
def bloom_layer_norm_forward(self, hidden_states):
|
||||||
if hidden_states.device.type == "xpu" and not (self.training and hidden_states.requires_grad):
|
if hidden_states.device.type == "xpu" and not (self.training and hidden_states.requires_grad):
|
||||||
import linear_q4_0
|
import linear_q4_0
|
||||||
hidden_states = linear_q4_0.fused_layer_norm(hidden_states,
|
result = linear_q4_0.fused_layer_norm(hidden_states,
|
||||||
[self.weight.size(0)],
|
[self.weight.size(0)],
|
||||||
self.weight,
|
self.weight,
|
||||||
self.bias,
|
self.bias,
|
||||||
self.eps)
|
self.eps)
|
||||||
return hidden_states
|
# if nelement == 0, means fused norm failed, go back to python implement.
|
||||||
else:
|
if result.nelement != 0:
|
||||||
return F.layer_norm(hidden_states, self.normalized_shape, self.weight, self.bias, self.eps)
|
return result
|
||||||
|
return F.layer_norm(hidden_states, self.normalized_shape, self.weight, self.bias, self.eps)
|
||||||
|
|
||||||
|
|
||||||
def bloom_attention_forward(
|
def bloom_attention_forward(
|
||||||
|
|
|
||||||
|
|
@ -78,27 +78,20 @@ def apply_rotary_pos_emb_chatglm(x: torch.Tensor, rope_cache: torch.Tensor) -> t
|
||||||
|
|
||||||
def chatglm_rms_norm_forward(self, hidden_states):
|
def chatglm_rms_norm_forward(self, hidden_states):
|
||||||
if hidden_states.device.type == "xpu" and not (self.training and hidden_states.requires_grad):
|
if hidden_states.device.type == "xpu" and not (self.training and hidden_states.requires_grad):
|
||||||
if get_ipex_version() <= "2.0.110+xpu":
|
import linear_q4_0
|
||||||
import linear_q4_0
|
result = linear_q4_0.fused_rms_norm(hidden_states,
|
||||||
hidden_states = linear_q4_0.fused_rms_norm(hidden_states,
|
[self.weight.size(0)],
|
||||||
[self.weight.size(0)],
|
self.weight,
|
||||||
self.weight,
|
None,
|
||||||
None,
|
self.eps)
|
||||||
self.eps)
|
# if nelement == 0, means fused norm failed, go back to python implement.
|
||||||
else:
|
if result.nelement != 0:
|
||||||
# for ipex >= 2.1
|
return result
|
||||||
hidden_states = torch.ops.torch_ipex.fast_rms_norm(hidden_states,
|
input_dtype = hidden_states.dtype
|
||||||
[self.weight.size(0)],
|
hidden_states = hidden_states.to(torch.float32)
|
||||||
self.weight,
|
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
||||||
None, # bias
|
hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
|
||||||
self.eps)
|
return self.weight * hidden_states.to(input_dtype)
|
||||||
return hidden_states
|
|
||||||
else:
|
|
||||||
input_dtype = hidden_states.dtype
|
|
||||||
hidden_states = hidden_states.to(torch.float32)
|
|
||||||
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
|
||||||
hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
|
|
||||||
return self.weight * hidden_states.to(input_dtype)
|
|
||||||
|
|
||||||
|
|
||||||
def chatglm2_model_forward(
|
def chatglm2_model_forward(
|
||||||
|
|
|
||||||
|
|
@ -75,26 +75,20 @@ def get_ipex_version():
|
||||||
|
|
||||||
def llama_rms_norm_forward(self, hidden_states):
|
def llama_rms_norm_forward(self, hidden_states):
|
||||||
if hidden_states.device.type == "xpu" and not (self.training and hidden_states.requires_grad):
|
if hidden_states.device.type == "xpu" and not (self.training and hidden_states.requires_grad):
|
||||||
if get_ipex_version() <= "2.0.110+xpu":
|
import linear_q4_0
|
||||||
import linear_q4_0
|
result = linear_q4_0.fused_rms_norm(hidden_states,
|
||||||
hidden_states = linear_q4_0.fused_rms_norm(hidden_states,
|
[self.weight.size(0)],
|
||||||
[self.weight.size(0)],
|
self.weight,
|
||||||
self.weight,
|
None,
|
||||||
None,
|
self.variance_epsilon)
|
||||||
self.variance_epsilon)
|
# if nelement == 0, means fused norm failed, go back to python implement.
|
||||||
else:
|
if result.nelement != 0:
|
||||||
hidden_states = torch.ops.torch_ipex.fast_rms_norm(hidden_states,
|
return result
|
||||||
[self.weight.size(0)],
|
input_dtype = hidden_states.dtype
|
||||||
self.weight,
|
hidden_states = hidden_states.to(torch.float32)
|
||||||
None,
|
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
||||||
self.variance_epsilon)
|
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
||||||
return hidden_states
|
return self.weight * hidden_states.to(input_dtype)
|
||||||
else:
|
|
||||||
input_dtype = hidden_states.dtype
|
|
||||||
hidden_states = hidden_states.to(torch.float32)
|
|
||||||
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
|
||||||
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
|
||||||
return self.weight * hidden_states.to(input_dtype)
|
|
||||||
|
|
||||||
|
|
||||||
def llama_attention_forward_4_31(
|
def llama_attention_forward_4_31(
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue