LLM: add mlp fusion for fp8e5 and update related check (#9860)

* update mlp fusion

* fix style

* update
This commit is contained in:
Ruonan Wang 2024-01-09 09:56:32 +08:00 committed by GitHub
parent 294fd32787
commit fea6f16057
5 changed files with 32 additions and 16 deletions

View file

@ -27,6 +27,7 @@ from bigdl.llm.ggml.quantize import ggml_tensor_qtype
from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb
from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb_no_cache_xpu from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb_no_cache_xpu
from bigdl.llm.transformers.models.utils import mlp_fusion_check
from transformers.utils import logging from transformers.utils import logging
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
@ -66,15 +67,15 @@ def baichuan_mlp_forward(
x: torch.Tensor, x: torch.Tensor,
) -> torch.Tensor: ) -> torch.Tensor:
x_2d = x.view(-1, x.shape[-1]) x_2d = x.view(-1, x.shape[-1])
if x_2d.shape[0] == 1 and x.device.type == 'xpu' \ qtype = getattr(self.gate_proj, "qtype", None)
and self.gate_proj.qtype == ggml_tensor_qtype["sym_int4"] \ if mlp_fusion_check(x_2d, qtype, self.training):
and not (self.training and x.requires_grad):
import linear_q4_0 import linear_q4_0
if not x_2d.is_contiguous(): if not x_2d.is_contiguous():
x_2d = x_2d.contiguous() x_2d = x_2d.contiguous()
return self.down_proj(linear_q4_0.mlp_forward_q4_0_xpu( return self.down_proj(linear_q4_0.mlp_forward_xpu(
x_2d, self.gate_proj.weight.data, self.up_proj.weight.data, x_2d, self.gate_proj.weight.data, self.up_proj.weight.data,
x_2d.shape[0], x_2d.shape[1], self.gate_proj.out_len, x_2d.shape[0], x_2d.shape[1], self.gate_proj.out_len,
qtype
)) ))
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))

View file

@ -44,6 +44,7 @@ from bigdl.llm.transformers.models.utils import is_enough_kv_cache_room_4_31, \
apply_rotary_pos_emb, is_enough_kv_cache_room_4_36 apply_rotary_pos_emb, is_enough_kv_cache_room_4_36
from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb_no_cache_xpu from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb_no_cache_xpu
from bigdl.llm.transformers.models.utils import use_flash_attention, use_esimd_sdp from bigdl.llm.transformers.models.utils import use_flash_attention, use_esimd_sdp
from bigdl.llm.transformers.models.utils import mlp_fusion_check
from transformers.modeling_outputs import BaseModelOutputWithPast from transformers.modeling_outputs import BaseModelOutputWithPast
from bigdl.llm.transformers.low_bit_linear import SYM_INT4 from bigdl.llm.transformers.low_bit_linear import SYM_INT4
from bigdl.llm.ggml.quantize import ggml_tensor_qtype from bigdl.llm.ggml.quantize import ggml_tensor_qtype
@ -104,15 +105,14 @@ def llama_mlp_forward(
) -> torch.Tensor: ) -> torch.Tensor:
x_2d = x.view(-1, x.shape[-1]) x_2d = x.view(-1, x.shape[-1])
qtype = getattr(self.gate_proj, "qtype", None) qtype = getattr(self.gate_proj, "qtype", None)
if x_2d.shape[0] == 1 and x.device.type == 'xpu' \ if mlp_fusion_check(x_2d, qtype, self.training):
and qtype == ggml_tensor_qtype["sym_int4"] \
and not (self.training and x.requires_grad):
import linear_q4_0 import linear_q4_0
if not x_2d.is_contiguous(): if not x_2d.is_contiguous():
x_2d = x_2d.contiguous() x_2d = x_2d.contiguous()
return self.down_proj(linear_q4_0.mlp_forward_q4_0_xpu( return self.down_proj(linear_q4_0.mlp_forward_xpu(
x_2d, self.gate_proj.weight.data, self.up_proj.weight.data, x_2d, self.gate_proj.weight.data, self.up_proj.weight.data,
x_2d.shape[0], x_2d.shape[1], self.gate_proj.out_len, x_2d.shape[0], x_2d.shape[1], self.gate_proj.out_len,
qtype
)) ))
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))

