handle empty fused norm result (#9688)
* handle empty fused norm result * remove fast_rms_norm * fix style
This commit is contained in:
parent
a5c481fedd
commit
320110d158
4 changed files with 51 additions and 69 deletions
|
|
@ -49,26 +49,20 @@ KV_CACHE_ALLOC_BLOCK_LENGTH = 256
|
|||
|
||||
def baichuan_13b_rms_norm_forward(self, hidden_states):
|
||||
if hidden_states.device.type == "xpu" and not (self.training and hidden_states.requires_grad):
|
||||
if get_ipex_version() <= "2.0.110+xpu":
|
||||
import linear_q4_0
|
||||
hidden_states = linear_q4_0.fused_rms_norm(hidden_states,
|
||||
[self.weight.size(0)],
|
||||
self.weight,
|
||||
None,
|
||||
self.epsilon)
|
||||
else:
|
||||
hidden_states = torch.ops.torch_ipex.fast_rms_norm(hidden_states,
|
||||
[self.weight.size(0)],
|
||||
self.weight,
|
||||
None,
|
||||
self.epsilon)
|
||||
return hidden_states
|
||||
else:
|
||||
input_dtype = hidden_states.dtype
|
||||
hidden_states = hidden_states.to(torch.float32)
|
||||
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
||||
hidden_states = hidden_states * torch.rsqrt(variance + self.epsilon)
|
||||
return self.weight * hidden_states.to(input_dtype)
|
||||
import linear_q4_0
|
||||
result = linear_q4_0.fused_rms_norm(hidden_states,
|
||||
[self.weight.size(0)],
|
||||
self.weight,
|
||||
None,
|
||||
self.epsilon)
|
||||
# if nelement == 0, means fused norm failed, go back to python implement.
|
||||
if result.nelement != 0:
|
||||
return result
|
||||
input_dtype = hidden_states.dtype
|
||||
hidden_states = hidden_states.to(torch.float32)
|
||||
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
||||
hidden_states = hidden_states * torch.rsqrt(variance + self.epsilon)
|
||||
return self.weight * hidden_states.to(input_dtype)
|
||||
|
||||
|
||||
def baichuan_mlp_forward(
|
||||
|
|
|
|||
|
|
@ -65,14 +65,15 @@ def dropout_add(x: torch.Tensor, residual: torch.Tensor, prob: float, training:
|
|||
def bloom_layer_norm_forward(self, hidden_states):
|
||||
if hidden_states.device.type == "xpu" and not (self.training and hidden_states.requires_grad):
|
||||
import linear_q4_0
|
||||
hidden_states = linear_q4_0.fused_layer_norm(hidden_states,
|
||||
[self.weight.size(0)],
|
||||
self.weight,
|
||||
self.bias,
|
||||
self.eps)
|
||||
return hidden_states
|
||||
else:
|
||||
return F.layer_norm(hidden_states, self.normalized_shape, self.weight, self.bias, self.eps)
|
||||
result = linear_q4_0.fused_layer_norm(hidden_states,
|
||||
[self.weight.size(0)],
|
||||
self.weight,
|
||||
self.bias,
|
||||
self.eps)
|
||||
# if nelement == 0, means fused norm failed, go back to python implement.
|
||||
if result.nelement != 0:
|
||||
return result
|
||||
return F.layer_norm(hidden_states, self.normalized_shape, self.weight, self.bias, self.eps)
|
||||
|
||||
|
||||
def bloom_attention_forward(
|
||||
|
|
|
|||
|
|
@ -78,27 +78,20 @@ def apply_rotary_pos_emb_chatglm(x: torch.Tensor, rope_cache: torch.Tensor) -> t
|
|||
|
||||
def chatglm_rms_norm_forward(self, hidden_states):
|
||||
if hidden_states.device.type == "xpu" and not (self.training and hidden_states.requires_grad):
|
||||
if get_ipex_version() <= "2.0.110+xpu":
|
||||
import linear_q4_0
|
||||
hidden_states = linear_q4_0.fused_rms_norm(hidden_states,
|
||||
[self.weight.size(0)],
|
||||
self.weight,
|
||||
None,
|
||||
self.eps)
|
||||
else:
|
||||
# for ipex >= 2.1
|
||||
hidden_states = torch.ops.torch_ipex.fast_rms_norm(hidden_states,
|
||||
[self.weight.size(0)],
|
||||
self.weight,
|
||||
None, # bias
|
||||
self.eps)
|
||||
return hidden_states
|
||||
else:
|
||||
input_dtype = hidden_states.dtype
|
||||
hidden_states = hidden_states.to(torch.float32)
|
||||
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
||||
hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
|
||||
return self.weight * hidden_states.to(input_dtype)
|
||||
import linear_q4_0
|
||||
result = linear_q4_0.fused_rms_norm(hidden_states,
|
||||
[self.weight.size(0)],
|
||||
self.weight,
|
||||
None,
|
||||
self.eps)
|
||||
# if nelement == 0, means fused norm failed, go back to python implement.
|
||||
if result.nelement != 0:
|
||||
return result
|
||||
input_dtype = hidden_states.dtype
|
||||
hidden_states = hidden_states.to(torch.float32)
|
||||
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
||||
hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
|
||||
return self.weight * hidden_states.to(input_dtype)
|
||||
|
||||
|
||||
def chatglm2_model_forward(
|
||||
|
|
|
|||
|
|
@ -75,26 +75,20 @@ def get_ipex_version():
|
|||
|
||||
def llama_rms_norm_forward(self, hidden_states):
|
||||
if hidden_states.device.type == "xpu" and not (self.training and hidden_states.requires_grad):
|
||||
if get_ipex_version() <= "2.0.110+xpu":
|
||||
import linear_q4_0
|
||||
hidden_states = linear_q4_0.fused_rms_norm(hidden_states,
|
||||
[self.weight.size(0)],
|
||||
self.weight,
|
||||
None,
|
||||
self.variance_epsilon)
|
||||
else:
|
||||
hidden_states = torch.ops.torch_ipex.fast_rms_norm(hidden_states,
|
||||
[self.weight.size(0)],
|
||||
self.weight,
|
||||
None,
|
||||
self.variance_epsilon)
|
||||
return hidden_states
|
||||
else:
|
||||
input_dtype = hidden_states.dtype
|
||||
hidden_states = hidden_states.to(torch.float32)
|
||||
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
||||
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
||||
return self.weight * hidden_states.to(input_dtype)
|
||||
import linear_q4_0
|
||||
result = linear_q4_0.fused_rms_norm(hidden_states,
|
||||
[self.weight.size(0)],
|
||||
self.weight,
|
||||
None,
|
||||
self.variance_epsilon)
|
||||
# if nelement == 0, means fused norm failed, go back to python implement.
|
||||
if result.nelement != 0:
|
||||
return result
|
||||
input_dtype = hidden_states.dtype
|
||||
hidden_states = hidden_states.to(torch.float32)
|
||||
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
||||
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
||||
return self.weight * hidden_states.to(input_dtype)
|
||||
|
||||
|
||||
def llama_attention_forward_4_31(
|
||||
|
|
|
|||
Loading…
Reference in a new issue