From 9a330bfc2be4d0e089df7092bb2474eb4ea77480 Mon Sep 17 00:00:00 2001 From: Yishuo Wang Date: Thu, 14 Dec 2023 16:16:05 +0800 Subject: [PATCH] fix fuse mlp when using q5_0 or fp8 (#9689) --- python/llm/src/bigdl/llm/transformers/models/baichuan2.py | 2 ++ python/llm/src/bigdl/llm/transformers/models/qwen.py | 2 ++ 2 files changed, 4 insertions(+) diff --git a/python/llm/src/bigdl/llm/transformers/models/baichuan2.py b/python/llm/src/bigdl/llm/transformers/models/baichuan2.py index 39c4a2f0..7dfbca8e 100644 --- a/python/llm/src/bigdl/llm/transformers/models/baichuan2.py +++ b/python/llm/src/bigdl/llm/transformers/models/baichuan2.py @@ -26,6 +26,7 @@ from torch import nn from torch.nn import functional as F from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from bigdl.llm.utils.common import invalidInputError +from bigdl.llm.ggml.quantize import ggml_tensor_qtype from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache from bigdl.llm.transformers.models.utils import rotate_half, apply_rotary_pos_emb from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb_no_cache_xpu @@ -75,6 +76,7 @@ def baichuan_mlp_forward( x: torch.Tensor, ) -> torch.Tensor: if x.shape[1] == 1 and x.dtype == torch.float32 and x.device.type == 'xpu' \ + and self.gate_proj.qtype == ggml_tensor_qtype["sym_int4"] \ and not (self.training and x.requires_grad): import linear_q4_0 x_2d = x.view(-1, x.shape[-1]) diff --git a/python/llm/src/bigdl/llm/transformers/models/qwen.py b/python/llm/src/bigdl/llm/transformers/models/qwen.py index 18642434..2ef6062c 100644 --- a/python/llm/src/bigdl/llm/transformers/models/qwen.py +++ b/python/llm/src/bigdl/llm/transformers/models/qwen.py @@ -39,6 +39,7 @@ except ImportError: from bigdl.llm.transformers.models.utils import extend_kv_cache, init_kv_cache, append_kv_cache from bigdl.llm.transformers.models.utils import rotate_half from bigdl.llm.utils.common import invalidInputError +from bigdl.llm.ggml.quantize import ggml_tensor_qtype apply_rotary_emb_func = None @@ -214,6 +215,7 @@ def qwen_attention_forward( def qwen_mlp_forward(self, x: torch.Tensor) -> torch.Tensor: if x.shape[1] == 1 and x.dtype == torch.float32 and x.device.type == 'xpu' \ + and self.w2.qtype == ggml_tensor_qtype["sym_int4"] \ and not (self.training and x.requires_grad): import linear_q4_0 x_2d = x.view(-1, x.shape[-1])