add bf16/fp16 fuse mlp support (#9726)
This commit is contained in:
parent
612651cb5d
commit
e54c428d30
4 changed files with 4 additions and 4 deletions
|
|
@ -70,7 +70,7 @@ def baichuan_mlp_forward(
|
|||
x: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
x_2d = x.view(-1, x.shape[-1])
|
||||
if x_2d.shape[0] == 1 and x.dtype == torch.float32 and x.device.type == 'xpu' \
|
||||
if x_2d.shape[0] == 1 and x.device.type == 'xpu' \
|
||||
and self.gate_proj.qtype == ggml_tensor_qtype["sym_int4"] \
|
||||
and not (self.training and x.requires_grad):
|
||||
import linear_q4_0
|
||||
|
|
|
|||
|
|
@ -98,7 +98,7 @@ def llama_mlp_forward(
|
|||
x: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
x_2d = x.view(-1, x.shape[-1])
|
||||
if x_2d.shape[0] == 1 and x.dtype == torch.float32 and x.device.type == 'xpu' \
|
||||
if x_2d.shape[0] == 1 and x.device.type == 'xpu' \
|
||||
and self.gate_proj.qtype == ggml_tensor_qtype["sym_int4"] \
|
||||
and not (self.training and x.requires_grad):
|
||||
import linear_q4_0
|
||||
|
|
|
|||
|
|
@ -258,7 +258,7 @@ def mixtral_mlp_forward(
|
|||
x: torch.Tensor,
|
||||
routing_weights
|
||||
) -> torch.Tensor:
|
||||
if x.shape[0] == 1 and x.dtype == torch.float32 and x.device.type == 'xpu' \
|
||||
if x.shape[0] == 1 and x.device.type == 'xpu' \
|
||||
and self.w1.qtype == ggml_tensor_qtype["sym_int4"] \
|
||||
and not (self.training and x.requires_grad):
|
||||
import linear_q4_0
|
||||
|
|
|
|||
|
|
@ -241,7 +241,7 @@ def qwen_attention_forward(
|
|||
|
||||
def qwen_mlp_forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x_2d = x.view(-1, x.shape[-1])
|
||||
if x_2d.shape[0] == 1 and x.dtype == torch.float32 and x.device.type == 'xpu' \
|
||||
if x_2d.shape[0] == 1 and x.device.type == 'xpu' \
|
||||
and self.w2.qtype == ggml_tensor_qtype["sym_int4"] \
|
||||
and not (self.training and x.requires_grad):
|
||||
import linear_q4_0
|
||||
|
|
|
|||
Loading…
Reference in a new issue