diff --git a/python/llm/src/bigdl/llm/transformers/models/baichuan2.py b/python/llm/src/bigdl/llm/transformers/models/baichuan2.py index 7dfbca8e..f5b7dcfb 100644 --- a/python/llm/src/bigdl/llm/transformers/models/baichuan2.py +++ b/python/llm/src/bigdl/llm/transformers/models/baichuan2.py @@ -49,26 +49,20 @@ KV_CACHE_ALLOC_BLOCK_LENGTH = 256 def baichuan_13b_rms_norm_forward(self, hidden_states): if hidden_states.device.type == "xpu" and not (self.training and hidden_states.requires_grad): - if get_ipex_version() <= "2.0.110+xpu": - import linear_q4_0 - hidden_states = linear_q4_0.fused_rms_norm(hidden_states, - [self.weight.size(0)], - self.weight, - None, - self.epsilon) - else: - hidden_states = torch.ops.torch_ipex.fast_rms_norm(hidden_states, - [self.weight.size(0)], - self.weight, - None, - self.epsilon) - return hidden_states - else: - input_dtype = hidden_states.dtype - hidden_states = hidden_states.to(torch.float32) - variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.epsilon) - return self.weight * hidden_states.to(input_dtype) + import linear_q4_0 + result = linear_q4_0.fused_rms_norm(hidden_states, + [self.weight.size(0)], + self.weight, + None, + self.epsilon) + # if nelement == 0, means fused norm failed, go back to python implement. + if result.nelement != 0: + return result + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.epsilon) + return self.weight * hidden_states.to(input_dtype) def baichuan_mlp_forward( diff --git a/python/llm/src/bigdl/llm/transformers/models/bloom.py b/python/llm/src/bigdl/llm/transformers/models/bloom.py index 7daaba78..88c02fc2 100644 --- a/python/llm/src/bigdl/llm/transformers/models/bloom.py +++ b/python/llm/src/bigdl/llm/transformers/models/bloom.py @@ -65,14 +65,15 @@ def dropout_add(x: torch.Tensor, residual: torch.Tensor, prob: float, training: def bloom_layer_norm_forward(self, hidden_states): if hidden_states.device.type == "xpu" and not (self.training and hidden_states.requires_grad): import linear_q4_0 - hidden_states = linear_q4_0.fused_layer_norm(hidden_states, - [self.weight.size(0)], - self.weight, - self.bias, - self.eps) - return hidden_states - else: - return F.layer_norm(hidden_states, self.normalized_shape, self.weight, self.bias, self.eps) + result = linear_q4_0.fused_layer_norm(hidden_states, + [self.weight.size(0)], + self.weight, + self.bias, + self.eps) + # if nelement == 0, means fused norm failed, go back to python implement. + if result.nelement != 0: + return result + return F.layer_norm(hidden_states, self.normalized_shape, self.weight, self.bias, self.eps) def bloom_attention_forward( diff --git a/python/llm/src/bigdl/llm/transformers/models/chatglm2.py b/python/llm/src/bigdl/llm/transformers/models/chatglm2.py index e99cea5a..8276b967 100644 --- a/python/llm/src/bigdl/llm/transformers/models/chatglm2.py +++ b/python/llm/src/bigdl/llm/transformers/models/chatglm2.py @@ -78,27 +78,20 @@ def apply_rotary_pos_emb_chatglm(x: torch.Tensor, rope_cache: torch.Tensor) -> t def chatglm_rms_norm_forward(self, hidden_states): if hidden_states.device.type == "xpu" and not (self.training and hidden_states.requires_grad): - if get_ipex_version() <= "2.0.110+xpu": - import linear_q4_0 - hidden_states = linear_q4_0.fused_rms_norm(hidden_states, - [self.weight.size(0)], - self.weight, - None, - self.eps) - else: - # for ipex >= 2.1 - hidden_states = torch.ops.torch_ipex.fast_rms_norm(hidden_states, - [self.weight.size(0)], - self.weight, - None, # bias - self.eps) - return hidden_states - else: - input_dtype = hidden_states.dtype - hidden_states = hidden_states.to(torch.float32) - variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.eps) - return self.weight * hidden_states.to(input_dtype) + import linear_q4_0 + result = linear_q4_0.fused_rms_norm(hidden_states, + [self.weight.size(0)], + self.weight, + None, + self.eps) + # if nelement == 0, means fused norm failed, go back to python implement. + if result.nelement != 0: + return result + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.eps) + return self.weight * hidden_states.to(input_dtype) def chatglm2_model_forward( diff --git a/python/llm/src/bigdl/llm/transformers/models/llama.py b/python/llm/src/bigdl/llm/transformers/models/llama.py index 63d0d851..a951169b 100644 --- a/python/llm/src/bigdl/llm/transformers/models/llama.py +++ b/python/llm/src/bigdl/llm/transformers/models/llama.py @@ -75,26 +75,20 @@ def get_ipex_version(): def llama_rms_norm_forward(self, hidden_states): if hidden_states.device.type == "xpu" and not (self.training and hidden_states.requires_grad): - if get_ipex_version() <= "2.0.110+xpu": - import linear_q4_0 - hidden_states = linear_q4_0.fused_rms_norm(hidden_states, - [self.weight.size(0)], - self.weight, - None, - self.variance_epsilon) - else: - hidden_states = torch.ops.torch_ipex.fast_rms_norm(hidden_states, - [self.weight.size(0)], - self.weight, - None, - self.variance_epsilon) - return hidden_states - else: - input_dtype = hidden_states.dtype - hidden_states = hidden_states.to(torch.float32) - variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - return self.weight * hidden_states.to(input_dtype) + import linear_q4_0 + result = linear_q4_0.fused_rms_norm(hidden_states, + [self.weight.size(0)], + self.weight, + None, + self.variance_epsilon) + # if nelement == 0, means fused norm failed, go back to python implement. + if result.nelement != 0: + return result + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) def llama_attention_forward_4_31(