diff --git a/python/llm/src/ipex_llm/transformers/models/chatglm2.py b/python/llm/src/ipex_llm/transformers/models/chatglm2.py index c6dae7a9..3308e93f 100644 --- a/python/llm/src/ipex_llm/transformers/models/chatglm2.py +++ b/python/llm/src/ipex_llm/transformers/models/chatglm2.py @@ -184,7 +184,7 @@ def chatglm2_model_forward( def chatglm2_attention_forward( self, hidden_states, attention_mask, rotary_pos_emb, kv_cache=None, use_cache=True ): - if use_quantize_kv_cache(self.query_key_value, hidden_states): + if use_quantize_kv_cache(self.query_key_value, hidden_states.transpose(0, 1)): forward_function = chatglm2_quantized_attention_forward_8eb45c else: forward_function = chatglm2_attention_forward_8eb45c