View file

@ -314,13 +314,13 @@ def mixtral_mlp_forward(
x: torch.Tensor, x: torch.Tensor,
routing_weights routing_weights
) -> torch.Tensor: ) -> torch.Tensor:
if x.shape[0] == 1 and x.device.type == 'xpu' \ qtype = getattr(self.w1, "qtype", None)
and self.w1.qtype == ggml_tensor_qtype["sym_int4"] \ if mlp_fusion_check(x, qtype, self.training):
and not (self.training and x.requires_grad):
import linear_q4_0 import linear_q4_0
return self.w2(linear_q4_0.mlp_forward_q4_0_xpu( return self.w2(linear_q4_0.mlp_forward_xpu(
x, self.w1.weight.data, self.w3.weight.data, x, self.w1.weight.data, self.w3.weight.data,
x.shape[0], x.shape[1], self.w1.out_len, x.shape[0], x.shape[1], self.w1.out_len,
qtype,
)) * routing_weights )) * routing_weights
else: else:
current_hidden_states = self.act_fn(self.w1(x)) * self.w3(x) current_hidden_states = self.act_fn(self.w1(x)) * self.w3(x)

View file

@ -40,6 +40,7 @@ from bigdl.llm.transformers.models.utils import extend_kv_cache, init_kv_cache,
from bigdl.llm.transformers.models.utils import init_fp8_kv_cache, extend_fp8_kv_cache, \ from bigdl.llm.transformers.models.utils import init_fp8_kv_cache, extend_fp8_kv_cache, \
append_fp8_kv_cache, restore_fp8_kv_cache append_fp8_kv_cache, restore_fp8_kv_cache
from bigdl.llm.transformers.models.utils import rotate_half, quantize_kv_cache from bigdl.llm.transformers.models.utils import rotate_half, quantize_kv_cache
from bigdl.llm.transformers.models.utils import mlp_fusion_check
from bigdl.llm.utils.common import invalidInputError, invalidOperationError from bigdl.llm.utils.common import invalidInputError, invalidOperationError
from bigdl.llm.ggml.quantize import ggml_tensor_qtype from bigdl.llm.ggml.quantize import ggml_tensor_qtype
@ -276,14 +277,14 @@ def core_attn(self, query, key, value, causal_mask=None, attention_mask=None, he
def qwen_mlp_forward(self, x: torch.Tensor) -> torch.Tensor: def qwen_mlp_forward(self, x: torch.Tensor) -> torch.Tensor:
x_2d = x.view(-1, x.shape[-1]) x_2d = x.view(-1, x.shape[-1])
if x_2d.shape[0] == 1 and x.device.type == 'xpu' \ qtype = getattr(self.w1, "qtype", None)
and self.w2.qtype == ggml_tensor_qtype["sym_int4"] \ if mlp_fusion_check(x_2d, qtype, self.training):
and not (self.training and x.requires_grad):
import linear_q4_0 import linear_q4_0
if not x_2d.is_contiguous(): if not x_2d.is_contiguous():
x_2d = x_2d.contiguous() x_2d = x_2d.contiguous()
return self.c_proj(linear_q4_0.mlp_forward_q4_0_xpu( return self.c_proj(linear_q4_0.mlp_forward_xpu(
x_2d, self.w2.weight.data, self.w1.weight.data, x_2d, self.w2.weight.data, self.w1.weight.data,
x_2d.shape[0], x_2d.shape[1], self.w2.out_len, x_2d.shape[0], x_2d.shape[1], self.w2.out_len,
qtype
)) ))
return self.c_proj(F.silu(self.w2(x)) * self.w1(x)) return self.c_proj(F.silu(self.w2(x)) * self.w1(x))

View file

@ -248,3 +248,17 @@ def use_esimd_sdp(q_len, head_dim, query_states):
return False return False
else: else:
return False return False
def mlp_fusion_check(x, qtype, training):
invalidInputError(x.dim() == 2,
"Here input x's dim should be 2.")
if x.shape[0] != 1:
return False
if x.device.type != 'xpu':
return False
if qtype not in [ggml_tensor_qtype["sym_int4"], ggml_tensor_qtype["fp8_e5m2"]]:
return False
if training or x.requires_grad:
return False
return True