handle empty fused norm result (#9688)

* handle empty fused norm result

* remove fast_rms_norm

* fix style
This commit is contained in:
Xin Qiu 2023-12-18 09:56:11 +08:00 committed by GitHub
parent a5c481fedd
commit 320110d158
4 changed files with 51 additions and 69 deletions

View file

@ -49,21 +49,15 @@ KV_CACHE_ALLOC_BLOCK_LENGTH = 256
def baichuan_13b_rms_norm_forward(self, hidden_states):
if hidden_states.device.type == "xpu" and not (self.training and hidden_states.requires_grad):
if get_ipex_version() <= "2.0.110+xpu":
import linear_q4_0
hidden_states = linear_q4_0.fused_rms_norm(hidden_states,
result = linear_q4_0.fused_rms_norm(hidden_states,
[self.weight.size(0)],
self.weight,
None,
self.epsilon)
else:
hidden_states = torch.ops.torch_ipex.fast_rms_norm(hidden_states,
[self.weight.size(0)],
self.weight,
None,
self.epsilon)
return hidden_states
else:
# if nelement == 0, means fused norm failed, go back to python implement.
if result.nelement != 0:
return result
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)

View file

@ -65,13 +65,14 @@ def dropout_add(x: torch.Tensor, residual: torch.Tensor, prob: float, training:
def bloom_layer_norm_forward(self, hidden_states):
if hidden_states.device.type == "xpu" and not (self.training and hidden_states.requires_grad):
import linear_q4_0
hidden_states = linear_q4_0.fused_layer_norm(hidden_states,
result = linear_q4_0.fused_layer_norm(hidden_states,
[self.weight.size(0)],
self.weight,
self.bias,
self.eps)
return hidden_states
else:
# if nelement == 0, means fused norm failed, go back to python implement.
if result.nelement != 0:
return result
return F.layer_norm(hidden_states, self.normalized_shape, self.weight, self.bias, self.eps)

View file

@ -78,22 +78,15 @@ def apply_rotary_pos_emb_chatglm(x: torch.Tensor, rope_cache: torch.Tensor) -> t
def chatglm_rms_norm_forward(self, hidden_states):
if hidden_states.device.type == "xpu" and not (self.training and hidden_states.requires_grad):
if get_ipex_version() <= "2.0.110+xpu":
import linear_q4_0
hidden_states = linear_q4_0.fused_rms_norm(hidden_states,
result = linear_q4_0.fused_rms_norm(hidden_states,
[self.weight.size(0)],
self.weight,
None,
self.eps)
else:
# for ipex >= 2.1
hidden_states = torch.ops.torch_ipex.fast_rms_norm(hidden_states,
[self.weight.size(0)],
self.weight,
None, # bias
self.eps)
return hidden_states
else:
# if nelement == 0, means fused norm failed, go back to python implement.
if result.nelement != 0:
return result
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)

View file

@ -75,21 +75,15 @@ def get_ipex_version():
def llama_rms_norm_forward(self, hidden_states):
if hidden_states.device.type == "xpu" and not (self.training and hidden_states.requires_grad):
if get_ipex_version() <= "2.0.110+xpu":
import linear_q4_0
hidden_states = linear_q4_0.fused_rms_norm(hidden_states,
result = linear_q4_0.fused_rms_norm(hidden_states,
[self.weight.size(0)],
self.weight,
None,
self.variance_epsilon)
else:
hidden_states = torch.ops.torch_ipex.fast_rms_norm(hidden_states,
[self.weight.size(0)],
self.weight,
None,
self.variance_epsilon)
return hidden_states
else:
# if nelement == 0, means fused norm failed, go back to python implement.
if result.nelement != 0:
return result
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)