diff --git a/python/llm/src/bigdl/llm/transformers/models/baichuan2.py b/python/llm/src/bigdl/llm/transformers/models/baichuan2.py index e6a2ab49..da45b733 100644 --- a/python/llm/src/bigdl/llm/transformers/models/baichuan2.py +++ b/python/llm/src/bigdl/llm/transformers/models/baichuan2.py @@ -70,7 +70,7 @@ def baichuan_mlp_forward( x: torch.Tensor, ) -> torch.Tensor: x_2d = x.view(-1, x.shape[-1]) - if x_2d.shape[0] == 1 and x.dtype == torch.float32 and x.device.type == 'xpu' \ + if x_2d.shape[0] == 1 and x.device.type == 'xpu' \ and self.gate_proj.qtype == ggml_tensor_qtype["sym_int4"] \ and not (self.training and x.requires_grad): import linear_q4_0 diff --git a/python/llm/src/bigdl/llm/transformers/models/llama.py b/python/llm/src/bigdl/llm/transformers/models/llama.py index 8572b92c..7594c74b 100644 --- a/python/llm/src/bigdl/llm/transformers/models/llama.py +++ b/python/llm/src/bigdl/llm/transformers/models/llama.py @@ -98,7 +98,7 @@ def llama_mlp_forward( x: torch.Tensor, ) -> torch.Tensor: x_2d = x.view(-1, x.shape[-1]) - if x_2d.shape[0] == 1 and x.dtype == torch.float32 and x.device.type == 'xpu' \ + if x_2d.shape[0] == 1 and x.device.type == 'xpu' \ and self.gate_proj.qtype == ggml_tensor_qtype["sym_int4"] \ and not (self.training and x.requires_grad): import linear_q4_0 diff --git a/python/llm/src/bigdl/llm/transformers/models/mixtral.py b/python/llm/src/bigdl/llm/transformers/models/mixtral.py index fd05c963..bc312aad 100644 --- a/python/llm/src/bigdl/llm/transformers/models/mixtral.py +++ b/python/llm/src/bigdl/llm/transformers/models/mixtral.py @@ -258,7 +258,7 @@ def mixtral_mlp_forward( x: torch.Tensor, routing_weights ) -> torch.Tensor: - if x.shape[0] == 1 and x.dtype == torch.float32 and x.device.type == 'xpu' \ + if x.shape[0] == 1 and x.device.type == 'xpu' \ and self.w1.qtype == ggml_tensor_qtype["sym_int4"] \ and not (self.training and x.requires_grad): import linear_q4_0 diff --git a/python/llm/src/bigdl/llm/transformers/models/qwen.py b/python/llm/src/bigdl/llm/transformers/models/qwen.py index 71f29236..5e4f5dcf 100644 --- a/python/llm/src/bigdl/llm/transformers/models/qwen.py +++ b/python/llm/src/bigdl/llm/transformers/models/qwen.py @@ -241,7 +241,7 @@ def qwen_attention_forward( def qwen_mlp_forward(self, x: torch.Tensor) -> torch.Tensor: x_2d = x.view(-1, x.shape[-1]) - if x_2d.shape[0] == 1 and x.dtype == torch.float32 and x.device.type == 'xpu' \ + if x_2d.shape[0] == 1 and x.device.type == 'xpu' \ and self.w2.qtype == ggml_tensor_qtype["sym_int4"] \ and not (self.training and x.requires_grad): import linear_q4_0