LLM: Speed-up mixtral in pipeline parallel inference (#10472)
* speed-up mixtral * fix style
This commit is contained in:
		
							parent
							
								
									b9d4280892
								
							
						
					
					
						commit
						34d0a9328c
					
				
					 2 changed files with 191 additions and 2 deletions
				
			
		| 
						 | 
				
			
			@ -1131,7 +1131,7 @@ def _optimize_post(model, lightweight_bmm=False):
 | 
			
		|||
        modeling_module_name = model.__class__.__module__
 | 
			
		||||
        module = importlib.import_module(modeling_module_name)
 | 
			
		||||
        from bigdl.llm.transformers.models.mixtral import mixtral_moeblock_forward, \
 | 
			
		||||
            mixtral_attention_forward, mixtral_mlp_forward
 | 
			
		||||
            mixtral_attention_forward, mixtral_mlp_forward, mixtral_model_forward
 | 
			
		||||
        convert_forward(model,
 | 
			
		||||
                        module.MixtralAttention,
 | 
			
		||||
                        mixtral_attention_forward)
 | 
			
		||||
| 
						 | 
				
			
			@ -1144,6 +1144,10 @@ def _optimize_post(model, lightweight_bmm=False):
 | 
			
		|||
        convert_forward(model,
 | 
			
		||||
                        module.MixtralBLockSparseTop2MLP,
 | 
			
		||||
                        mixtral_mlp_forward)
 | 
			
		||||
        convert_forward(model,
 | 
			
		||||
                        module.MixtralModel,
 | 
			
		||||
                        mixtral_model_forward)
 | 
			
		||||
 | 
			
		||||
    elif model.config.model_type == "phi-msft" and \
 | 
			
		||||
            hasattr(model.config, "num_local_experts"):
 | 
			
		||||
        # For phixtral, limit the condition to avoid applying on phi-2 hosted by ModelScope
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -38,7 +38,12 @@
 | 
			
		|||
 | 
			
		||||
""" PyTorch Mixtral model."""
 | 
			
		||||
import math
 | 
			
		||||
from typing import Optional, Tuple
 | 
			
		||||
from typing import Optional, Tuple, Union, List
 | 
			
		||||
from transformers.modeling_outputs import MoeModelOutputWithPast
 | 
			
		||||
from transformers.cache_utils import Cache, DynamicCache
 | 
			
		||||
from transformers.modeling_attn_mask_utils import (
 | 
			
		||||
    _prepare_4d_causal_attention_mask,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
from torch import nn
 | 
			
		||||
| 
						 | 
				
			
			@ -378,3 +383,183 @@ def mixtral_mlp_forward(
 | 
			
		|||
        current_hidden_states = self.act_fn(self.w1(x)) * self.w3(x)
 | 
			
		||||
        current_hidden_states = self.w2(current_hidden_states)
 | 
			
		||||
        return routing_weights * current_hidden_states
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def mixtral_model_forward(
 | 
			
		||||
        self,
 | 
			
		||||
        input_ids: torch.LongTensor = None,
 | 
			
		||||
        attention_mask: Optional[torch.Tensor] = None,
 | 
			
		||||
        position_ids: Optional[torch.LongTensor] = None,
 | 
			
		||||
        past_key_values: Optional[List[torch.FloatTensor]] = None,
 | 
			
		||||
        inputs_embeds: Optional[torch.FloatTensor] = None,
 | 
			
		||||
        use_cache: Optional[bool] = None,
 | 
			
		||||
        output_attentions: Optional[bool] = None,
 | 
			
		||||
        output_hidden_states: Optional[bool] = None,
 | 
			
		||||
        output_router_logits: Optional[bool] = None,
 | 
			
		||||
        return_dict: Optional[bool] = None,
 | 
			
		||||
) -> Union[Tuple, MoeModelOutputWithPast]:
 | 
			
		||||
    output_attentions = output_attentions if output_attentions is not None \
 | 
			
		||||
        else self.config.output_attentions
 | 
			
		||||
    output_router_logits = (
 | 
			
		||||
        output_router_logits if output_router_logits is not None
 | 
			
		||||
        else self.config.output_router_logits
 | 
			
		||||
    )
 | 
			
		||||
    output_hidden_states = (
 | 
			
		||||
        output_hidden_states if output_hidden_states is not None else
 | 
			
		||||
        self.config.output_hidden_states
 | 
			
		||||
    )
 | 
			
		||||
    use_cache = use_cache if use_cache is not None else self.config.use_cache
 | 
			
		||||
 | 
			
		||||
    return_dict = return_dict if return_dict is not None else self.config.use_return_dict
 | 
			
		||||
 | 
			
		||||
    # retrieve input_ids and inputs_embeds
 | 
			
		||||
    if input_ids is not None and inputs_embeds is not None:
 | 
			
		||||
        invalidInputError(False, "You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")  # noqa
 | 
			
		||||
    elif input_ids is not None:
 | 
			
		||||
        batch_size, seq_length = input_ids.shape
 | 
			
		||||
    elif inputs_embeds is not None:
 | 
			
		||||
        batch_size, seq_length, _ = inputs_embeds.shape
 | 
			
		||||
    else:
 | 
			
		||||
        invalidInputError(False, "You have to specify either decoder_input_ids or decoder_inputs_embeds")  # noqa
 | 
			
		||||
 | 
			
		||||
    past_key_values_length = 0
 | 
			
		||||
 | 
			
		||||
    if use_cache:
 | 
			
		||||
        use_legacy_cache = not isinstance(past_key_values, Cache)
 | 
			
		||||
        if use_legacy_cache:
 | 
			
		||||
            past_key_values = DynamicCache.from_legacy_cache(past_key_values)
 | 
			
		||||
        past_key_values_length = past_key_values.get_usable_length(seq_length)
 | 
			
		||||
 | 
			
		||||
    if position_ids is None:
 | 
			
		||||
        device = input_ids.device if input_ids is not None else inputs_embeds.device
 | 
			
		||||
        position_ids = torch.arange(
 | 
			
		||||
            past_key_values_length, seq_length + past_key_values_length,
 | 
			
		||||
            dtype=torch.long, device=device
 | 
			
		||||
        )
 | 
			
		||||
        position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
 | 
			
		||||
    else:
 | 
			
		||||
        position_ids = position_ids.view(-1, seq_length).long()
 | 
			
		||||
 | 
			
		||||
    if inputs_embeds is None:
 | 
			
		||||
        inputs_embeds = self.embed_tokens(input_ids)
 | 
			
		||||
 | 
			
		||||
    if attention_mask is not None and self._use_flash_attention_2 and use_cache:
 | 
			
		||||
        is_padding_right = attention_mask[:, -1].sum().item() != batch_size
 | 
			
		||||
        if is_padding_right:
 | 
			
		||||
            invalidInputError(
 | 
			
		||||
                False,
 | 
			
		||||
                "You are attempting to perform batched generation with padding_side='right'"
 | 
			
		||||
                " this may lead to unexpected behaviour for Flash Attention version of Mixtral. Make sure to "  # noqa
 | 
			
		||||
                " call `tokenizer.padding_side  = 'left'` before tokenizing the input. "
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
    if self._use_flash_attention_2:
 | 
			
		||||
        # 2d mask is passed through the layers
 | 
			
		||||
        attention_mask = attention_mask \
 | 
			
		||||
            if (attention_mask is not None and 0 in attention_mask) else None
 | 
			
		||||
    else:
 | 
			
		||||
        # 4d mask is passed through the layers
 | 
			
		||||
        attention_mask = _prepare_4d_causal_attention_mask(
 | 
			
		||||
            attention_mask,
 | 
			
		||||
            (batch_size, seq_length),
 | 
			
		||||
            inputs_embeds,
 | 
			
		||||
            past_key_values_length,
 | 
			
		||||
            sliding_window=self.config.sliding_window,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    hidden_states = inputs_embeds
 | 
			
		||||
 | 
			
		||||
    if self.gradient_checkpointing and self.training:
 | 
			
		||||
        if use_cache:
 | 
			
		||||
            logger.warning_once(
 | 
			
		||||
                "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."  # noqa
 | 
			
		||||
            )
 | 
			
		||||
            use_cache = False
 | 
			
		||||
 | 
			
		||||
    # decoder layers
 | 
			
		||||
    all_hidden_states = () if output_hidden_states else None
 | 
			
		||||
    all_self_attns = () if output_attentions else None
 | 
			
		||||
    all_router_logits = () if output_router_logits else None
 | 
			
		||||
    next_decoder_cache = None
 | 
			
		||||
 | 
			
		||||
    for decoder_layer in self.layers:
 | 
			
		||||
        if output_hidden_states:
 | 
			
		||||
            all_hidden_states += (hidden_states,)
 | 
			
		||||
 | 
			
		||||
        if self.gradient_checkpointing and self.training:
 | 
			
		||||
            layer_outputs = self._gradient_checkpointing_func(
 | 
			
		||||
                decoder_layer.__call__,
 | 
			
		||||
                hidden_states,
 | 
			
		||||
                attention_mask,
 | 
			
		||||
                position_ids,
 | 
			
		||||
                past_key_values,
 | 
			
		||||
                output_attentions,
 | 
			
		||||
                output_router_logits,
 | 
			
		||||
                use_cache,
 | 
			
		||||
            )
 | 
			
		||||
        else:
 | 
			
		||||
            # bigdl-llm changes:
 | 
			
		||||
            #
 | 
			
		||||
            # Avoid moving `attention_mask`` and `position_ids`` to other devices multiple times.
 | 
			
		||||
            #
 | 
			
		||||
            # When the model is partitioned on two different devices using
 | 
			
		||||
            # `accelerate`'s `dispatch``, a hook to move inputs to the correct device is
 | 
			
		||||
            # added to each layer's `forward``, which will result in moving `attention_mask`
 | 
			
		||||
            # and `position_ids`, which allocated on device:0, to other devices for each
 | 
			
		||||
            # decoder layer not in device:0.
 | 
			
		||||
            #
 | 
			
		||||
            # To avoid this, we move `attention_mask` and `position_ids` to the device of
 | 
			
		||||
            # the current layer before the forward call, so that the moving is only done once
 | 
			
		||||
            # for each devices other than devie:0.
 | 
			
		||||
            #
 | 
			
		||||
            curr_device = decoder_layer.input_layernorm.weight.device
 | 
			
		||||
            if attention_mask is not None:
 | 
			
		||||
                attention_mask = attention_mask.to(curr_device)
 | 
			
		||||
            if position_ids is not None:
 | 
			
		||||
                position_ids = position_ids.to(curr_device)
 | 
			
		||||
            # bigdl-llm changes end
 | 
			
		||||
            layer_outputs = decoder_layer(
 | 
			
		||||
                hidden_states,
 | 
			
		||||
                attention_mask=attention_mask,
 | 
			
		||||
                position_ids=position_ids,
 | 
			
		||||
                past_key_value=past_key_values,
 | 
			
		||||
                output_attentions=output_attentions,
 | 
			
		||||
                output_router_logits=output_router_logits,
 | 
			
		||||
                use_cache=use_cache,
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        hidden_states = layer_outputs[0]
 | 
			
		||||
 | 
			
		||||
        if use_cache:
 | 
			
		||||
            next_decoder_cache = layer_outputs[2 if output_attentions else 1]
 | 
			
		||||
 | 
			
		||||
        if output_attentions:
 | 
			
		||||
            all_self_attns += (layer_outputs[1],)
 | 
			
		||||
 | 
			
		||||
        if output_router_logits:
 | 
			
		||||
            all_router_logits += (layer_outputs[-1],)
 | 
			
		||||
 | 
			
		||||
    hidden_states = self.norm(hidden_states)
 | 
			
		||||
 | 
			
		||||
    # add hidden states from the last decoder layer
 | 
			
		||||
    if output_hidden_states:
 | 
			
		||||
        all_hidden_states += (hidden_states,)
 | 
			
		||||
 | 
			
		||||
    next_cache = None
 | 
			
		||||
    if use_cache:
 | 
			
		||||
        next_cache = next_decoder_cache.to_legacy_cache() \
 | 
			
		||||
            if use_legacy_cache else next_decoder_cache
 | 
			
		||||
 | 
			
		||||
    if not return_dict:
 | 
			
		||||
        return tuple(
 | 
			
		||||
            v
 | 
			
		||||
            for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_router_logits]  # noqa
 | 
			
		||||
            if v is not None
 | 
			
		||||
        )
 | 
			
		||||
    return MoeModelOutputWithPast(
 | 
			
		||||
        last_hidden_state=hidden_states,
 | 
			
		||||
        past_key_values=next_cache,
 | 
			
		||||
        hidden_states=all_hidden_states,
 | 
			
		||||
        attentions=all_self_attns,
 | 
			
		||||
        router_logits=all_router_logits,
 | 
			
		||||
    )
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue