LLM: add mlp fusion for fp8e5 and update related check (#9860)

* update mlp fusion

* fix style

* update
This commit is contained in:
Ruonan Wang 2024-01-09 09:56:32 +08:00 committed by GitHub
parent 294fd32787
commit fea6f16057
5 changed files with 32 additions and 16 deletions

View file

@ -27,6 +27,7 @@ from bigdl.llm.ggml.quantize import ggml_tensor_qtype
from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb
from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb_no_cache_xpu
from bigdl.llm.transformers.models.utils import mlp_fusion_check
from transformers.utils import logging
logger = logging.get_logger(__name__)
@ -66,15 +67,15 @@ def baichuan_mlp_forward(
x: torch.Tensor,
) -> torch.Tensor:
x_2d = x.view(-1, x.shape[-1])
if x_2d.shape[0] == 1 and x.device.type == 'xpu' \
and self.gate_proj.qtype == ggml_tensor_qtype["sym_int4"] \
and not (self.training and x.requires_grad):
qtype = getattr(self.gate_proj, "qtype", None)
if mlp_fusion_check(x_2d, qtype, self.training):
import linear_q4_0
if not x_2d.is_contiguous():
x_2d = x_2d.contiguous()
return self.down_proj(linear_q4_0.mlp_forward_q4_0_xpu(
return self.down_proj(linear_q4_0.mlp_forward_xpu(
x_2d, self.gate_proj.weight.data, self.up_proj.weight.data,
x_2d.shape[0], x_2d.shape[1], self.gate_proj.out_len,
qtype
))
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))

View file

@ -44,6 +44,7 @@ from bigdl.llm.transformers.models.utils import is_enough_kv_cache_room_4_31, \
apply_rotary_pos_emb, is_enough_kv_cache_room_4_36
from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb_no_cache_xpu
from bigdl.llm.transformers.models.utils import use_flash_attention, use_esimd_sdp
from bigdl.llm.transformers.models.utils import mlp_fusion_check
from transformers.modeling_outputs import BaseModelOutputWithPast
from bigdl.llm.transformers.low_bit_linear import SYM_INT4
from bigdl.llm.ggml.quantize import ggml_tensor_qtype
@ -104,15 +105,14 @@ def llama_mlp_forward(
) -> torch.Tensor:
x_2d = x.view(-1, x.shape[-1])
qtype = getattr(self.gate_proj, "qtype", None)
if x_2d.shape[0] == 1 and x.device.type == 'xpu' \
and qtype == ggml_tensor_qtype["sym_int4"] \
and not (self.training and x.requires_grad):
if mlp_fusion_check(x_2d, qtype, self.training):
import linear_q4_0
if not x_2d.is_contiguous():
x_2d = x_2d.contiguous()
return self.down_proj(linear_q4_0.mlp_forward_q4_0_xpu(
return self.down_proj(linear_q4_0.mlp_forward_xpu(
x_2d, self.gate_proj.weight.data, self.up_proj.weight.data,
x_2d.shape[0], x_2d.shape[1], self.gate_proj.out_len,
qtype
))
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))

View file

@ -314,13 +314,13 @@ def mixtral_mlp_forward(
x: torch.Tensor,
routing_weights
) -> torch.Tensor:
if x.shape[0] == 1 and x.device.type == 'xpu' \
and self.w1.qtype == ggml_tensor_qtype["sym_int4"] \
and not (self.training and x.requires_grad):
qtype = getattr(self.w1, "qtype", None)
if mlp_fusion_check(x, qtype, self.training):
import linear_q4_0
return self.w2(linear_q4_0.mlp_forward_q4_0_xpu(
return self.w2(linear_q4_0.mlp_forward_xpu(
x, self.w1.weight.data, self.w3.weight.data,
x.shape[0], x.shape[1], self.w1.out_len,
qtype,
)) * routing_weights
else:
current_hidden_states = self.act_fn(self.w1(x)) * self.w3(x)

View file

@ -40,6 +40,7 @@ from bigdl.llm.transformers.models.utils import extend_kv_cache, init_kv_cache,
from bigdl.llm.transformers.models.utils import init_fp8_kv_cache, extend_fp8_kv_cache, \
append_fp8_kv_cache, restore_fp8_kv_cache
from bigdl.llm.transformers.models.utils import rotate_half, quantize_kv_cache
from bigdl.llm.transformers.models.utils import mlp_fusion_check
from bigdl.llm.utils.common import invalidInputError, invalidOperationError
from bigdl.llm.ggml.quantize import ggml_tensor_qtype
@ -276,14 +277,14 @@ def core_attn(self, query, key, value, causal_mask=None, attention_mask=None, he
def qwen_mlp_forward(self, x: torch.Tensor) -> torch.Tensor:
x_2d = x.view(-1, x.shape[-1])
if x_2d.shape[0] == 1 and x.device.type == 'xpu' \
and self.w2.qtype == ggml_tensor_qtype["sym_int4"] \
and not (self.training and x.requires_grad):
qtype = getattr(self.w1, "qtype", None)
if mlp_fusion_check(x_2d, qtype, self.training):
import linear_q4_0
if not x_2d.is_contiguous():
x_2d = x_2d.contiguous()
return self.c_proj(linear_q4_0.mlp_forward_q4_0_xpu(
return self.c_proj(linear_q4_0.mlp_forward_xpu(
x_2d, self.w2.weight.data, self.w1.weight.data,
x_2d.shape[0], x_2d.shape[1], self.w2.out_len,
qtype
))
return self.c_proj(F.silu(self.w2(x)) * self.w1(x))

View file

@ -248,3 +248,17 @@ def use_esimd_sdp(q_len, head_dim, query_states):
return False
else:
return False
def mlp_fusion_check(x, qtype, training):
invalidInputError(x.dim() == 2,
"Here input x's dim should be 2.")
if x.shape[0] != 1:
return False
if x.device.type != 'xpu':
return False
if qtype not in [ggml_tensor_qtype["sym_int4"], ggml_tensor_qtype["fp8_e5m2"]]:
return False
if training or x.requires_grad:
return False
return True