328 lines
		
	
	
	
		
			13 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			328 lines
		
	
	
	
		
			13 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
from transformers.modeling_utils import PreTrainedModel
 | 
						|
from transformers.models.llama.modeling_llama import LlamaConfig, LlamaDecoderLayer, LlamaRMSNorm, LlamaPreTrainedModel
 | 
						|
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
 | 
						|
 | 
						|
from torch import nn
 | 
						|
import torch
 | 
						|
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
 | 
						|
from typing import List, Optional, Tuple, Union, Iterator
 | 
						|
from transformers.utils import logging
 | 
						|
logger = logging.get_logger(__name__)
 | 
						|
import numpy as np
 | 
						|
import time
 | 
						|
from transformers import AutoTokenizer, AutoConfig
 | 
						|
import torch.distributed as dist
 | 
						|
from pipeline_models import (
 | 
						|
    _make_causal_mask, _expand_mask, DummyLayer, PPConfig,
 | 
						|
    PipelineBaseModel,
 | 
						|
)
 | 
						|
 | 
						|
 | 
						|
class LlamaModel(LlamaPreTrainedModel):
 | 
						|
    """
 | 
						|
    Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`]
 | 
						|
 | 
						|
    Args:
 | 
						|
        config: LlamaConfig
 | 
						|
    """
 | 
						|
 | 
						|
    def __init__(self, config: LlamaConfig):
 | 
						|
        super().__init__(config)
 | 
						|
        self.config = config
 | 
						|
 | 
						|
        # pp modification
 | 
						|
        self.pp_config = PPConfig(pp_rank=dist.get_rank(), pp_world_size=dist.get_world_size())
 | 
						|
        nr_slices = self.pp_config.pp_world_size
 | 
						|
        # self.config.num_hidden_layers = 8
 | 
						|
        slice_size = (self.config.num_hidden_layers + nr_slices -
 | 
						|
                      1) // nr_slices
 | 
						|
        self.layer_start = slice_size * self.pp_config.pp_rank
 | 
						|
        self.layer_end  = self.layer_start + min(slice_size,
 | 
						|
                                    self.config.num_hidden_layers - self.layer_start)
 | 
						|
        self.num_layers = self.layer_end - self.layer_start
 | 
						|
        layers = []
 | 
						|
        for i in range(self.config.num_hidden_layers):
 | 
						|
            if i < self.layer_start or i >= self.layer_end:
 | 
						|
                layers.append(DummyLayer())
 | 
						|
            else:
 | 
						|
                layers.append(LlamaDecoderLayer(config))
 | 
						|
        self.layers = nn.ModuleList(layers)
 | 
						|
 | 
						|
        self.padding_idx = config.pad_token_id
 | 
						|
        self.vocab_size = config.vocab_size
 | 
						|
        if self.pp_config.is_head:
 | 
						|
            self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
 | 
						|
        if self.pp_config.is_tail:
 | 
						|
            self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
 | 
						|
 | 
						|
 | 
						|
    def get_input_embeddings(self):
 | 
						|
        return self.embed_tokens
 | 
						|
 | 
						|
    def set_input_embeddings(self, value):
 | 
						|
        self.embed_tokens = value
 | 
						|
 | 
						|
    # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
 | 
						|
    def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
 | 
						|
        # create causal mask
 | 
						|
        # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
 | 
						|
        combined_attention_mask = None
 | 
						|
        if input_shape[-1] > 1:
 | 
						|
            combined_attention_mask = _make_causal_mask(
 | 
						|
                input_shape,
 | 
						|
                inputs_embeds.dtype,
 | 
						|
                device=inputs_embeds.device,
 | 
						|
                past_key_values_length=past_key_values_length,
 | 
						|
            )
 | 
						|
 | 
						|
        if attention_mask is not None:
 | 
						|
            # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
 | 
						|
            expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(
 | 
						|
                inputs_embeds.device
 | 
						|
            )
 | 
						|
            combined_attention_mask = (
 | 
						|
                expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
 | 
						|
            )
 | 
						|
 | 
						|
        return combined_attention_mask
 | 
						|
 | 
						|
    def forward(
 | 
						|
        self,
 | 
						|
        input_ids: torch.LongTensor = None,
 | 
						|
        attention_mask: Optional[torch.Tensor] = None,
 | 
						|
        position_ids: Optional[torch.LongTensor] = None,
 | 
						|
        past_key_values: Optional[List[torch.FloatTensor]] = None,
 | 
						|
        inputs_embeds: Optional[torch.FloatTensor] = None,
 | 
						|
        use_cache: Optional[bool] = None,
 | 
						|
        output_attentions: Optional[bool] = None,
 | 
						|
        output_hidden_states: Optional[bool] = None,
 | 
						|
        return_dict: Optional[bool] = None,
 | 
						|
    ) -> Union[Tuple, BaseModelOutputWithPast]:
 | 
						|
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
 | 
						|
        output_hidden_states = (
 | 
						|
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
 | 
						|
        )
 | 
						|
        use_cache = use_cache if use_cache is not None else self.config.use_cache
 | 
						|
 | 
						|
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
 | 
						|
 | 
						|
        # retrieve input_ids and inputs_embeds for pp
 | 
						|
        if input_ids is not None and inputs_embeds is not None:
 | 
						|
            raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
 | 
						|
        elif input_ids is not None:
 | 
						|
            assert self.pp_config.is_head, "input_ids is only supported on the head stage"
 | 
						|
            batch_size, seq_length = input_ids.shape
 | 
						|
        elif inputs_embeds is not None:
 | 
						|
            assert not self.pp_config.is_head, "inputs_embeds is only supported on the tail stage"
 | 
						|
            batch_size, seq_length, _ = inputs_embeds.shape
 | 
						|
        else:
 | 
						|
            raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
 | 
						|
 | 
						|
        seq_length_with_past = seq_length
 | 
						|
        past_key_values_length = 0
 | 
						|
 | 
						|
        if past_key_values is not None:
 | 
						|
            past_key_values_length = past_key_values[0][0].shape[2]
 | 
						|
            seq_length_with_past = seq_length_with_past + past_key_values_length
 | 
						|
 | 
						|
        if position_ids is None:
 | 
						|
            device = input_ids.device if input_ids is not None else inputs_embeds.device
 | 
						|
            position_ids = torch.arange(
 | 
						|
                past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
 | 
						|
            )
 | 
						|
            position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
 | 
						|
        else:
 | 
						|
            position_ids = position_ids.view(-1, seq_length).long()
 | 
						|
 | 
						|
        if inputs_embeds is None:
 | 
						|
            inputs_embeds = self.embed_tokens(input_ids)
 | 
						|
        # embed positions
 | 
						|
        if attention_mask is None:
 | 
						|
            attention_mask = torch.ones(
 | 
						|
                (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
 | 
						|
            )
 | 
						|
        attention_mask = self._prepare_decoder_attention_mask(
 | 
						|
            attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
 | 
						|
        )
 | 
						|
 | 
						|
        hidden_states = inputs_embeds
 | 
						|
 | 
						|
        # decoder layers
 | 
						|
        all_hidden_states = () if output_hidden_states else None
 | 
						|
        all_self_attns = () if output_attentions else None
 | 
						|
        next_decoder_cache = () if use_cache else None
 | 
						|
 | 
						|
        for idx in range(self.num_layers):
 | 
						|
            decoder_layer = self.layers[self.layer_start + idx]
 | 
						|
            if output_hidden_states:
 | 
						|
                all_hidden_states += (hidden_states,)
 | 
						|
 | 
						|
            past_key_value = past_key_values[idx] if past_key_values is not None else None
 | 
						|
 | 
						|
            layer_outputs = decoder_layer(
 | 
						|
                hidden_states,
 | 
						|
                attention_mask=attention_mask,
 | 
						|
                position_ids=position_ids,
 | 
						|
                past_key_value=past_key_value,
 | 
						|
                output_attentions=output_attentions,
 | 
						|
                use_cache=use_cache,
 | 
						|
            )
 | 
						|
 | 
						|
            hidden_states = layer_outputs[0]
 | 
						|
 | 
						|
            if use_cache:
 | 
						|
                next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
 | 
						|
 | 
						|
            if output_attentions:
 | 
						|
                all_self_attns += (layer_outputs[1],)
 | 
						|
        
 | 
						|
        if self.pp_config.is_tail:
 | 
						|
            hidden_states = self.norm(hidden_states)
 | 
						|
 | 
						|
        # add hidden states from the last decoder layer
 | 
						|
        if output_hidden_states:
 | 
						|
            all_hidden_states += (hidden_states,)
 | 
						|
 | 
						|
        next_cache = next_decoder_cache if use_cache else None
 | 
						|
        if not return_dict:
 | 
						|
            return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
 | 
						|
        return BaseModelOutputWithPast(
 | 
						|
            last_hidden_state=hidden_states,
 | 
						|
            past_key_values=next_cache,
 | 
						|
            hidden_states=all_hidden_states,
 | 
						|
            attentions=all_self_attns,
 | 
						|
        )
 | 
						|
 | 
						|
 | 
						|
class LlamaForCausalLM(LlamaPreTrainedModel):
 | 
						|
 | 
						|
    def __init__(self, config: LlamaConfig):
 | 
						|
        super().__init__(config=config)
 | 
						|
        self.config = config
 | 
						|
        self.pp_config = PPConfig(pp_rank=dist.get_rank(), pp_world_size=dist.get_world_size())
 | 
						|
        self.model = LlamaModel(config)
 | 
						|
        self.pretraining_tp = config.pretraining_tp
 | 
						|
        self.vocab_size = config.vocab_size
 | 
						|
        if self.pp_config.is_tail:
 | 
						|
            self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
 | 
						|
 | 
						|
    def get_input_embeddings(self):
 | 
						|
        return self.model.embed_tokens
 | 
						|
 | 
						|
    def set_input_embeddings(self, value):
 | 
						|
        self.model.embed_tokens = value
 | 
						|
 | 
						|
    def get_output_embeddings(self):
 | 
						|
        return self.lm_head
 | 
						|
 | 
						|
    def set_output_embeddings(self, new_embeddings):
 | 
						|
        self.lm_head = new_embeddings
 | 
						|
 | 
						|
    def set_decoder(self, decoder):
 | 
						|
        self.model = decoder
 | 
						|
 | 
						|
    def get_decoder(self):
 | 
						|
        return self.model
 | 
						|
 | 
						|
    def forward(
 | 
						|
        self,
 | 
						|
        input_ids: torch.LongTensor = None,
 | 
						|
        attention_mask: Optional[torch.Tensor] = None,
 | 
						|
        position_ids: Optional[torch.LongTensor] = None,
 | 
						|
        past_key_values: Optional[List[torch.FloatTensor]] = None,
 | 
						|
        inputs_embeds: Optional[torch.FloatTensor] = None,
 | 
						|
        labels: Optional[torch.LongTensor] = None,
 | 
						|
        use_cache: Optional[bool] = None,
 | 
						|
        output_attentions: Optional[bool] = None,
 | 
						|
        output_hidden_states: Optional[bool] = None,
 | 
						|
        return_dict: Optional[bool] = None,
 | 
						|
    ) -> Union[Tuple, CausalLMOutputWithPast]:
 | 
						|
 | 
						|
 | 
						|
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
 | 
						|
        output_hidden_states = (
 | 
						|
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
 | 
						|
        )
 | 
						|
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
 | 
						|
 | 
						|
        # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
 | 
						|
        outputs = self.model(
 | 
						|
            input_ids=input_ids,
 | 
						|
            attention_mask=attention_mask,
 | 
						|
            position_ids=position_ids,
 | 
						|
            past_key_values=past_key_values,
 | 
						|
            inputs_embeds=inputs_embeds,
 | 
						|
            use_cache=use_cache,
 | 
						|
            output_attentions=output_attentions,
 | 
						|
            output_hidden_states=output_hidden_states,
 | 
						|
            return_dict=return_dict,
 | 
						|
        )
 | 
						|
 | 
						|
        if self.pp_config.is_tail:
 | 
						|
            hidden_states = outputs[0]
 | 
						|
            logits = self.lm_head(hidden_states)
 | 
						|
 | 
						|
            loss = None
 | 
						|
            if labels is not None:
 | 
						|
                # Shift so that tokens < n predict n
 | 
						|
                shift_logits = logits[..., :-1, :].contiguous()
 | 
						|
                shift_labels = labels[..., 1:].contiguous()
 | 
						|
                # Flatten the tokens
 | 
						|
                loss_fct = CrossEntropyLoss()
 | 
						|
                shift_logits = shift_logits.view(-1, self.config.vocab_size)
 | 
						|
                shift_labels = shift_labels.view(-1)
 | 
						|
                # Enable model parallelism
 | 
						|
                shift_labels = shift_labels.to(shift_logits.device)
 | 
						|
                loss = loss_fct(shift_logits, shift_labels)
 | 
						|
 | 
						|
            if not return_dict:
 | 
						|
                output = (logits,) + outputs[1:]
 | 
						|
                return (loss,) + output if loss is not None else output
 | 
						|
 | 
						|
            return CausalLMOutputWithPast(
 | 
						|
                loss=loss,
 | 
						|
                logits=logits,
 | 
						|
                past_key_values=outputs.past_key_values,
 | 
						|
                hidden_states=outputs.hidden_states,
 | 
						|
                attentions=outputs.attentions,
 | 
						|
            )
 | 
						|
        return outputs
 | 
						|
 | 
						|
    def prepare_inputs_for_generation(
 | 
						|
        self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
 | 
						|
    ):
 | 
						|
        if past_key_values:
 | 
						|
            input_ids = input_ids[:, -1:]
 | 
						|
 | 
						|
        position_ids = kwargs.get("position_ids", None)
 | 
						|
        if attention_mask is not None and position_ids is None:
 | 
						|
            # create position_ids on the fly for batch generation
 | 
						|
            position_ids = attention_mask.long().cumsum(-1) - 1
 | 
						|
            position_ids.masked_fill_(attention_mask == 0, 1)
 | 
						|
            if past_key_values:
 | 
						|
                position_ids = position_ids[:, -1].unsqueeze(-1)
 | 
						|
 | 
						|
        # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
 | 
						|
        if inputs_embeds is not None and past_key_values is None:
 | 
						|
            model_inputs = {"inputs_embeds": inputs_embeds}
 | 
						|
        else:
 | 
						|
            model_inputs = {"input_ids": input_ids}
 | 
						|
 | 
						|
        model_inputs.update(
 | 
						|
            {
 | 
						|
                "position_ids": position_ids,
 | 
						|
                "past_key_values": past_key_values,
 | 
						|
                "use_cache": kwargs.get("use_cache"),
 | 
						|
                "attention_mask": attention_mask,
 | 
						|
            }
 | 
						|
        )
 | 
						|
        return model_inputs
 | 
						|
 | 
						|
    @staticmethod
 | 
						|
    def _reorder_cache(past_key_values, beam_idx):
 | 
						|
        reordered_past = ()
 | 
						|
        for layer_past in past_key_values:
 | 
						|
            reordered_past += (
 | 
						|
                tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
 | 
						|
            )
 | 
						|
        return reordered_past
 |