qwen2 cpu fix (#10187)
This commit is contained in:
parent
39d37bd042
commit
56ad781f2f
1 changed files with 16 additions and 4 deletions
|
|
@ -135,6 +135,7 @@ def qwen2_attention_forward_quantized(
|
||||||
"Passing `padding_mask` is deprecated and will be removed in v4.37. "
|
"Passing `padding_mask` is deprecated and will be removed in v4.37. "
|
||||||
"Please make sure use `attention_mask` instead.`"
|
"Please make sure use `attention_mask` instead.`"
|
||||||
)
|
)
|
||||||
|
use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
|
||||||
bsz, q_len, _ = hidden_states.size()
|
bsz, q_len, _ = hidden_states.size()
|
||||||
|
|
||||||
query_states = self.q_proj(hidden_states)
|
query_states = self.q_proj(hidden_states)
|
||||||
|
|
@ -158,8 +159,14 @@ def qwen2_attention_forward_quantized(
|
||||||
"with a layer index.")
|
"with a layer index.")
|
||||||
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
||||||
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
||||||
query_states, key_states = apply_rotary_pos_emb_cache_freq_xpu(query_states, key_states,
|
|
||||||
sin, cos, "qwen2", position_ids)
|
if use_fuse_rope:
|
||||||
|
query_states, key_states = apply_rotary_pos_emb_cache_freq_xpu(query_states, key_states,
|
||||||
|
sin, cos, "qwen2",
|
||||||
|
position_ids)
|
||||||
|
else:
|
||||||
|
query_states, key_states = apply_rotary_pos_emb(query_states, key_states,
|
||||||
|
cos, sin, position_ids)
|
||||||
|
|
||||||
if past_key_value is not None:
|
if past_key_value is not None:
|
||||||
cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
|
cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
|
||||||
|
|
@ -263,8 +270,13 @@ def qwen2_attention_forward_origin(
|
||||||
)
|
)
|
||||||
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
||||||
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
||||||
query_states, key_states = apply_rotary_pos_emb_cache_freq_xpu(query_states, key_states,
|
if use_fuse_rope:
|
||||||
sin, cos, "qwen2", position_ids)
|
query_states, key_states = apply_rotary_pos_emb_cache_freq_xpu(query_states, key_states,
|
||||||
|
sin, cos, "qwen2",
|
||||||
|
position_ids)
|
||||||
|
else:
|
||||||
|
query_states, key_states = apply_rotary_pos_emb(query_states, key_states,
|
||||||
|
cos, sin, position_ids)
|
||||||
|
|
||||||
if past_key_value is not None:
|
if past_key_value is not None:
|
||||||
# update the number of seen tokens
|
# update the number of seen tokens
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue