LLM: add mlp optimization of mixtral (#9709)
This commit is contained in:
parent
b3647507c0
commit
8ed89557e5
2 changed files with 24 additions and 1 deletions
|
|
@ -621,7 +621,7 @@ def _optimize_post(model, lightweight_bmm=False):
|
|||
modeling_module_name = model.__class__.__module__
|
||||
module = importlib.import_module(modeling_module_name)
|
||||
from bigdl.llm.transformers.models.mixtral import mixtral_moeblock_forward, \
|
||||
mixtral_attention_forward
|
||||
mixtral_attention_forward, mixtral_mlp_forward
|
||||
convert_forward(model,
|
||||
module.MixtralAttention,
|
||||
mixtral_attention_forward)
|
||||
|
|
@ -631,6 +631,9 @@ def _optimize_post(model, lightweight_bmm=False):
|
|||
convert_forward(model,
|
||||
module.MixtralSparseMoeBlock,
|
||||
mixtral_moeblock_forward)
|
||||
convert_forward(model,
|
||||
module.MixtralBLockSparseTop2MLP,
|
||||
mixtral_mlp_forward)
|
||||
elif model.config.model_type == "mistral":
|
||||
if model.config.architectures is not None and \
|
||||
model.config.architectures[0] == "MixtralForCausalLM":
|
||||
|
|
|
|||
|
|
@ -43,6 +43,7 @@ from typing import Optional, Tuple
|
|||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
from bigdl.llm.ggml.quantize import ggml_tensor_qtype
|
||||
from bigdl.llm.utils.common import invalidInputError
|
||||
from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
|
||||
from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb,\
|
||||
|
|
@ -250,3 +251,22 @@ def mixtral_attention_forward(
|
|||
attn_weights = None
|
||||
|
||||
return attn_output, attn_weights, past_key_value
|
||||
|
||||
|
||||
def mixtral_mlp_forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
routing_weights
|
||||
) -> torch.Tensor:
|
||||
if x.shape[0] == 1 and x.dtype == torch.float32 and x.device.type == 'xpu' \
|
||||
and self.w1.qtype == ggml_tensor_qtype["sym_int4"] \
|
||||
and not (self.training and x.requires_grad):
|
||||
import linear_q4_0
|
||||
return self.w2(linear_q4_0.mlp_forward_q4_0_xpu(
|
||||
x, self.w1.weight.data, self.w3.weight.data,
|
||||
x.shape[0], x.shape[1], self.w1.out_len,
|
||||
)) * routing_weights
|
||||
else:
|
||||
current_hidden_states = self.act_fn(self.w1(x)) * self.w3(x)
|
||||
current_hidden_states = self.w2(current_hidden_states)
|
||||
return routing_weights * current_hidden_states
|
||||
|
|
|
|||
Loading…
Reference in a new issue