From fea6f16057cda511544f8b4670e5ba1158885bbc Mon Sep 17 00:00:00 2001 From: Ruonan Wang Date: Tue, 9 Jan 2024 09:56:32 +0800 Subject: [PATCH] LLM: add mlp fusion for fp8e5 and update related check (#9860) * update mlp fusion * fix style * update --- .../src/bigdl/llm/transformers/models/baichuan2.py | 9 +++++---- .../llm/src/bigdl/llm/transformers/models/llama.py | 8 ++++---- .../src/bigdl/llm/transformers/models/mixtral.py | 8 ++++---- .../llm/src/bigdl/llm/transformers/models/qwen.py | 9 +++++---- .../llm/src/bigdl/llm/transformers/models/utils.py | 14 ++++++++++++++ 5 files changed, 32 insertions(+), 16 deletions(-) diff --git a/python/llm/src/bigdl/llm/transformers/models/baichuan2.py b/python/llm/src/bigdl/llm/transformers/models/baichuan2.py index 5a0e537d..802b8871 100644 --- a/python/llm/src/bigdl/llm/transformers/models/baichuan2.py +++ b/python/llm/src/bigdl/llm/transformers/models/baichuan2.py @@ -27,6 +27,7 @@ from bigdl.llm.ggml.quantize import ggml_tensor_qtype from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb_no_cache_xpu +from bigdl.llm.transformers.models.utils import mlp_fusion_check from transformers.utils import logging logger = logging.get_logger(__name__) @@ -66,15 +67,15 @@ def baichuan_mlp_forward( x: torch.Tensor, ) -> torch.Tensor: x_2d = x.view(-1, x.shape[-1]) - if x_2d.shape[0] == 1 and x.device.type == 'xpu' \ - and self.gate_proj.qtype == ggml_tensor_qtype["sym_int4"] \ - and not (self.training and x.requires_grad): + qtype = getattr(self.gate_proj, "qtype", None) + if mlp_fusion_check(x_2d, qtype, self.training): import linear_q4_0 if not x_2d.is_contiguous(): x_2d = x_2d.contiguous() - return self.down_proj(linear_q4_0.mlp_forward_q4_0_xpu( + return self.down_proj(linear_q4_0.mlp_forward_xpu( x_2d, self.gate_proj.weight.data, self.up_proj.weight.data, x_2d.shape[0], x_2d.shape[1], self.gate_proj.out_len, + qtype )) return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) diff --git a/python/llm/src/bigdl/llm/transformers/models/llama.py b/python/llm/src/bigdl/llm/transformers/models/llama.py index 11473803..55bcfb37 100644 --- a/python/llm/src/bigdl/llm/transformers/models/llama.py +++ b/python/llm/src/bigdl/llm/transformers/models/llama.py @@ -44,6 +44,7 @@ from bigdl.llm.transformers.models.utils import is_enough_kv_cache_room_4_31, \ apply_rotary_pos_emb, is_enough_kv_cache_room_4_36 from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb_no_cache_xpu from bigdl.llm.transformers.models.utils import use_flash_attention, use_esimd_sdp +from bigdl.llm.transformers.models.utils import mlp_fusion_check from transformers.modeling_outputs import BaseModelOutputWithPast from bigdl.llm.transformers.low_bit_linear import SYM_INT4 from bigdl.llm.ggml.quantize import ggml_tensor_qtype @@ -104,15 +105,14 @@ def llama_mlp_forward( ) -> torch.Tensor: x_2d = x.view(-1, x.shape[-1]) qtype = getattr(self.gate_proj, "qtype", None) - if x_2d.shape[0] == 1 and x.device.type == 'xpu' \ - and qtype == ggml_tensor_qtype["sym_int4"] \ - and not (self.training and x.requires_grad): + if mlp_fusion_check(x_2d, qtype, self.training): import linear_q4_0 if not x_2d.is_contiguous(): x_2d = x_2d.contiguous() - return self.down_proj(linear_q4_0.mlp_forward_q4_0_xpu( + return self.down_proj(linear_q4_0.mlp_forward_xpu( x_2d, self.gate_proj.weight.data, self.up_proj.weight.data, x_2d.shape[0], x_2d.shape[1], self.gate_proj.out_len, + qtype )) return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) diff --git a/python/llm/src/bigdl/llm/transformers/models/mixtral.py b/python/llm/src/bigdl/llm/transformers/models/mixtral.py index 36251834..72e4f110 100644 --- a/python/llm/src/bigdl/llm/transformers/models/mixtral.py +++ b/python/llm/src/bigdl/llm/transformers/models/mixtral.py @@ -314,13 +314,13 @@ def mixtral_mlp_forward( x: torch.Tensor, routing_weights ) -> torch.Tensor: - if x.shape[0] == 1 and x.device.type == 'xpu' \ - and self.w1.qtype == ggml_tensor_qtype["sym_int4"] \ - and not (self.training and x.requires_grad): + qtype = getattr(self.w1, "qtype", None) + if mlp_fusion_check(x, qtype, self.training): import linear_q4_0 - return self.w2(linear_q4_0.mlp_forward_q4_0_xpu( + return self.w2(linear_q4_0.mlp_forward_xpu( x, self.w1.weight.data, self.w3.weight.data, x.shape[0], x.shape[1], self.w1.out_len, + qtype, )) * routing_weights else: current_hidden_states = self.act_fn(self.w1(x)) * self.w3(x) diff --git a/python/llm/src/bigdl/llm/transformers/models/qwen.py b/python/llm/src/bigdl/llm/transformers/models/qwen.py index c2c6ef0e..e0198c68 100644 --- a/python/llm/src/bigdl/llm/transformers/models/qwen.py +++ b/python/llm/src/bigdl/llm/transformers/models/qwen.py @@ -40,6 +40,7 @@ from bigdl.llm.transformers.models.utils import extend_kv_cache, init_kv_cache, from bigdl.llm.transformers.models.utils import init_fp8_kv_cache, extend_fp8_kv_cache, \ append_fp8_kv_cache, restore_fp8_kv_cache from bigdl.llm.transformers.models.utils import rotate_half, quantize_kv_cache +from bigdl.llm.transformers.models.utils import mlp_fusion_check from bigdl.llm.utils.common import invalidInputError, invalidOperationError from bigdl.llm.ggml.quantize import ggml_tensor_qtype @@ -276,14 +277,14 @@ def core_attn(self, query, key, value, causal_mask=None, attention_mask=None, he def qwen_mlp_forward(self, x: torch.Tensor) -> torch.Tensor: x_2d = x.view(-1, x.shape[-1]) - if x_2d.shape[0] == 1 and x.device.type == 'xpu' \ - and self.w2.qtype == ggml_tensor_qtype["sym_int4"] \ - and not (self.training and x.requires_grad): + qtype = getattr(self.w1, "qtype", None) + if mlp_fusion_check(x_2d, qtype, self.training): import linear_q4_0 if not x_2d.is_contiguous(): x_2d = x_2d.contiguous() - return self.c_proj(linear_q4_0.mlp_forward_q4_0_xpu( + return self.c_proj(linear_q4_0.mlp_forward_xpu( x_2d, self.w2.weight.data, self.w1.weight.data, x_2d.shape[0], x_2d.shape[1], self.w2.out_len, + qtype )) return self.c_proj(F.silu(self.w2(x)) * self.w1(x)) diff --git a/python/llm/src/bigdl/llm/transformers/models/utils.py b/python/llm/src/bigdl/llm/transformers/models/utils.py index 2bee81f7..fbd802ca 100644 --- a/python/llm/src/bigdl/llm/transformers/models/utils.py +++ b/python/llm/src/bigdl/llm/transformers/models/utils.py @@ -248,3 +248,17 @@ def use_esimd_sdp(q_len, head_dim, query_states): return False else: return False + + +def mlp_fusion_check(x, qtype, training): + invalidInputError(x.dim() == 2, + "Here input x's dim should be 2.") + if x.shape[0] != 1: + return False + if x.device.type != 'xpu': + return False + if qtype not in [ggml_tensor_qtype["sym_int4"], ggml_tensor_qtype["fp8_e5m2"]]: + return False + if training or x.requires_grad: + return False + return True