diff --git a/python/llm/src/ipex_llm/transformers/models/llama.py b/python/llm/src/ipex_llm/transformers/models/llama.py index 589d7149..6c170bbf 100644 --- a/python/llm/src/ipex_llm/transformers/models/llama.py +++ b/python/llm/src/ipex_llm/transformers/models/llama.py @@ -547,7 +547,7 @@ def llama_attention_forward_4_31_original( value_states = torch.cat(value_states, dim=-1) else: if fp16_fusion_check(self.q_proj, hidden_states, self.training) and \ - hidden_size == 4096: + hidden_size == 4096 and self.q_proj.out_features == self.k_proj.out_features: # only use mm_qkv_out on pvc for llama-7b if not hasattr(self, "qkv_proj_weight"): self.qkv_proj_weight = torch.stack([self.q_proj.weight, @@ -1200,7 +1200,7 @@ def llama_attention_forward_4_36_original( value_states = torch.cat(value_states, dim=-1) else: if fp16_fusion_check(self.q_proj, hidden_states, self.training) and \ - hidden_size == 4096: + hidden_size == 4096 and self.q_proj.out_features == self.k_proj.out_features: # only use mm_qkv_out on pvc for llama-7b if not hasattr(self, "qkv_proj_weight"): self.qkv_proj_weight = torch.stack([self.q_proj.weight,