From 36dd4afd61dfa4e2c8333e2564dc88265087e36a Mon Sep 17 00:00:00 2001 From: Yang Wang Date: Sat, 7 Oct 2023 04:27:37 +0800 Subject: [PATCH] Fix llama when rope scaling is not None (#9086) * Fix llama when rope scaling is not None * fix style * fix style --- python/llm/src/bigdl/llm/transformers/models/llama.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/python/llm/src/bigdl/llm/transformers/models/llama.py b/python/llm/src/bigdl/llm/transformers/models/llama.py index 791f5acd..7953670a 100644 --- a/python/llm/src/bigdl/llm/transformers/models/llama.py +++ b/python/llm/src/bigdl/llm/transformers/models/llama.py @@ -118,7 +118,11 @@ def llama_attention_forward_4_31( if past_key_value is not None: kv_seq_len += past_key_value[0].shape[-2] - if query_states.device.type == "xpu" and not (self.training and query_states.requires_grad): + use_fuse_rope = query_states.device.type == "xpu" + use_fuse_rope = use_fuse_rope and not (self.training and query_states.requires_grad) + use_fuse_rope = use_fuse_rope and self.config.rope_scaling is None + + if use_fuse_rope: query_states, key_states = apply_rotary_pos_emb_no_cache_xpu(query_states, key_states, position_ids,