From 26239446049a146438069602eada2ffad1eb5602 Mon Sep 17 00:00:00 2001 From: Yishuo Wang Date: Fri, 7 Jun 2024 14:42:18 +0800 Subject: [PATCH] qwen2 sdpa small fix (#11261) --- python/llm/src/ipex_llm/transformers/models/qwen2.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/python/llm/src/ipex_llm/transformers/models/qwen2.py b/python/llm/src/ipex_llm/transformers/models/qwen2.py index 64cbcb70..2e236acb 100644 --- a/python/llm/src/ipex_llm/transformers/models/qwen2.py +++ b/python/llm/src/ipex_llm/transformers/models/qwen2.py @@ -337,6 +337,9 @@ def qwen2_attention_forward( is_causal=self.is_causal and attention_mask is None and q_len > 1) elif not self.training and not hidden_states.requires_grad and \ use_flash_attention(query_states, key_states, attention_mask): + # repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) attn_output = sdpa(query_states.to(device, dtype=torch.float16), key_states.to(device, dtype=torch.float16), value_states.to(device, dtype=torch.float16),