diff --git a/python/llm/src/bigdl/llm/transformers/models/baichuan2.py b/python/llm/src/bigdl/llm/transformers/models/baichuan2.py index cdc49f5d..7827c1fb 100644 --- a/python/llm/src/bigdl/llm/transformers/models/baichuan2.py +++ b/python/llm/src/bigdl/llm/transformers/models/baichuan2.py @@ -30,6 +30,7 @@ from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, from bigdl.llm.transformers.models.utils import rotate_half, apply_rotary_pos_emb from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb_no_cache_xpu from transformers.utils import logging, ContextManagers +from bigdl.llm.transformers.models.llama import get_ipex_version logger = logging.get_logger(__name__) try: @@ -47,8 +48,16 @@ KV_CACHE_ALLOC_BLOCK_LENGTH = 256 def baichuan_13b_rms_norm_forward(self, hidden_states): if hidden_states.device.type == "xpu" and not (self.training and hidden_states.requires_grad): - hidden_states, _ = torch.ops.torch_ipex.rms_norm(hidden_states, - [self.weight.size(0)], self.weight) + if get_ipex_version() <= "2.0.110+xpu": + hidden_states, _ = torch.ops.torch_ipex.rms_norm(hidden_states, + [self.weight.size(0)], + self.weight) + else: + hidden_states, _ = torch.ops.torch_ipex.fast_rms_norm(hidden_states, + [self.weight.size(0)], + self.weight, + None, + self.epsilon) else: input_dtype = hidden_states.dtype hidden_states = hidden_states.to(torch.float32) diff --git a/python/llm/src/bigdl/llm/transformers/models/chatglm2.py b/python/llm/src/bigdl/llm/transformers/models/chatglm2.py index 0373f2aa..9e04147c 100644 --- a/python/llm/src/bigdl/llm/transformers/models/chatglm2.py +++ b/python/llm/src/bigdl/llm/transformers/models/chatglm2.py @@ -22,6 +22,7 @@ from typing import Optional, Tuple, List import torch.nn.functional as F from transformers.modeling_outputs import BaseModelOutputWithPast from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache +from bigdl.llm.transformers.models.llama import get_ipex_version KV_CACHE_ALLOC_BLOCK_LENGTH = 256 @@ -77,8 +78,16 @@ def apply_rotary_pos_emb_chatglm(x: torch.Tensor, rope_cache: torch.Tensor) -> t def chatglm_rms_norm_forward(self, hidden_states): if hidden_states.device.type == "xpu" and not (self.training and hidden_states.requires_grad): - hidden_states, _ = torch.ops.torch_ipex.rms_norm(hidden_states, - [self.weight.size(0)], self.weight) + if get_ipex_version() <= "2.0.110+xpu": + hidden_states, _ = torch.ops.torch_ipex.rms_norm(hidden_states, + [self.weight.size(0)], self.weight) + else: + # for ipex >= 2.1 + hidden_states = torch.ops.torch_ipex.fast_rms_norm(hidden_states, + [self.weight.size(0)], + self.weight, + None, # bias + self.eps) else: input_dtype = hidden_states.dtype hidden_states = hidden_states.to(torch.float32) diff --git a/python/llm/src/bigdl/llm/transformers/models/llama.py b/python/llm/src/bigdl/llm/transformers/models/llama.py index 798facec..64a6dd33 100644 --- a/python/llm/src/bigdl/llm/transformers/models/llama.py +++ b/python/llm/src/bigdl/llm/transformers/models/llama.py @@ -75,13 +75,15 @@ def get_ipex_version(): def llama_rms_norm_forward(self, hidden_states): if hidden_states.device.type == "xpu" and not (self.training and hidden_states.requires_grad): - if get_ipex_version() == "2.0.110+xpu": + if get_ipex_version() <= "2.0.110+xpu": hidden_states, _ = torch.ops.torch_ipex.rms_norm(hidden_states, [self.weight.size(0)], self.weight) else: - hidden_states, _ = torch.ops.torch_ipex.rms_norm(hidden_states, - [self.weight.size(0)], self.weight, - self.variance_epsilon) + hidden_states, _ = torch.ops.torch_ipex.fast_rms_norm(hidden_states, + [self.weight.size(0)], + self.weight, + None, + self.variance_epsilon) else: input_dtype = hidden_states.dtype hidden_states = hidden_states.to(torch.float32)