add bf16/fp16 fuse mlp support (#9726)

This commit is contained in:
Yishuo Wang 2023-12-20 10:40:45 +08:00 committed by GitHub
parent 612651cb5d
commit e54c428d30
4 changed files with 4 additions and 4 deletions

View file

@ -70,7 +70,7 @@ def baichuan_mlp_forward(
x: torch.Tensor,
) -> torch.Tensor:
x_2d = x.view(-1, x.shape[-1])
if x_2d.shape[0] == 1 and x.dtype == torch.float32 and x.device.type == 'xpu' \
if x_2d.shape[0] == 1 and x.device.type == 'xpu' \
and self.gate_proj.qtype == ggml_tensor_qtype["sym_int4"] \
and not (self.training and x.requires_grad):
import linear_q4_0

View file

@ -98,7 +98,7 @@ def llama_mlp_forward(
x: torch.Tensor,
) -> torch.Tensor:
x_2d = x.view(-1, x.shape[-1])
if x_2d.shape[0] == 1 and x.dtype == torch.float32 and x.device.type == 'xpu' \
if x_2d.shape[0] == 1 and x.device.type == 'xpu' \
and self.gate_proj.qtype == ggml_tensor_qtype["sym_int4"] \
and not (self.training and x.requires_grad):
import linear_q4_0

View file

@ -258,7 +258,7 @@ def mixtral_mlp_forward(
x: torch.Tensor,
routing_weights
) -> torch.Tensor:
if x.shape[0] == 1 and x.dtype == torch.float32 and x.device.type == 'xpu' \
if x.shape[0] == 1 and x.device.type == 'xpu' \
and self.w1.qtype == ggml_tensor_qtype["sym_int4"] \
and not (self.training and x.requires_grad):
import linear_q4_0

View file

@ -241,7 +241,7 @@ def qwen_attention_forward(
def qwen_mlp_forward(self, x: torch.Tensor) -> torch.Tensor:
x_2d = x.view(-1, x.shape[-1])
if x_2d.shape[0] == 1 and x.dtype == torch.float32 and x.device.type == 'xpu' \
if x_2d.shape[0] == 1 and x.device.type == 'xpu' \
and self.w2.qtype == ggml_tensor_qtype["sym_int4"] \
and not (self.training and x.requires_grad):
import linear_q4_0