diff --git a/python/llm/src/bigdl/llm/transformers/models/qwen2.py b/python/llm/src/bigdl/llm/transformers/models/qwen2.py index 1b9265c9..1bb7e6b8 100644 --- a/python/llm/src/bigdl/llm/transformers/models/qwen2.py +++ b/python/llm/src/bigdl/llm/transformers/models/qwen2.py @@ -135,6 +135,7 @@ def qwen2_attention_forward_quantized( "Passing `padding_mask` is deprecated and will be removed in v4.37. " "Please make sure use `attention_mask` instead.`" ) + use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids) bsz, q_len, _ = hidden_states.size() query_states = self.q_proj(hidden_states) @@ -158,8 +159,14 @@ def qwen2_attention_forward_quantized( "with a layer index.") kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - query_states, key_states = apply_rotary_pos_emb_cache_freq_xpu(query_states, key_states, - sin, cos, "qwen2", position_ids) + + if use_fuse_rope: + query_states, key_states = apply_rotary_pos_emb_cache_freq_xpu(query_states, key_states, + sin, cos, "qwen2", + position_ids) + else: + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, + cos, sin, position_ids) if past_key_value is not None: cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models @@ -263,8 +270,13 @@ def qwen2_attention_forward_origin( ) kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - query_states, key_states = apply_rotary_pos_emb_cache_freq_xpu(query_states, key_states, - sin, cos, "qwen2", position_ids) + if use_fuse_rope: + query_states, key_states = apply_rotary_pos_emb_cache_freq_xpu(query_states, key_states, + sin, cos, "qwen2", + position_ids) + else: + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, + cos, sin, position_ids) if past_key_value is not None: # update the number of seen tokens