From 56ad781f2fa10bc4c1df57501cd1e8968dbc4e9b Mon Sep 17 00:00:00 2001 From: Xin Qiu Date: Wed, 21 Feb 2024 11:23:51 +0800 Subject: [PATCH] qwen2 cpu fix (#10187) --- .../bigdl/llm/transformers/models/qwen2.py | 20 +++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/python/llm/src/bigdl/llm/transformers/models/qwen2.py b/python/llm/src/bigdl/llm/transformers/models/qwen2.py index 1b9265c9..1bb7e6b8 100644 --- a/python/llm/src/bigdl/llm/transformers/models/qwen2.py +++ b/python/llm/src/bigdl/llm/transformers/models/qwen2.py @@ -135,6 +135,7 @@ def qwen2_attention_forward_quantized( "Passing `padding_mask` is deprecated and will be removed in v4.37. " "Please make sure use `attention_mask` instead.`" ) + use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids) bsz, q_len, _ = hidden_states.size() query_states = self.q_proj(hidden_states) @@ -158,8 +159,14 @@ def qwen2_attention_forward_quantized( "with a layer index.") kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - query_states, key_states = apply_rotary_pos_emb_cache_freq_xpu(query_states, key_states, - sin, cos, "qwen2", position_ids) + + if use_fuse_rope: + query_states, key_states = apply_rotary_pos_emb_cache_freq_xpu(query_states, key_states, + sin, cos, "qwen2", + position_ids) + else: + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, + cos, sin, position_ids) if past_key_value is not None: cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models @@ -263,8 +270,13 @@ def qwen2_attention_forward_origin( ) kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - query_states, key_states = apply_rotary_pos_emb_cache_freq_xpu(query_states, key_states, - sin, cos, "qwen2", position_ids) + if use_fuse_rope: + query_states, key_states = apply_rotary_pos_emb_cache_freq_xpu(query_states, key_states, + sin, cos, "qwen2", + position_ids) + else: + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, + cos, sin, position_ids) if past_key_value is not None: # update the number of seen tokens