diff --git a/python/llm/src/bigdl/llm/transformers/models/baichuan2.py b/python/llm/src/bigdl/llm/transformers/models/baichuan2.py index f5b7dcfb..bbb48eed 100644 --- a/python/llm/src/bigdl/llm/transformers/models/baichuan2.py +++ b/python/llm/src/bigdl/llm/transformers/models/baichuan2.py @@ -69,11 +69,11 @@ def baichuan_mlp_forward( self, x: torch.Tensor, ) -> torch.Tensor: - if x.shape[1] == 1 and x.dtype == torch.float32 and x.device.type == 'xpu' \ + x_2d = x.view(-1, x.shape[-1]) + if x_2d.shape[0] == 1 and x.dtype == torch.float32 and x.device.type == 'xpu' \ and self.gate_proj.qtype == ggml_tensor_qtype["sym_int4"] \ and not (self.training and x.requires_grad): import linear_q4_0 - x_2d = x.view(-1, x.shape[-1]) if not x_2d.is_contiguous(): x_2d = x_2d.contiguous() return self.down_proj(linear_q4_0.mlp_forward_q4_0_xpu( diff --git a/python/llm/src/bigdl/llm/transformers/models/llama.py b/python/llm/src/bigdl/llm/transformers/models/llama.py index 5cc6fd54..8572b92c 100644 --- a/python/llm/src/bigdl/llm/transformers/models/llama.py +++ b/python/llm/src/bigdl/llm/transformers/models/llama.py @@ -42,6 +42,7 @@ from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, from bigdl.llm.transformers.models.utils import rotate_half, apply_rotary_pos_emb from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb_no_cache_xpu from bigdl.llm.transformers.low_bit_linear import SYM_INT4 +from bigdl.llm.ggml.quantize import ggml_tensor_qtype def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: @@ -96,10 +97,11 @@ def llama_mlp_forward( self, x: torch.Tensor, ) -> torch.Tensor: - if x.shape[1] == 1 and x.dtype == torch.float32 and x.device.type == 'xpu' \ + x_2d = x.view(-1, x.shape[-1]) + if x_2d.shape[0] == 1 and x.dtype == torch.float32 and x.device.type == 'xpu' \ + and self.gate_proj.qtype == ggml_tensor_qtype["sym_int4"] \ and not (self.training and x.requires_grad): import linear_q4_0 - x_2d = x.view(-1, x.shape[-1]) if not x_2d.is_contiguous(): x_2d = x_2d.contiguous() return self.down_proj(linear_q4_0.mlp_forward_q4_0_xpu( diff --git a/python/llm/src/bigdl/llm/transformers/models/qwen.py b/python/llm/src/bigdl/llm/transformers/models/qwen.py index 43dca325..71f29236 100644 --- a/python/llm/src/bigdl/llm/transformers/models/qwen.py +++ b/python/llm/src/bigdl/llm/transformers/models/qwen.py @@ -240,11 +240,11 @@ def qwen_attention_forward( def qwen_mlp_forward(self, x: torch.Tensor) -> torch.Tensor: - if x.shape[1] == 1 and x.dtype == torch.float32 and x.device.type == 'xpu' \ + x_2d = x.view(-1, x.shape[-1]) + if x_2d.shape[0] == 1 and x.dtype == torch.float32 and x.device.type == 'xpu' \ and self.w2.qtype == ggml_tensor_qtype["sym_int4"] \ and not (self.training and x.requires_grad): import linear_q4_0 - x_2d = x.view(-1, x.shape[-1]) if not x_2d.is_contiguous(): x_2d = x_2d.contiguous() return self.c_proj(linear_q4_0.mlp_forward_q4_0_xpu(