LLM: support Qwen1.5-MoE-A2.7B-Chat pipeline parallel inference (#10864)
This commit is contained in:
		
							parent
							
								
									2d210817ff
								
							
						
					
					
						commit
						c9feffff9a
					
				
					 1 changed files with 197 additions and 1 deletions
				
			
		| 
						 | 
				
			
			@ -61,6 +61,20 @@ import os
 | 
			
		|||
 | 
			
		||||
KV_CACHE_ALLOC_BLOCK_LENGTH = int(os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256))
 | 
			
		||||
 | 
			
		||||
from transformers.models.qwen2.modeling_qwen2 import _prepare_4d_causal_attention_mask_for_sdpa
 | 
			
		||||
from transformers.models.qwen2.modeling_qwen2 import _prepare_4d_causal_attention_mask
 | 
			
		||||
from transformers.modeling_outputs import MoeModelOutputWithPast
 | 
			
		||||
 | 
			
		||||
try:
 | 
			
		||||
    from transformers.cache_utils import Cache, DynamicCache
 | 
			
		||||
except ImportError:
 | 
			
		||||
    Cache = Tuple[torch.Tensor]
 | 
			
		||||
import logging
 | 
			
		||||
from transformers import logging
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
logger = logging.get_logger(__name__)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def qwen2moe_model_forward(
 | 
			
		||||
    self,
 | 
			
		||||
| 
						 | 
				
			
			@ -79,7 +93,7 @@ def qwen2moe_model_forward(
 | 
			
		|||
    if use_cache and use_quantize_kv_cache(self.layers[0].mlp.shared_expert.up_proj, input_ids):
 | 
			
		||||
        if not isinstance(past_key_values, DynamicFp8Cache):
 | 
			
		||||
            past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
 | 
			
		||||
    return Qwen2MoeModel.forward(
 | 
			
		||||
    return qwen2_moe_model_forward_internal(
 | 
			
		||||
        self=self,
 | 
			
		||||
        input_ids=input_ids,
 | 
			
		||||
        attention_mask=attention_mask,
 | 
			
		||||
| 
						 | 
				
			
			@ -94,6 +108,188 @@ def qwen2moe_model_forward(
 | 
			
		|||
    )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def qwen2_moe_model_forward_internal(
 | 
			
		||||
        self,
 | 
			
		||||
        input_ids: torch.LongTensor = None,
 | 
			
		||||
        attention_mask: Optional[torch.Tensor] = None,
 | 
			
		||||
        position_ids: Optional[torch.LongTensor] = None,
 | 
			
		||||
        past_key_values: Optional[List[torch.FloatTensor]] = None,
 | 
			
		||||
        inputs_embeds: Optional[torch.FloatTensor] = None,
 | 
			
		||||
        use_cache: Optional[bool] = None,
 | 
			
		||||
        output_attentions: Optional[bool] = None,
 | 
			
		||||
        output_hidden_states: Optional[bool] = None,
 | 
			
		||||
        output_router_logits: Optional[bool] = None,
 | 
			
		||||
        return_dict: Optional[bool] = None,
 | 
			
		||||
) -> Union[Tuple, MoeModelOutputWithPast]:
 | 
			
		||||
        output_attentions = output_attentions if output_attentions is not None \
 | 
			
		||||
            else self.config.output_attentions
 | 
			
		||||
        output_router_logits = (
 | 
			
		||||
            output_router_logits if output_router_logits is not None else
 | 
			
		||||
            self.config.output_router_logits
 | 
			
		||||
        )
 | 
			
		||||
        output_hidden_states = (
 | 
			
		||||
            output_hidden_states if output_hidden_states is not None else
 | 
			
		||||
            self.config.output_hidden_states
 | 
			
		||||
        )
 | 
			
		||||
        use_cache = use_cache if use_cache is not None else self.config.use_cache
 | 
			
		||||
 | 
			
		||||
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
 | 
			
		||||
 | 
			
		||||
        # retrieve input_ids and inputs_embeds
 | 
			
		||||
        if input_ids is not None and inputs_embeds is not None:
 | 
			
		||||
            invalidInputError(False, "You cannot specify both decoder_input_ids and "
 | 
			
		||||
                              "decoder_inputs_embeds at the same time")
 | 
			
		||||
        elif input_ids is not None:
 | 
			
		||||
            batch_size, seq_length = input_ids.shape
 | 
			
		||||
        elif inputs_embeds is not None:
 | 
			
		||||
            batch_size, seq_length, _ = inputs_embeds.shape
 | 
			
		||||
        else:
 | 
			
		||||
            invalidInputError(False,
 | 
			
		||||
                              "You have to specify decoder_input_ids or decoder_inputs_embeds")
 | 
			
		||||
 | 
			
		||||
        if self.gradient_checkpointing and self.training:
 | 
			
		||||
            if use_cache:
 | 
			
		||||
                logger.warning_once(
 | 
			
		||||
                    "`use_cache=True` is incompatible with gradient checkpointing."
 | 
			
		||||
                    " Setting `use_cache=False`..."
 | 
			
		||||
                )
 | 
			
		||||
                use_cache = False
 | 
			
		||||
 | 
			
		||||
        past_key_values_length = 0
 | 
			
		||||
 | 
			
		||||
        if use_cache:
 | 
			
		||||
            use_legacy_cache = not isinstance(past_key_values, Cache)
 | 
			
		||||
            if use_legacy_cache:
 | 
			
		||||
                past_key_values = DynamicCache.from_legacy_cache(past_key_values)
 | 
			
		||||
            past_key_values_length = past_key_values.get_usable_length(seq_length)
 | 
			
		||||
 | 
			
		||||
        if position_ids is None:
 | 
			
		||||
            device = input_ids.device if input_ids is not None else inputs_embeds.device
 | 
			
		||||
            position_ids = torch.arange(
 | 
			
		||||
                past_key_values_length, seq_length + past_key_values_length,
 | 
			
		||||
                dtype=torch.long, device=device
 | 
			
		||||
            )
 | 
			
		||||
            position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
 | 
			
		||||
        else:
 | 
			
		||||
            position_ids = position_ids.view(-1, seq_length).long()
 | 
			
		||||
 | 
			
		||||
        if inputs_embeds is None:
 | 
			
		||||
            inputs_embeds = self.embed_tokens(input_ids)
 | 
			
		||||
 | 
			
		||||
        if attention_mask is not None and self._attn_implementation == "flash_attention_2" \
 | 
			
		||||
                and use_cache:
 | 
			
		||||
            is_padding_right = attention_mask[:, -1].sum().item() != batch_size
 | 
			
		||||
            if is_padding_right:
 | 
			
		||||
                invalidInputError(
 | 
			
		||||
                    False,
 | 
			
		||||
                    "You are attempting to perform batched generation with padding_side='right'"
 | 
			
		||||
                    " this may lead to unexpected behaviour for Flash Attention version of"
 | 
			
		||||
                    " Qwen2MoE. Make sure to call `tokenizer.padding_side='left'`"
 | 
			
		||||
                    " before tokenizing the input."
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
        if self._attn_implementation == "flash_attention_2":
 | 
			
		||||
            # 2d mask is passed through the layers
 | 
			
		||||
            attention_mask = attention_mask if (attention_mask is not None and
 | 
			
		||||
                                                0 in attention_mask) else None
 | 
			
		||||
        elif self._attn_implementation == "sdpa" and not output_attentions:
 | 
			
		||||
            # output_attentions=True can not be supported when using SDPA, and we fall back on
 | 
			
		||||
            # the manual implementation that requires a 4D causal mask in all cases.
 | 
			
		||||
            attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
 | 
			
		||||
                attention_mask,
 | 
			
		||||
                (batch_size, seq_length),
 | 
			
		||||
                inputs_embeds,
 | 
			
		||||
                past_key_values_length,
 | 
			
		||||
                sliding_window=self.config.sliding_window,
 | 
			
		||||
            )
 | 
			
		||||
        else:
 | 
			
		||||
            # 4d mask is passed through the layers
 | 
			
		||||
            attention_mask = _prepare_4d_causal_attention_mask(
 | 
			
		||||
                attention_mask,
 | 
			
		||||
                (batch_size, seq_length),
 | 
			
		||||
                inputs_embeds,
 | 
			
		||||
                past_key_values_length,
 | 
			
		||||
                sliding_window=self.config.sliding_window,
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        hidden_states = inputs_embeds
 | 
			
		||||
 | 
			
		||||
        # decoder layers
 | 
			
		||||
        all_hidden_states = () if output_hidden_states else None
 | 
			
		||||
        all_self_attns = () if output_attentions else None
 | 
			
		||||
        all_router_logits = () if output_router_logits else None
 | 
			
		||||
        next_decoder_cache = None
 | 
			
		||||
 | 
			
		||||
        for decoder_layer in self.layers:
 | 
			
		||||
            if output_hidden_states:
 | 
			
		||||
                all_hidden_states += (hidden_states,)
 | 
			
		||||
 | 
			
		||||
            if self.gradient_checkpointing and self.training:
 | 
			
		||||
                layer_outputs = self._gradient_checkpointing_func(
 | 
			
		||||
                    decoder_layer.__call__,
 | 
			
		||||
                    hidden_states,
 | 
			
		||||
                    attention_mask,
 | 
			
		||||
                    position_ids,
 | 
			
		||||
                    past_key_values,
 | 
			
		||||
                    output_attentions,
 | 
			
		||||
                    output_router_logits,
 | 
			
		||||
                    use_cache,
 | 
			
		||||
                )
 | 
			
		||||
            else:
 | 
			
		||||
                # ipex-llm changes
 | 
			
		||||
                curr_device = decoder_layer.input_layernorm.weight.device
 | 
			
		||||
                if attention_mask is not None:
 | 
			
		||||
                    attention_mask = attention_mask.to(curr_device)
 | 
			
		||||
                if position_ids is not None:
 | 
			
		||||
                    position_ids = position_ids.to(curr_device)
 | 
			
		||||
                # ipex-llm changes end
 | 
			
		||||
                layer_outputs = decoder_layer(
 | 
			
		||||
                    hidden_states,
 | 
			
		||||
                    attention_mask=attention_mask,
 | 
			
		||||
                    position_ids=position_ids,
 | 
			
		||||
                    past_key_value=past_key_values,
 | 
			
		||||
                    output_attentions=output_attentions,
 | 
			
		||||
                    output_router_logits=output_router_logits,
 | 
			
		||||
                    use_cache=use_cache,
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
            hidden_states = layer_outputs[0]
 | 
			
		||||
 | 
			
		||||
            if use_cache:
 | 
			
		||||
                next_decoder_cache = layer_outputs[2 if output_attentions else 1]
 | 
			
		||||
 | 
			
		||||
            if output_attentions:
 | 
			
		||||
                all_self_attns += (layer_outputs[1],)
 | 
			
		||||
 | 
			
		||||
            if output_router_logits and layer_outputs[-1] is not None:
 | 
			
		||||
                all_router_logits += (layer_outputs[-1],)
 | 
			
		||||
 | 
			
		||||
        hidden_states = self.norm(hidden_states)
 | 
			
		||||
 | 
			
		||||
        # add hidden states from the last decoder layer
 | 
			
		||||
        if output_hidden_states:
 | 
			
		||||
            all_hidden_states += (hidden_states,)
 | 
			
		||||
 | 
			
		||||
        next_cache = None
 | 
			
		||||
        if use_cache:
 | 
			
		||||
            next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache \
 | 
			
		||||
                else next_decoder_cache
 | 
			
		||||
 | 
			
		||||
        if not return_dict:
 | 
			
		||||
            return tuple(
 | 
			
		||||
                v
 | 
			
		||||
                for v in [hidden_states, next_cache, all_hidden_states, all_self_attns,
 | 
			
		||||
                          all_router_logits] if v is not None
 | 
			
		||||
            )
 | 
			
		||||
        return MoeModelOutputWithPast(
 | 
			
		||||
            last_hidden_state=hidden_states,
 | 
			
		||||
            past_key_values=next_cache,
 | 
			
		||||
            hidden_states=all_hidden_states,
 | 
			
		||||
            attentions=all_self_attns,
 | 
			
		||||
            router_logits=all_router_logits,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def qwen2moe_attention_forward(
 | 
			
		||||
    self,
 | 
			
		||||
    hidden_states: torch.Tensor,
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue