LLM: Speed-up mixtral in pipeline parallel inference (#10472)
* speed-up mixtral * fix style
This commit is contained in:
parent
b9d4280892
commit
34d0a9328c
2 changed files with 191 additions and 2 deletions
|
|
@ -1131,7 +1131,7 @@ def _optimize_post(model, lightweight_bmm=False):
|
|||
modeling_module_name = model.__class__.__module__
|
||||
module = importlib.import_module(modeling_module_name)
|
||||
from bigdl.llm.transformers.models.mixtral import mixtral_moeblock_forward, \
|
||||
mixtral_attention_forward, mixtral_mlp_forward
|
||||
mixtral_attention_forward, mixtral_mlp_forward, mixtral_model_forward
|
||||
convert_forward(model,
|
||||
module.MixtralAttention,
|
||||
mixtral_attention_forward)
|
||||
|
|
@ -1144,6 +1144,10 @@ def _optimize_post(model, lightweight_bmm=False):
|
|||
convert_forward(model,
|
||||
module.MixtralBLockSparseTop2MLP,
|
||||
mixtral_mlp_forward)
|
||||
convert_forward(model,
|
||||
module.MixtralModel,
|
||||
mixtral_model_forward)
|
||||
|
||||
elif model.config.model_type == "phi-msft" and \
|
||||
hasattr(model.config, "num_local_experts"):
|
||||
# For phixtral, limit the condition to avoid applying on phi-2 hosted by ModelScope
|
||||
|
|
|
|||
|
|
@ -38,7 +38,12 @@
|
|||
|
||||
""" PyTorch Mixtral model."""
|
||||
import math
|
||||
from typing import Optional, Tuple
|
||||
from typing import Optional, Tuple, Union, List
|
||||
from transformers.modeling_outputs import MoeModelOutputWithPast
|
||||
from transformers.cache_utils import Cache, DynamicCache
|
||||
from transformers.modeling_attn_mask_utils import (
|
||||
_prepare_4d_causal_attention_mask,
|
||||
)
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
|
@ -378,3 +383,183 @@ def mixtral_mlp_forward(
|
|||
current_hidden_states = self.act_fn(self.w1(x)) * self.w3(x)
|
||||
current_hidden_states = self.w2(current_hidden_states)
|
||||
return routing_weights * current_hidden_states
|
||||
|
||||
|
||||
def mixtral_model_forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
output_router_logits: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[Tuple, MoeModelOutputWithPast]:
|
||||
output_attentions = output_attentions if output_attentions is not None \
|
||||
else self.config.output_attentions
|
||||
output_router_logits = (
|
||||
output_router_logits if output_router_logits is not None
|
||||
else self.config.output_router_logits
|
||||
)
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else
|
||||
self.config.output_hidden_states
|
||||
)
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
# retrieve input_ids and inputs_embeds
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
invalidInputError(False, "You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") # noqa
|
||||
elif input_ids is not None:
|
||||
batch_size, seq_length = input_ids.shape
|
||||
elif inputs_embeds is not None:
|
||||
batch_size, seq_length, _ = inputs_embeds.shape
|
||||
else:
|
||||
invalidInputError(False, "You have to specify either decoder_input_ids or decoder_inputs_embeds") # noqa
|
||||
|
||||
past_key_values_length = 0
|
||||
|
||||
if use_cache:
|
||||
use_legacy_cache = not isinstance(past_key_values, Cache)
|
||||
if use_legacy_cache:
|
||||
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
||||
past_key_values_length = past_key_values.get_usable_length(seq_length)
|
||||
|
||||
if position_ids is None:
|
||||
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||
position_ids = torch.arange(
|
||||
past_key_values_length, seq_length + past_key_values_length,
|
||||
dtype=torch.long, device=device
|
||||
)
|
||||
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
|
||||
else:
|
||||
position_ids = position_ids.view(-1, seq_length).long()
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
|
||||
if attention_mask is not None and self._use_flash_attention_2 and use_cache:
|
||||
is_padding_right = attention_mask[:, -1].sum().item() != batch_size
|
||||
if is_padding_right:
|
||||
invalidInputError(
|
||||
False,
|
||||
"You are attempting to perform batched generation with padding_side='right'"
|
||||
" this may lead to unexpected behaviour for Flash Attention version of Mixtral. Make sure to " # noqa
|
||||
" call `tokenizer.padding_side = 'left'` before tokenizing the input. "
|
||||
)
|
||||
|
||||
if self._use_flash_attention_2:
|
||||
# 2d mask is passed through the layers
|
||||
attention_mask = attention_mask \
|
||||
if (attention_mask is not None and 0 in attention_mask) else None
|
||||
else:
|
||||
# 4d mask is passed through the layers
|
||||
attention_mask = _prepare_4d_causal_attention_mask(
|
||||
attention_mask,
|
||||
(batch_size, seq_length),
|
||||
inputs_embeds,
|
||||
past_key_values_length,
|
||||
sliding_window=self.config.sliding_window,
|
||||
)
|
||||
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
if use_cache:
|
||||
logger.warning_once(
|
||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." # noqa
|
||||
)
|
||||
use_cache = False
|
||||
|
||||
# decoder layers
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_self_attns = () if output_attentions else None
|
||||
all_router_logits = () if output_router_logits else None
|
||||
next_decoder_cache = None
|
||||
|
||||
for decoder_layer in self.layers:
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
decoder_layer.__call__,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
past_key_values,
|
||||
output_attentions,
|
||||
output_router_logits,
|
||||
use_cache,
|
||||
)
|
||||
else:
|
||||
# bigdl-llm changes:
|
||||
#
|
||||
# Avoid moving `attention_mask`` and `position_ids`` to other devices multiple times.
|
||||
#
|
||||
# When the model is partitioned on two different devices using
|
||||
# `accelerate`'s `dispatch``, a hook to move inputs to the correct device is
|
||||
# added to each layer's `forward``, which will result in moving `attention_mask`
|
||||
# and `position_ids`, which allocated on device:0, to other devices for each
|
||||
# decoder layer not in device:0.
|
||||
#
|
||||
# To avoid this, we move `attention_mask` and `position_ids` to the device of
|
||||
# the current layer before the forward call, so that the moving is only done once
|
||||
# for each devices other than devie:0.
|
||||
#
|
||||
curr_device = decoder_layer.input_layernorm.weight.device
|
||||
if attention_mask is not None:
|
||||
attention_mask = attention_mask.to(curr_device)
|
||||
if position_ids is not None:
|
||||
position_ids = position_ids.to(curr_device)
|
||||
# bigdl-llm changes end
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_values,
|
||||
output_attentions=output_attentions,
|
||||
output_router_logits=output_router_logits,
|
||||
use_cache=use_cache,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
if use_cache:
|
||||
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
|
||||
|
||||
if output_attentions:
|
||||
all_self_attns += (layer_outputs[1],)
|
||||
|
||||
if output_router_logits:
|
||||
all_router_logits += (layer_outputs[-1],)
|
||||
|
||||
hidden_states = self.norm(hidden_states)
|
||||
|
||||
# add hidden states from the last decoder layer
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
next_cache = None
|
||||
if use_cache:
|
||||
next_cache = next_decoder_cache.to_legacy_cache() \
|
||||
if use_legacy_cache else next_decoder_cache
|
||||
|
||||
if not return_dict:
|
||||
return tuple(
|
||||
v
|
||||
for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_router_logits] # noqa
|
||||
if v is not None
|
||||
)
|
||||
return MoeModelOutputWithPast(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=next_cache,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_self_attns,
|
||||
router_logits=all_router_logits,
|
||||
)
|
||||
|
|
|
|||
Loading…
Reference in a new issue