From 8ed89557e541ec4ea1f150ad9b63dd08a50c4c0f Mon Sep 17 00:00:00 2001 From: Ruonan Wang Date: Mon, 18 Dec 2023 16:59:52 +0800 Subject: [PATCH] LLM: add mlp optimization of mixtral (#9709) --- .../llm/src/bigdl/llm/transformers/convert.py | 5 ++++- .../bigdl/llm/transformers/models/mixtral.py | 20 +++++++++++++++++++ 2 files changed, 24 insertions(+), 1 deletion(-) diff --git a/python/llm/src/bigdl/llm/transformers/convert.py b/python/llm/src/bigdl/llm/transformers/convert.py index fbbd280b..e9a3bcdb 100644 --- a/python/llm/src/bigdl/llm/transformers/convert.py +++ b/python/llm/src/bigdl/llm/transformers/convert.py @@ -621,7 +621,7 @@ def _optimize_post(model, lightweight_bmm=False): modeling_module_name = model.__class__.__module__ module = importlib.import_module(modeling_module_name) from bigdl.llm.transformers.models.mixtral import mixtral_moeblock_forward, \ - mixtral_attention_forward + mixtral_attention_forward, mixtral_mlp_forward convert_forward(model, module.MixtralAttention, mixtral_attention_forward) @@ -631,6 +631,9 @@ def _optimize_post(model, lightweight_bmm=False): convert_forward(model, module.MixtralSparseMoeBlock, mixtral_moeblock_forward) + convert_forward(model, + module.MixtralBLockSparseTop2MLP, + mixtral_mlp_forward) elif model.config.model_type == "mistral": if model.config.architectures is not None and \ model.config.architectures[0] == "MixtralForCausalLM": diff --git a/python/llm/src/bigdl/llm/transformers/models/mixtral.py b/python/llm/src/bigdl/llm/transformers/models/mixtral.py index fda47df5..fd05c963 100644 --- a/python/llm/src/bigdl/llm/transformers/models/mixtral.py +++ b/python/llm/src/bigdl/llm/transformers/models/mixtral.py @@ -43,6 +43,7 @@ from typing import Optional, Tuple import torch from torch import nn import torch.nn.functional as F +from bigdl.llm.ggml.quantize import ggml_tensor_qtype from bigdl.llm.utils.common import invalidInputError from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb,\ @@ -250,3 +251,22 @@ def mixtral_attention_forward( attn_weights = None return attn_output, attn_weights, past_key_value + + +def mixtral_mlp_forward( + self, + x: torch.Tensor, + routing_weights +) -> torch.Tensor: + if x.shape[0] == 1 and x.dtype == torch.float32 and x.device.type == 'xpu' \ + and self.w1.qtype == ggml_tensor_qtype["sym_int4"] \ + and not (self.training and x.requires_grad): + import linear_q4_0 + return self.w2(linear_q4_0.mlp_forward_q4_0_xpu( + x, self.w1.weight.data, self.w3.weight.data, + x.shape[0], x.shape[1], self.w1.out_len, + )) * routing_weights + else: + current_hidden_states = self.act_fn(self.w1(x)) * self.w3(x) + current_hidden_states = self.w2(current_hidden_states) + return routing_weights * current_hidden_states