qwen2 cpu fix (#10187)
This commit is contained in:
parent
39d37bd042
commit
56ad781f2f
1 changed files with 16 additions and 4 deletions
|
|
@ -135,6 +135,7 @@ def qwen2_attention_forward_quantized(
|
|||
"Passing `padding_mask` is deprecated and will be removed in v4.37. "
|
||||
"Please make sure use `attention_mask` instead.`"
|
||||
)
|
||||
use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
query_states = self.q_proj(hidden_states)
|
||||
|
|
@ -158,8 +159,14 @@ def qwen2_attention_forward_quantized(
|
|||
"with a layer index.")
|
||||
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
||||
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
||||
|
||||
if use_fuse_rope:
|
||||
query_states, key_states = apply_rotary_pos_emb_cache_freq_xpu(query_states, key_states,
|
||||
sin, cos, "qwen2", position_ids)
|
||||
sin, cos, "qwen2",
|
||||
position_ids)
|
||||
else:
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states,
|
||||
cos, sin, position_ids)
|
||||
|
||||
if past_key_value is not None:
|
||||
cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
|
||||
|
|
@ -263,8 +270,13 @@ def qwen2_attention_forward_origin(
|
|||
)
|
||||
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
||||
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
||||
if use_fuse_rope:
|
||||
query_states, key_states = apply_rotary_pos_emb_cache_freq_xpu(query_states, key_states,
|
||||
sin, cos, "qwen2", position_ids)
|
||||
sin, cos, "qwen2",
|
||||
position_ids)
|
||||
else:
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states,
|
||||
cos, sin, position_ids)
|
||||
|
||||
if past_key_value is not None:
|
||||
# update the number of seen tokens
|
||||
|
|
|
|||
Loading…
Reference in a new issue