diff --git a/python/llm/src/bigdl/llm/transformers/models/qwen2.py b/python/llm/src/bigdl/llm/transformers/models/qwen2.py index de9ccb61..e71a1df6 100644 --- a/python/llm/src/bigdl/llm/transformers/models/qwen2.py +++ b/python/llm/src/bigdl/llm/transformers/models/qwen2.py @@ -167,6 +167,8 @@ def qwen2_attention_forward_quantized( if q_len != 1: key, value = restore_fp8_kv_cache(key_states, value_states, query_states.dtype) + key = repeat_kv(key, self.num_key_value_groups) + value = repeat_kv(value, self.num_key_value_groups) attn_weights = torch.matmul(query_states, key.transpose(2, 3)) else: import linear_q4_0