From a20271ffe41a1a81b3ab7161e70101bac7c0db6b Mon Sep 17 00:00:00 2001 From: "Wang, Jian4" <61138589+hzjane@users.noreply.github.com> Date: Wed, 17 Apr 2024 16:49:59 +0800 Subject: [PATCH] LLM: Fix yi-6b fp16 error on pvc (#10781) * updat for yi fp16 * update * update --- python/llm/src/ipex_llm/transformers/models/llama.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/llm/src/ipex_llm/transformers/models/llama.py b/python/llm/src/ipex_llm/transformers/models/llama.py index 589d7149..6c170bbf 100644 --- a/python/llm/src/ipex_llm/transformers/models/llama.py +++ b/python/llm/src/ipex_llm/transformers/models/llama.py @@ -547,7 +547,7 @@ def llama_attention_forward_4_31_original( value_states = torch.cat(value_states, dim=-1) else: if fp16_fusion_check(self.q_proj, hidden_states, self.training) and \ - hidden_size == 4096: + hidden_size == 4096 and self.q_proj.out_features == self.k_proj.out_features: # only use mm_qkv_out on pvc for llama-7b if not hasattr(self, "qkv_proj_weight"): self.qkv_proj_weight = torch.stack([self.q_proj.weight, @@ -1200,7 +1200,7 @@ def llama_attention_forward_4_36_original( value_states = torch.cat(value_states, dim=-1) else: if fp16_fusion_check(self.q_proj, hidden_states, self.training) and \ - hidden_size == 4096: + hidden_size == 4096 and self.q_proj.out_features == self.k_proj.out_features: # only use mm_qkv_out on pvc for llama-7b if not hasattr(self, "qkv_proj_weight"): self.qkv_proj_weight = torch.stack([self.q_proj.weight,