LLM: add mlp optimization of mixtral (#9709)

This commit is contained in:
Ruonan Wang 2023-12-18 16:59:52 +08:00 committed by GitHub
parent b3647507c0
commit 8ed89557e5
2 changed files with 24 additions and 1 deletions

View file

@ -621,7 +621,7 @@ def _optimize_post(model, lightweight_bmm=False):
modeling_module_name = model.__class__.__module__
module = importlib.import_module(modeling_module_name)
from bigdl.llm.transformers.models.mixtral import mixtral_moeblock_forward, \
mixtral_attention_forward
mixtral_attention_forward, mixtral_mlp_forward
convert_forward(model,
module.MixtralAttention,
mixtral_attention_forward)
@ -631,6 +631,9 @@ def _optimize_post(model, lightweight_bmm=False):
convert_forward(model,
module.MixtralSparseMoeBlock,
mixtral_moeblock_forward)
convert_forward(model,
module.MixtralBLockSparseTop2MLP,
mixtral_mlp_forward)
elif model.config.model_type == "mistral":
if model.config.architectures is not None and \
model.config.architectures[0] == "MixtralForCausalLM":

View file

@ -43,6 +43,7 @@ from typing import Optional, Tuple
import torch
from torch import nn
import torch.nn.functional as F
from bigdl.llm.ggml.quantize import ggml_tensor_qtype
from bigdl.llm.utils.common import invalidInputError
from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb,\
@ -250,3 +251,22 @@ def mixtral_attention_forward(
attn_weights = None
return attn_output, attn_weights, past_key_value
def mixtral_mlp_forward(
self,
x: torch.Tensor,
routing_weights
) -> torch.Tensor:
if x.shape[0] == 1 and x.dtype == torch.float32 and x.device.type == 'xpu' \
and self.w1.qtype == ggml_tensor_qtype["sym_int4"] \
and not (self.training and x.requires_grad):
import linear_q4_0
return self.w2(linear_q4_0.mlp_forward_q4_0_xpu(
x, self.w1.weight.data, self.w3.weight.data,
x.shape[0], x.shape[1], self.w1.out_len,
)) * routing_weights
else:
current_hidden_states = self.act_fn(self.w1(x)) * self.w3(x)
current_hidden_states = self.w2(current_hidden_states)
return routing_weights * current_hidden_states