LLM: add mlp optimization of mixtral (#9709)
This commit is contained in:
parent
b3647507c0
commit
8ed89557e5
2 changed files with 24 additions and 1 deletions
|
|
@ -621,7 +621,7 @@ def _optimize_post(model, lightweight_bmm=False):
|
||||||
modeling_module_name = model.__class__.__module__
|
modeling_module_name = model.__class__.__module__
|
||||||
module = importlib.import_module(modeling_module_name)
|
module = importlib.import_module(modeling_module_name)
|
||||||
from bigdl.llm.transformers.models.mixtral import mixtral_moeblock_forward, \
|
from bigdl.llm.transformers.models.mixtral import mixtral_moeblock_forward, \
|
||||||
mixtral_attention_forward
|
mixtral_attention_forward, mixtral_mlp_forward
|
||||||
convert_forward(model,
|
convert_forward(model,
|
||||||
module.MixtralAttention,
|
module.MixtralAttention,
|
||||||
mixtral_attention_forward)
|
mixtral_attention_forward)
|
||||||
|
|
@ -631,6 +631,9 @@ def _optimize_post(model, lightweight_bmm=False):
|
||||||
convert_forward(model,
|
convert_forward(model,
|
||||||
module.MixtralSparseMoeBlock,
|
module.MixtralSparseMoeBlock,
|
||||||
mixtral_moeblock_forward)
|
mixtral_moeblock_forward)
|
||||||
|
convert_forward(model,
|
||||||
|
module.MixtralBLockSparseTop2MLP,
|
||||||
|
mixtral_mlp_forward)
|
||||||
elif model.config.model_type == "mistral":
|
elif model.config.model_type == "mistral":
|
||||||
if model.config.architectures is not None and \
|
if model.config.architectures is not None and \
|
||||||
model.config.architectures[0] == "MixtralForCausalLM":
|
model.config.architectures[0] == "MixtralForCausalLM":
|
||||||
|
|
|
||||||
|
|
@ -43,6 +43,7 @@ from typing import Optional, Tuple
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
from bigdl.llm.ggml.quantize import ggml_tensor_qtype
|
||||||
from bigdl.llm.utils.common import invalidInputError
|
from bigdl.llm.utils.common import invalidInputError
|
||||||
from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
|
from bigdl.llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
|
||||||
from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb,\
|
from bigdl.llm.transformers.models.utils import apply_rotary_pos_emb,\
|
||||||
|
|
@ -250,3 +251,22 @@ def mixtral_attention_forward(
|
||||||
attn_weights = None
|
attn_weights = None
|
||||||
|
|
||||||
return attn_output, attn_weights, past_key_value
|
return attn_output, attn_weights, past_key_value
|
||||||
|
|
||||||
|
|
||||||
|
def mixtral_mlp_forward(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
routing_weights
|
||||||
|
) -> torch.Tensor:
|
||||||
|
if x.shape[0] == 1 and x.dtype == torch.float32 and x.device.type == 'xpu' \
|
||||||
|
and self.w1.qtype == ggml_tensor_qtype["sym_int4"] \
|
||||||
|
and not (self.training and x.requires_grad):
|
||||||
|
import linear_q4_0
|
||||||
|
return self.w2(linear_q4_0.mlp_forward_q4_0_xpu(
|
||||||
|
x, self.w1.weight.data, self.w3.weight.data,
|
||||||
|
x.shape[0], x.shape[1], self.w1.out_len,
|
||||||
|
)) * routing_weights
|
||||||
|
else:
|
||||||
|
current_hidden_states = self.act_fn(self.w1(x)) * self.w3(x)
|
||||||
|
current_hidden_states = self.w2(current_hidden_states)
|
||||||
|
return routing_weights * current_hidden_states
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue