fix qwen2 cpu (#11240)
This commit is contained in:
parent
e738ec38f4
commit
2e4ccd541c
2 changed files with 6 additions and 0 deletions
|
|
@ -1279,6 +1279,9 @@ def _optimize_post(model, lightweight_bmm=False):
|
|||
convert_forward(model,
|
||||
module.Qwen2Attention,
|
||||
qwen2_attention_forward)
|
||||
convert_forward(model,
|
||||
module.Qwen2SdpaAttention,
|
||||
qwen2_attention_forward)
|
||||
elif model.config.model_type == "qwen2_moe":
|
||||
# for Qwen1.5-MOE-A2.7B
|
||||
modeling_module_name = model.__class__.__module__
|
||||
|
|
|
|||
|
|
@ -326,6 +326,9 @@ def qwen2_attention_forward(
|
|||
|
||||
attn_weights = None
|
||||
if query_states.device.type == "cpu":
|
||||
# repeat k/v heads if n_kv_heads < n_heads
|
||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||
attn_output = sdpa(query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
|
|
|
|||
Loading…
Reference in a new issue