fix qwen2 cpu (#11240)
This commit is contained in:
parent
e738ec38f4
commit
2e4ccd541c
2 changed files with 6 additions and 0 deletions
|
|
@ -1279,6 +1279,9 @@ def _optimize_post(model, lightweight_bmm=False):
|
||||||
convert_forward(model,
|
convert_forward(model,
|
||||||
module.Qwen2Attention,
|
module.Qwen2Attention,
|
||||||
qwen2_attention_forward)
|
qwen2_attention_forward)
|
||||||
|
convert_forward(model,
|
||||||
|
module.Qwen2SdpaAttention,
|
||||||
|
qwen2_attention_forward)
|
||||||
elif model.config.model_type == "qwen2_moe":
|
elif model.config.model_type == "qwen2_moe":
|
||||||
# for Qwen1.5-MOE-A2.7B
|
# for Qwen1.5-MOE-A2.7B
|
||||||
modeling_module_name = model.__class__.__module__
|
modeling_module_name = model.__class__.__module__
|
||||||
|
|
|
||||||
|
|
@ -326,6 +326,9 @@ def qwen2_attention_forward(
|
||||||
|
|
||||||
attn_weights = None
|
attn_weights = None
|
||||||
if query_states.device.type == "cpu":
|
if query_states.device.type == "cpu":
|
||||||
|
# repeat k/v heads if n_kv_heads < n_heads
|
||||||
|
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||||
|
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||||
attn_output = sdpa(query_states,
|
attn_output = sdpa(query_states,
|
||||||
key_states,
|
key_states,
|
||||||
value_states,
|
value_states,
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue