From e7aa67e141600eb02e817f5794bf1728212a1c54 Mon Sep 17 00:00:00 2001 From: SONG Ge <38711238+sgwhat@users.noreply.github.com> Date: Fri, 13 Oct 2023 14:18:28 +0800 Subject: [PATCH] [LLM] Add rope optimization for internlm (#9159) * add rope and norm optimization for internlm and gptneox * revert gptneox back and split with pr#9155 # * add norm_forward * style fix * update * update --- .../llm/src/bigdl/llm/transformers/convert.py | 4 ++++ .../bigdl/llm/transformers/models/internlm.py | 23 +++++++++++-------- 2 files changed, 18 insertions(+), 9 deletions(-) diff --git a/python/llm/src/bigdl/llm/transformers/convert.py b/python/llm/src/bigdl/llm/transformers/convert.py index 5d530405..37757fbd 100644 --- a/python/llm/src/bigdl/llm/transformers/convert.py +++ b/python/llm/src/bigdl/llm/transformers/convert.py @@ -311,6 +311,10 @@ def optimize(model): module.InternLMAttention, internlm_attention_forward ) + convert_forward(model, + module.InternLMRMSNorm, + llama_rms_norm_forward + ) elif model.config.model_type == "qwen": modeling_module_name = model.__class__.__module__ module = importlib.import_module(modeling_module_name) diff --git a/python/llm/src/bigdl/llm/transformers/models/internlm.py b/python/llm/src/bigdl/llm/transformers/models/internlm.py index 7f3b7f7d..5afd06b6 100644 --- a/python/llm/src/bigdl/llm/transformers/models/internlm.py +++ b/python/llm/src/bigdl/llm/transformers/models/internlm.py @@ -74,15 +74,20 @@ def internlm_attention_forward( kv_seq_len = key_states.shape[-2] if past_key_value is not None: kv_seq_len += past_key_value[0].shape[-2] - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - query_states, key_states = apply_rotary_pos_emb( - query_states, - key_states, - cos, - sin, - position_ids, - "internlm" - ) + if query_states.device.type == "xpu" and not (self.training and query_states.requires_grad): + query_states, key_states = apply_rotary_pos_emb_no_cache_xpu(query_states, + key_states, + position_ids, + "internlm") + else: + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + query_states, key_states = apply_rotary_pos_emb( + query_states, + key_states, + cos, + sin, + position_ids, + "internlm") # [bsz, nh, t, hd] if past_key_value is not None: