fix qwen2 cpu (#11240)

This commit is contained in:
Yishuo Wang 2024-06-06 16:24:19 +08:00 committed by GitHub
parent e738ec38f4
commit 2e4ccd541c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 6 additions and 0 deletions

View file

@ -1279,6 +1279,9 @@ def _optimize_post(model, lightweight_bmm=False):
convert_forward(model,
module.Qwen2Attention,
qwen2_attention_forward)
convert_forward(model,
module.Qwen2SdpaAttention,
qwen2_attention_forward)
elif model.config.model_type == "qwen2_moe":
# for Qwen1.5-MOE-A2.7B
modeling_module_name = model.__class__.__module__

View file

@ -326,6 +326,9 @@ def qwen2_attention_forward(
attn_weights = None
if query_states.device.type == "cpu":
# repeat k/v heads if n_kv_heads < n_heads
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
attn_output = sdpa(query_states,
key_states,
value_states,