LLM: add mlp fusion for fp8e5 and update related check (#9860)
* update mlp fusion * fix style * update
This commit is contained in:
parent
294fd32787
commit
fea6f16057
5 changed files with 32 additions and 16 deletions
|
|
@ -27,6 +27,7 @@ from bigdl.llm.ggml.quantize import ggml_tensor_qtype
|
||||||
from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
|
from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
|
||||||
from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb
|
from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb
|
||||||
from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb_no_cache_xpu
|
from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb_no_cache_xpu
|
||||||
|
from bigdl.llm.transformers.models.utils import mlp_fusion_check
|
||||||
from transformers.utils import logging
|
from transformers.utils import logging
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
@ -66,15 +67,15 @@ def baichuan_mlp_forward(
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
x_2d = x.view(-1, x.shape[-1])
|
x_2d = x.view(-1, x.shape[-1])
|
||||||
if x_2d.shape[0] == 1 and x.device.type == 'xpu' \
|
qtype = getattr(self.gate_proj, "qtype", None)
|
||||||
and self.gate_proj.qtype == ggml_tensor_qtype["sym_int4"] \
|
if mlp_fusion_check(x_2d, qtype, self.training):
|
||||||
and not (self.training and x.requires_grad):
|
|
||||||
import linear_q4_0
|
import linear_q4_0
|
||||||
if not x_2d.is_contiguous():
|
if not x_2d.is_contiguous():
|
||||||
x_2d = x_2d.contiguous()
|
x_2d = x_2d.contiguous()
|
||||||
return self.down_proj(linear_q4_0.mlp_forward_q4_0_xpu(
|
return self.down_proj(linear_q4_0.mlp_forward_xpu(
|
||||||
x_2d, self.gate_proj.weight.data, self.up_proj.weight.data,
|
x_2d, self.gate_proj.weight.data, self.up_proj.weight.data,
|
||||||
x_2d.shape[0], x_2d.shape[1], self.gate_proj.out_len,
|
x_2d.shape[0], x_2d.shape[1], self.gate_proj.out_len,
|
||||||
|
qtype
|
||||||
))
|
))
|
||||||
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -44,6 +44,7 @@ from bigdl.llm.transformers.models.utils import is_enough_kv_cache_room_4_31, \
|
||||||
apply_rotary_pos_emb, is_enough_kv_cache_room_4_36
|
apply_rotary_pos_emb, is_enough_kv_cache_room_4_36
|
||||||
from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb_no_cache_xpu
|
from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb_no_cache_xpu
|
||||||
from bigdl.llm.transformers.models.utils import use_flash_attention, use_esimd_sdp
|
from bigdl.llm.transformers.models.utils import use_flash_attention, use_esimd_sdp
|
||||||
|
from bigdl.llm.transformers.models.utils import mlp_fusion_check
|
||||||
from transformers.modeling_outputs import BaseModelOutputWithPast
|
from transformers.modeling_outputs import BaseModelOutputWithPast
|
||||||
from bigdl.llm.transformers.low_bit_linear import SYM_INT4
|
from bigdl.llm.transformers.low_bit_linear import SYM_INT4
|
||||||
from bigdl.llm.ggml.quantize import ggml_tensor_qtype
|
from bigdl.llm.ggml.quantize import ggml_tensor_qtype
|
||||||
|
|
@ -104,15 +105,14 @@ def llama_mlp_forward(
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
x_2d = x.view(-1, x.shape[-1])
|
x_2d = x.view(-1, x.shape[-1])
|
||||||
qtype = getattr(self.gate_proj, "qtype", None)
|
qtype = getattr(self.gate_proj, "qtype", None)
|
||||||
if x_2d.shape[0] == 1 and x.device.type == 'xpu' \
|
if mlp_fusion_check(x_2d, qtype, self.training):
|
||||||
and qtype == ggml_tensor_qtype["sym_int4"] \
|
|
||||||
and not (self.training and x.requires_grad):
|
|
||||||
import linear_q4_0
|
import linear_q4_0
|
||||||
if not x_2d.is_contiguous():
|
if not x_2d.is_contiguous():
|
||||||
x_2d = x_2d.contiguous()
|
x_2d = x_2d.contiguous()
|
||||||
return self.down_proj(linear_q4_0.mlp_forward_q4_0_xpu(
|
return self.down_proj(linear_q4_0.mlp_forward_xpu(
|
||||||
x_2d, self.gate_proj.weight.data, self.up_proj.weight.data,
|
x_2d, self.gate_proj.weight.data, self.up_proj.weight.data,
|
||||||
x_2d.shape[0], x_2d.shape[1], self.gate_proj.out_len,
|
x_2d.shape[0], x_2d.shape[1], self.gate_proj.out_len,
|
||||||
|
qtype
|
||||||
))
|
))
|
||||||
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -314,13 +314,13 @@ def mixtral_mlp_forward(
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
routing_weights
|
routing_weights
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
if x.shape[0] == 1 and x.device.type == 'xpu' \
|
qtype = getattr(self.w1, "qtype", None)
|
||||||
and self.w1.qtype == ggml_tensor_qtype["sym_int4"] \
|
if mlp_fusion_check(x, qtype, self.training):
|
||||||
and not (self.training and x.requires_grad):
|
|
||||||
import linear_q4_0
|
import linear_q4_0
|
||||||
return self.w2(linear_q4_0.mlp_forward_q4_0_xpu(
|
return self.w2(linear_q4_0.mlp_forward_xpu(
|
||||||
x, self.w1.weight.data, self.w3.weight.data,
|
x, self.w1.weight.data, self.w3.weight.data,
|
||||||
x.shape[0], x.shape[1], self.w1.out_len,
|
x.shape[0], x.shape[1], self.w1.out_len,
|
||||||
|
qtype,
|
||||||
)) * routing_weights
|
)) * routing_weights
|
||||||
else:
|
else:
|
||||||
current_hidden_states = self.act_fn(self.w1(x)) * self.w3(x)
|
current_hidden_states = self.act_fn(self.w1(x)) * self.w3(x)
|
||||||
|
|
|
||||||
|
|
@ -40,6 +40,7 @@ from bigdl.llm.transformers.models.utils import extend_kv_cache, init_kv_cache,
|
||||||
from bigdl.llm.transformers.models.utils import init_fp8_kv_cache, extend_fp8_kv_cache, \
|
from bigdl.llm.transformers.models.utils import init_fp8_kv_cache, extend_fp8_kv_cache, \
|
||||||
append_fp8_kv_cache, restore_fp8_kv_cache
|
append_fp8_kv_cache, restore_fp8_kv_cache
|
||||||
from bigdl.llm.transformers.models.utils import rotate_half, quantize_kv_cache
|
from bigdl.llm.transformers.models.utils import rotate_half, quantize_kv_cache
|
||||||
|
from bigdl.llm.transformers.models.utils import mlp_fusion_check
|
||||||
from bigdl.llm.utils.common import invalidInputError, invalidOperationError
|
from bigdl.llm.utils.common import invalidInputError, invalidOperationError
|
||||||
from bigdl.llm.ggml.quantize import ggml_tensor_qtype
|
from bigdl.llm.ggml.quantize import ggml_tensor_qtype
|
||||||
|
|
||||||
|
|
@ -276,14 +277,14 @@ def core_attn(self, query, key, value, causal_mask=None, attention_mask=None, he
|
||||||
|
|
||||||
def qwen_mlp_forward(self, x: torch.Tensor) -> torch.Tensor:
|
def qwen_mlp_forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
x_2d = x.view(-1, x.shape[-1])
|
x_2d = x.view(-1, x.shape[-1])
|
||||||
if x_2d.shape[0] == 1 and x.device.type == 'xpu' \
|
qtype = getattr(self.w1, "qtype", None)
|
||||||
and self.w2.qtype == ggml_tensor_qtype["sym_int4"] \
|
if mlp_fusion_check(x_2d, qtype, self.training):
|
||||||
and not (self.training and x.requires_grad):
|
|
||||||
import linear_q4_0
|
import linear_q4_0
|
||||||
if not x_2d.is_contiguous():
|
if not x_2d.is_contiguous():
|
||||||
x_2d = x_2d.contiguous()
|
x_2d = x_2d.contiguous()
|
||||||
return self.c_proj(linear_q4_0.mlp_forward_q4_0_xpu(
|
return self.c_proj(linear_q4_0.mlp_forward_xpu(
|
||||||
x_2d, self.w2.weight.data, self.w1.weight.data,
|
x_2d, self.w2.weight.data, self.w1.weight.data,
|
||||||
x_2d.shape[0], x_2d.shape[1], self.w2.out_len,
|
x_2d.shape[0], x_2d.shape[1], self.w2.out_len,
|
||||||
|
qtype
|
||||||
))
|
))
|
||||||
return self.c_proj(F.silu(self.w2(x)) * self.w1(x))
|
return self.c_proj(F.silu(self.w2(x)) * self.w1(x))
|
||||||
|
|
|
||||||
|
|
@ -248,3 +248,17 @@ def use_esimd_sdp(q_len, head_dim, query_states):
|
||||||
return False
|
return False
|
||||||
else:
|
else:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def mlp_fusion_check(x, qtype, training):
|
||||||
|
invalidInputError(x.dim() == 2,
|
||||||
|
"Here input x's dim should be 2.")
|
||||||
|
if x.shape[0] != 1:
|
||||||
|
return False
|
||||||
|
if x.device.type != 'xpu':
|
||||||
|
return False
|
||||||
|
if qtype not in [ggml_tensor_qtype["sym_int4"], ggml_tensor_qtype["fp8_e5m2"]]:
|
||||||
|
return False
|
||||||
|
if training or x.requires_grad:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue