qwen2 cpu fix (#10187)

This commit is contained in:
Xin Qiu 2024-02-21 11:23:51 +08:00 committed by GitHub
parent 39d37bd042
commit 56ad781f2f

View file

@ -135,6 +135,7 @@ def qwen2_attention_forward_quantized(
"Passing `padding_mask` is deprecated and will be removed in v4.37. "
"Please make sure use `attention_mask` instead.`"
)
use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
bsz, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
@ -158,8 +159,14 @@ def qwen2_attention_forward_quantized(
"with a layer index.")
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb_cache_freq_xpu(query_states, key_states,
sin, cos, "qwen2", position_ids)
if use_fuse_rope:
query_states, key_states = apply_rotary_pos_emb_cache_freq_xpu(query_states, key_states,
sin, cos, "qwen2",
position_ids)
else:
query_states, key_states = apply_rotary_pos_emb(query_states, key_states,
cos, sin, position_ids)
if past_key_value is not None:
cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
@ -263,8 +270,13 @@ def qwen2_attention_forward_origin(
)
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb_cache_freq_xpu(query_states, key_states,
sin, cos, "qwen2", position_ids)
if use_fuse_rope:
query_states, key_states = apply_rotary_pos_emb_cache_freq_xpu(query_states, key_states,
sin, cos, "qwen2",
position_ids)
else:
query_states, key_states = apply_rotary_pos_emb(query_states, key_states,
cos, sin, position_ids)
if past_key_value is not None:
# update the number of seen tokens