diff --git a/python/llm/src/bigdl/llm/transformers/convert.py b/python/llm/src/bigdl/llm/transformers/convert.py index 5d530405..37757fbd 100644 --- a/python/llm/src/bigdl/llm/transformers/convert.py +++ b/python/llm/src/bigdl/llm/transformers/convert.py @@ -311,6 +311,10 @@ def optimize(model): module.InternLMAttention, internlm_attention_forward ) + convert_forward(model, + module.InternLMRMSNorm, + llama_rms_norm_forward + ) elif model.config.model_type == "qwen": modeling_module_name = model.__class__.__module__ module = importlib.import_module(modeling_module_name) diff --git a/python/llm/src/bigdl/llm/transformers/models/internlm.py b/python/llm/src/bigdl/llm/transformers/models/internlm.py index 7f3b7f7d..5afd06b6 100644 --- a/python/llm/src/bigdl/llm/transformers/models/internlm.py +++ b/python/llm/src/bigdl/llm/transformers/models/internlm.py @@ -74,15 +74,20 @@ def internlm_attention_forward( kv_seq_len = key_states.shape[-2] if past_key_value is not None: kv_seq_len += past_key_value[0].shape[-2] - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - query_states, key_states = apply_rotary_pos_emb( - query_states, - key_states, - cos, - sin, - position_ids, - "internlm" - ) + if query_states.device.type == "xpu" and not (self.training and query_states.requires_grad): + query_states, key_states = apply_rotary_pos_emb_no_cache_xpu(query_states, + key_states, + position_ids, + "internlm") + else: + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + query_states, key_states = apply_rotary_pos_emb( + query_states, + key_states, + cos, + sin, + position_ids, + "internlm") # [bsz, nh, t, hd] if past_key_value is not None: