diff --git a/python/llm/src/bigdl/llm/transformers/convert.py b/python/llm/src/bigdl/llm/transformers/convert.py index 794f910f..ec2548b0 100644 --- a/python/llm/src/bigdl/llm/transformers/convert.py +++ b/python/llm/src/bigdl/llm/transformers/convert.py @@ -275,6 +275,9 @@ def optimize(model): module.BaichuanAttention, baichuan_attention_forward_13b ) + convert_forward(model, + module.RMSNorm, + llama_rms_norm_forward) elif model.config.model_type == "baichuan": # baichuan1 @@ -296,6 +299,9 @@ def optimize(model): module.BaichuanAttention, baichuan_attention_forward_13b ) + convert_forward(model, + module.RMSNorm, + llama_rms_norm_forward) elif model.config.model_type == "gpt_neox": from bigdl.llm.transformers.models.gptneox import gptneox_attention_forward diff --git a/python/llm/src/bigdl/llm/transformers/models/baichuan.py b/python/llm/src/bigdl/llm/transformers/models/baichuan.py index 298654f2..2bd8f550 100644 --- a/python/llm/src/bigdl/llm/transformers/models/baichuan.py +++ b/python/llm/src/bigdl/llm/transformers/models/baichuan.py @@ -28,6 +28,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from bigdl.llm.utils.common import invalidInputError from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache from bigdl.llm.transformers.models.utils import rotate_half, apply_rotary_pos_emb +from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb_no_cache_xpu KV_CACHE_ALLOC_BLOCK_LENGTH = 256 @@ -56,9 +57,15 @@ def baichuan_attention_forward_7b( kv_seq_len = key_states.shape[-2] if past_key_value is not None: kv_seq_len += past_key_value[0].shape[-2] - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, - cos, sin, position_ids, "baichuan") + if query_states.device.type == "xpu" and not (self.training and query_states.requires_grad): + query_states, key_states = apply_rotary_pos_emb_no_cache_xpu(query_states, + key_states, + position_ids, + "baichuan") + else: + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, + cos, sin, position_ids, "baichuan") # [bsz, nh, t, hd] # if past_key_value is not None: diff --git a/python/llm/src/bigdl/llm/transformers/models/baichuan2.py b/python/llm/src/bigdl/llm/transformers/models/baichuan2.py index 08d392e8..4576fb6b 100644 --- a/python/llm/src/bigdl/llm/transformers/models/baichuan2.py +++ b/python/llm/src/bigdl/llm/transformers/models/baichuan2.py @@ -28,6 +28,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from bigdl.llm.utils.common import invalidInputError from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache from bigdl.llm.transformers.models.utils import rotate_half, apply_rotary_pos_emb +from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb_no_cache_xpu from transformers.utils import logging, ContextManagers logger = logging.get_logger(__name__) @@ -68,9 +69,15 @@ def baichuan_attention_forward_7b( kv_seq_len = key_states.shape[-2] if past_key_value is not None: kv_seq_len += past_key_value[0].shape[-2] - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, - cos, sin, position_ids, "baichuan") + if query_states.device.type == "xpu" and not (self.training and query_states.requires_grad): + query_states, key_states = apply_rotary_pos_emb_no_cache_xpu(query_states, + key_states, + position_ids, + "baichuan") + else: + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, + cos, sin, position_ids, "baichuan") # [bsz, nh, t, hd] # if past_key_value is not None: