diff --git a/python/llm/src/bigdl/llm/transformers/models/baichuan2.py b/python/llm/src/bigdl/llm/transformers/models/baichuan2.py index cf3a0d2c..8b42cb3f 100644 --- a/python/llm/src/bigdl/llm/transformers/models/baichuan2.py +++ b/python/llm/src/bigdl/llm/transformers/models/baichuan2.py @@ -47,18 +47,22 @@ KV_CACHE_ALLOC_BLOCK_LENGTH = 256 def baichuan_13b_rms_norm_forward(self, hidden_states): + optimized_rms_norm = False if hidden_states.device.type == "xpu" and not (self.training and hidden_states.requires_grad): if get_ipex_version() <= "2.0.110+xpu": - hidden_states, _ = torch.ops.torch_ipex.rms_norm(hidden_states, - [self.weight.size(0)], - self.weight) + if self.epsilon == 1e-6: + hidden_states, _ = torch.ops.torch_ipex.rms_norm(hidden_states, + [self.weight.size(0)], + self.weight) + optimized_rms_norm = True else: hidden_states = torch.ops.torch_ipex.fast_rms_norm(hidden_states, [self.weight.size(0)], self.weight, None, self.epsilon) - else: + optimized_rms_norm = True + if not optimized_rms_norm: input_dtype = hidden_states.dtype hidden_states = hidden_states.to(torch.float32) variance = hidden_states.pow(2).mean(-1, keepdim=True) diff --git a/python/llm/src/bigdl/llm/transformers/models/chatglm2.py b/python/llm/src/bigdl/llm/transformers/models/chatglm2.py index 9e04147c..41df1038 100644 --- a/python/llm/src/bigdl/llm/transformers/models/chatglm2.py +++ b/python/llm/src/bigdl/llm/transformers/models/chatglm2.py @@ -77,10 +77,14 @@ def apply_rotary_pos_emb_chatglm(x: torch.Tensor, rope_cache: torch.Tensor) -> t def chatglm_rms_norm_forward(self, hidden_states): + optimized_rms_norm = False if hidden_states.device.type == "xpu" and not (self.training and hidden_states.requires_grad): if get_ipex_version() <= "2.0.110+xpu": - hidden_states, _ = torch.ops.torch_ipex.rms_norm(hidden_states, - [self.weight.size(0)], self.weight) + if self.eps == 1e-6: + hidden_states, _ = torch.ops.torch_ipex.rms_norm(hidden_states, + [self.weight.size(0)], + self.weight) + optimized_rms_norm = True else: # for ipex >= 2.1 hidden_states = torch.ops.torch_ipex.fast_rms_norm(hidden_states, @@ -88,7 +92,8 @@ def chatglm_rms_norm_forward(self, hidden_states): self.weight, None, # bias self.eps) - else: + optimized_rms_norm = True + if not optimized_rms_norm: input_dtype = hidden_states.dtype hidden_states = hidden_states.to(torch.float32) variance = hidden_states.pow(2).mean(-1, keepdim=True) diff --git a/python/llm/src/bigdl/llm/transformers/models/llama.py b/python/llm/src/bigdl/llm/transformers/models/llama.py index 1189c21e..3c887892 100644 --- a/python/llm/src/bigdl/llm/transformers/models/llama.py +++ b/python/llm/src/bigdl/llm/transformers/models/llama.py @@ -74,17 +74,22 @@ def get_ipex_version(): def llama_rms_norm_forward(self, hidden_states): + optimized_rms_norm = False if hidden_states.device.type == "xpu" and not (self.training and hidden_states.requires_grad): if get_ipex_version() <= "2.0.110+xpu": - hidden_states, _ = torch.ops.torch_ipex.rms_norm(hidden_states, - [self.weight.size(0)], self.weight) + if self.variance_epsilon == 1e-6: + hidden_states, _ = torch.ops.torch_ipex.rms_norm(hidden_states, + [self.weight.size(0)], + self.weight) + optimized_rms_norm = True else: hidden_states = torch.ops.torch_ipex.fast_rms_norm(hidden_states, [self.weight.size(0)], self.weight, None, self.variance_epsilon) - else: + optimized_rms_norm = True + if not optimized_rms_norm: input_dtype = hidden_states.dtype hidden_states = hidden_states.to(torch.float32) variance = hidden_states.pow(2).mean(-1, keepdim=True)