From 2e4ccd541c54839454012c17029a48044308e786 Mon Sep 17 00:00:00 2001 From: Yishuo Wang Date: Thu, 6 Jun 2024 16:24:19 +0800 Subject: [PATCH] fix qwen2 cpu (#11240) --- python/llm/src/ipex_llm/transformers/convert.py | 3 +++ python/llm/src/ipex_llm/transformers/models/qwen2.py | 3 +++ 2 files changed, 6 insertions(+) diff --git a/python/llm/src/ipex_llm/transformers/convert.py b/python/llm/src/ipex_llm/transformers/convert.py index 37aa3f6e..abc5195c 100644 --- a/python/llm/src/ipex_llm/transformers/convert.py +++ b/python/llm/src/ipex_llm/transformers/convert.py @@ -1279,6 +1279,9 @@ def _optimize_post(model, lightweight_bmm=False): convert_forward(model, module.Qwen2Attention, qwen2_attention_forward) + convert_forward(model, + module.Qwen2SdpaAttention, + qwen2_attention_forward) elif model.config.model_type == "qwen2_moe": # for Qwen1.5-MOE-A2.7B modeling_module_name = model.__class__.__module__ diff --git a/python/llm/src/ipex_llm/transformers/models/qwen2.py b/python/llm/src/ipex_llm/transformers/models/qwen2.py index 7a6ed90b..64cbcb70 100644 --- a/python/llm/src/ipex_llm/transformers/models/qwen2.py +++ b/python/llm/src/ipex_llm/transformers/models/qwen2.py @@ -326,6 +326,9 @@ def qwen2_attention_forward( attn_weights = None if query_states.device.type == "cpu": + # repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) attn_output = sdpa(query_states, key_states, value_states,