diff --git a/python/llm/src/bigdl/llm/transformers/models/llama.py b/python/llm/src/bigdl/llm/transformers/models/llama.py index 3c887892..7e8a77be 100644 --- a/python/llm/src/bigdl/llm/transformers/models/llama.py +++ b/python/llm/src/bigdl/llm/transformers/models/llama.py @@ -74,22 +74,26 @@ def get_ipex_version(): def llama_rms_norm_forward(self, hidden_states): - optimized_rms_norm = False if hidden_states.device.type == "xpu" and not (self.training and hidden_states.requires_grad): if get_ipex_version() <= "2.0.110+xpu": if self.variance_epsilon == 1e-6: hidden_states, _ = torch.ops.torch_ipex.rms_norm(hidden_states, [self.weight.size(0)], self.weight) - optimized_rms_norm = True + else: + import linear_q4_0 + hidden_states = linear_q4_0.fused_rms_norm(hidden_states, + [self.weight.size(0)], + self.weight, + None, + self.variance_epsilon) else: hidden_states = torch.ops.torch_ipex.fast_rms_norm(hidden_states, [self.weight.size(0)], self.weight, None, self.variance_epsilon) - optimized_rms_norm = True - if not optimized_rms_norm: + else: input_dtype = hidden_states.dtype hidden_states = hidden_states.to(torch.float32) variance = hidden_states.pow(2).mean(-1, keepdim=True)