LLM: update moe block convert to optimize rest token latency of Mixtral (#9669)
* update moe block convert * further accelerate final_hidden_states * fix style * fix style
This commit is contained in:
parent
503880809c
commit
c7741c4e84
2 changed files with 112 additions and 0 deletions
|
|
@ -616,9 +616,13 @@ def _optimize_post(model, lightweight_bmm=False):
|
|||
"to run Mixtral models.")
|
||||
modeling_module_name = model.__class__.__module__
|
||||
module = importlib.import_module(modeling_module_name)
|
||||
from bigdl.llm.transformers.models.mixtral import mixtral_moeblock_forward
|
||||
convert_forward(model,
|
||||
module.MixtralRMSNorm,
|
||||
llama_rms_norm_forward)
|
||||
convert_forward(model,
|
||||
module.MixtralSparseMoeBlock,
|
||||
mixtral_moeblock_forward)
|
||||
elif model.config.model_type == "mistral":
|
||||
if model.config.architectures is not None and \
|
||||
model.config.architectures[0] == "MixtralForCausalLM":
|
||||
|
|
|
|||
108
python/llm/src/bigdl/llm/transformers/models/mixtral.py
Normal file
108
python/llm/src/bigdl/llm/transformers/models/mixtral.py
Normal file
|
|
@ -0,0 +1,108 @@
|
|||
#
|
||||
# Copyright 2016 The BigDL Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
# Some parts of this file is adapted from
|
||||
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/mixtral/modeling_mixtral.py
|
||||
|
||||
# coding=utf-8
|
||||
# Copyright 2023 Mistral AI and the HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
||||
# and OPT implementations in this library. It has been modified from its
|
||||
# original forms to accommodate minor architectural differences compared
|
||||
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
""" PyTorch Mixtral model."""
|
||||
import math
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
from bigdl.llm.utils.common import invalidInputError
|
||||
|
||||
|
||||
def mixtral_moeblock_forward(self,
|
||||
hidden_states: torch.Tensor):
|
||||
batch_size, sequence_length, hidden_dim = hidden_states.shape
|
||||
hidden_states = hidden_states.view(-1, hidden_dim)
|
||||
bs = hidden_states.shape[0]
|
||||
# router_logits: (batch * sequence_length, n_experts)
|
||||
router_logits = self.gate(hidden_states)
|
||||
|
||||
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
|
||||
routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
|
||||
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
|
||||
# we cast back to the input dtype
|
||||
routing_weights = routing_weights.to(hidden_states.dtype)
|
||||
|
||||
if bs > 1:
|
||||
final_hidden_states = torch.zeros(
|
||||
(batch_size * sequence_length, hidden_dim),
|
||||
dtype=hidden_states.dtype,
|
||||
device=hidden_states.device
|
||||
)
|
||||
# One hot encode the selected experts to create an expert mask
|
||||
# this will be used to easily index which expert is going to be sollicitated
|
||||
expert_mask = torch.nn.functional.one_hot(selected_experts,
|
||||
num_classes=self.num_experts).permute(2, 1, 0)
|
||||
|
||||
# Loop over all available experts in the model and perform the computation on each expert
|
||||
for expert_idx in range(self.num_experts):
|
||||
expert_layer = self.experts[expert_idx]
|
||||
idx, top_x = torch.where(expert_mask[expert_idx])
|
||||
|
||||
if top_x.shape[0] == 0:
|
||||
continue
|
||||
|
||||
# in torch it is faster to index using lists than torch tensors
|
||||
top_x_list = top_x.tolist()
|
||||
idx_list = idx.tolist()
|
||||
|
||||
# Index the correct hidden states and compute the expert hidden state for
|
||||
# the current expert. We need to make sure to multiply the output hidden
|
||||
# states by `routing_weights` on the corresponding tokens (top-1 and top-2)
|
||||
current_state = hidden_states[None, top_x_list].reshape(-1, hidden_dim)
|
||||
current_hidden_states = expert_layer(current_state,
|
||||
routing_weights[top_x_list, idx_list, None])
|
||||
|
||||
# However `index_add_` only support torch tensors for indexing so we'll use
|
||||
# the `top_x` tensor here.
|
||||
final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
|
||||
else:
|
||||
selected_experts = selected_experts[0].cpu().tolist()
|
||||
for idx in range(self.top_k):
|
||||
exp_id = selected_experts[idx]
|
||||
expert_layer = self.experts[exp_id]
|
||||
weight = routing_weights[:, idx]
|
||||
if idx == 0:
|
||||
final_hidden_states = expert_layer(hidden_states, weight)
|
||||
else:
|
||||
final_hidden_states = final_hidden_states + expert_layer(hidden_states, weight)
|
||||
|
||||
final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
|
||||
return final_hidden_states, router_logits
|
||||
Loading…
Reference in a new issue