fix mlp batch size check (#9718)
This commit is contained in:
parent
1fa7793fc0
commit
f2e6abb563
3 changed files with 8 additions and 6 deletions
|
|
@ -69,11 +69,11 @@ def baichuan_mlp_forward(
|
||||||
self,
|
self,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
if x.shape[1] == 1 and x.dtype == torch.float32 and x.device.type == 'xpu' \
|
x_2d = x.view(-1, x.shape[-1])
|
||||||
|
if x_2d.shape[0] == 1 and x.dtype == torch.float32 and x.device.type == 'xpu' \
|
||||||
and self.gate_proj.qtype == ggml_tensor_qtype["sym_int4"] \
|
and self.gate_proj.qtype == ggml_tensor_qtype["sym_int4"] \
|
||||||
and not (self.training and x.requires_grad):
|
and not (self.training and x.requires_grad):
|
||||||
import linear_q4_0
|
import linear_q4_0
|
||||||
x_2d = x.view(-1, x.shape[-1])
|
|
||||||
if not x_2d.is_contiguous():
|
if not x_2d.is_contiguous():
|
||||||
x_2d = x_2d.contiguous()
|
x_2d = x_2d.contiguous()
|
||||||
return self.down_proj(linear_q4_0.mlp_forward_q4_0_xpu(
|
return self.down_proj(linear_q4_0.mlp_forward_q4_0_xpu(
|
||||||
|
|
|
||||||
|
|
@ -42,6 +42,7 @@ from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache,
|
||||||
from bigdl.llm.transformers.models.utils import rotate_half, apply_rotary_pos_emb
|
from bigdl.llm.transformers.models.utils import rotate_half, apply_rotary_pos_emb
|
||||||
from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb_no_cache_xpu
|
from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb_no_cache_xpu
|
||||||
from bigdl.llm.transformers.low_bit_linear import SYM_INT4
|
from bigdl.llm.transformers.low_bit_linear import SYM_INT4
|
||||||
|
from bigdl.llm.ggml.quantize import ggml_tensor_qtype
|
||||||
|
|
||||||
|
|
||||||
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
||||||
|
|
@ -96,10 +97,11 @@ def llama_mlp_forward(
|
||||||
self,
|
self,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
if x.shape[1] == 1 and x.dtype == torch.float32 and x.device.type == 'xpu' \
|
x_2d = x.view(-1, x.shape[-1])
|
||||||
|
if x_2d.shape[0] == 1 and x.dtype == torch.float32 and x.device.type == 'xpu' \
|
||||||
|
and self.gate_proj.qtype == ggml_tensor_qtype["sym_int4"] \
|
||||||
and not (self.training and x.requires_grad):
|
and not (self.training and x.requires_grad):
|
||||||
import linear_q4_0
|
import linear_q4_0
|
||||||
x_2d = x.view(-1, x.shape[-1])
|
|
||||||
if not x_2d.is_contiguous():
|
if not x_2d.is_contiguous():
|
||||||
x_2d = x_2d.contiguous()
|
x_2d = x_2d.contiguous()
|
||||||
return self.down_proj(linear_q4_0.mlp_forward_q4_0_xpu(
|
return self.down_proj(linear_q4_0.mlp_forward_q4_0_xpu(
|
||||||
|
|
|
||||||
|
|
@ -240,11 +240,11 @@ def qwen_attention_forward(
|
||||||
|
|
||||||
|
|
||||||
def qwen_mlp_forward(self, x: torch.Tensor) -> torch.Tensor:
|
def qwen_mlp_forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
if x.shape[1] == 1 and x.dtype == torch.float32 and x.device.type == 'xpu' \
|
x_2d = x.view(-1, x.shape[-1])
|
||||||
|
if x_2d.shape[0] == 1 and x.dtype == torch.float32 and x.device.type == 'xpu' \
|
||||||
and self.w2.qtype == ggml_tensor_qtype["sym_int4"] \
|
and self.w2.qtype == ggml_tensor_qtype["sym_int4"] \
|
||||||
and not (self.training and x.requires_grad):
|
and not (self.training and x.requires_grad):
|
||||||
import linear_q4_0
|
import linear_q4_0
|
||||||
x_2d = x.view(-1, x.shape[-1])
|
|
||||||
if not x_2d.is_contiguous():
|
if not x_2d.is_contiguous():
|
||||||
x_2d = x_2d.contiguous()
|
x_2d = x_2d.contiguous()
|
||||||
return self.c_proj(linear_q4_0.mlp_forward_q4_0_xpu(
|
return self.c_proj(linear_q4_0.mlp_forward_q4_0_xpu(
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue