LLM: Add Pipeline-Parallel-FastAPI example (#10917)
Add multi-stage Pipeline-Parallel-FastAPI example --------- Co-authored-by: hzjane <a1015616934@qq.com>
This commit is contained in:
		
							parent
							
								
									d550af957a
								
							
						
					
					
						commit
						5c8ccf0ba9
					
				
					 5 changed files with 1029 additions and 0 deletions
				
			
		
							
								
								
									
										33
									
								
								python/llm/example/GPU/Pipeline-Parallel-FastAPI/README.md
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										33
									
								
								python/llm/example/GPU/Pipeline-Parallel-FastAPI/README.md
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
				
			
			@ -0,0 +1,33 @@
 | 
			
		|||
# Serve IPEX-LLM on Multiple Intel GPUs in multi-stage pipeline parallel fashion
 | 
			
		||||
 | 
			
		||||
This example demonstrates how to run IPEX-LLM serving on multiple [Intel GPUs](../README.md) with Pipeline Parallel.
 | 
			
		||||
 | 
			
		||||
## Requirements
 | 
			
		||||
 | 
			
		||||
To run this example with IPEX-LLM on Intel GPUs, we have some recommended requirements for your machine, please refer to [here](../README.md#recommended-requirements) for more information. For this particular example, you will need at least two GPUs on your machine.
 | 
			
		||||
 | 
			
		||||
## Example
 | 
			
		||||
 | 
			
		||||
### 1. Install
 | 
			
		||||
 | 
			
		||||
```bash
 | 
			
		||||
conda create -n llm python=3.11
 | 
			
		||||
conda activate llm
 | 
			
		||||
# below command will install intel_extension_for_pytorch==2.1.10+xpu as default
 | 
			
		||||
pip install --pre --upgrade ipex-llm[xpu] --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/
 | 
			
		||||
pip install oneccl_bind_pt==2.1.100 --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/
 | 
			
		||||
# configures OneAPI environment variables
 | 
			
		||||
source /opt/intel/oneapi/setvars.sh
 | 
			
		||||
# pip install git+https://github.com/microsoft/DeepSpeed.git@ed8aed5
 | 
			
		||||
# pip install git+https://github.com/intel/intel-extension-for-deepspeed.git@0eb734b
 | 
			
		||||
pip install mpi4py fastapi uvicorn
 | 
			
		||||
conda install -c conda-forge -y gperftools=2.10 # to enable tcmalloc
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
### 2. Run pipeline parallel serving on multiple GPUs
 | 
			
		||||
 | 
			
		||||
```bash
 | 
			
		||||
# Need to set MODEL_PATH in run.sh first
 | 
			
		||||
bash run.sh
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										327
									
								
								python/llm/example/GPU/Pipeline-Parallel-FastAPI/llama_models.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										327
									
								
								python/llm/example/GPU/Pipeline-Parallel-FastAPI/llama_models.py
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
				
			
			@ -0,0 +1,327 @@
 | 
			
		|||
from transformers.modeling_utils import PreTrainedModel
 | 
			
		||||
from transformers.models.llama.modeling_llama import LlamaConfig, LlamaDecoderLayer, LlamaRMSNorm, LlamaPreTrainedModel
 | 
			
		||||
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
 | 
			
		||||
 | 
			
		||||
from torch import nn
 | 
			
		||||
import torch
 | 
			
		||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
 | 
			
		||||
from typing import List, Optional, Tuple, Union, Iterator
 | 
			
		||||
from transformers.utils import logging
 | 
			
		||||
logger = logging.get_logger(__name__)
 | 
			
		||||
import numpy as np
 | 
			
		||||
import time
 | 
			
		||||
from transformers import AutoTokenizer, AutoConfig
 | 
			
		||||
import torch.distributed as dist
 | 
			
		||||
from pipeline_models import (
 | 
			
		||||
    _make_causal_mask, _expand_mask, DummyLayer, PPConfig,
 | 
			
		||||
    PipelineBaseModel,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class LlamaModel(LlamaPreTrainedModel):
 | 
			
		||||
    """
 | 
			
		||||
    Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`]
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
        config: LlamaConfig
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    def __init__(self, config: LlamaConfig):
 | 
			
		||||
        super().__init__(config)
 | 
			
		||||
        self.config = config
 | 
			
		||||
 | 
			
		||||
        # pp modification
 | 
			
		||||
        self.pp_config = PPConfig(pp_rank=dist.get_rank(), pp_world_size=dist.get_world_size())
 | 
			
		||||
        nr_slices = self.pp_config.pp_world_size
 | 
			
		||||
        # self.config.num_hidden_layers = 8
 | 
			
		||||
        slice_size = (self.config.num_hidden_layers + nr_slices -
 | 
			
		||||
                      1) // nr_slices
 | 
			
		||||
        self.layer_start = slice_size * self.pp_config.pp_rank
 | 
			
		||||
        self.layer_end  = self.layer_start + min(slice_size,
 | 
			
		||||
                                    self.config.num_hidden_layers - self.layer_start)
 | 
			
		||||
        self.num_layers = self.layer_end - self.layer_start
 | 
			
		||||
        layers = []
 | 
			
		||||
        for i in range(self.config.num_hidden_layers):
 | 
			
		||||
            if i < self.layer_start or i >= self.layer_end:
 | 
			
		||||
                layers.append(DummyLayer())
 | 
			
		||||
            else:
 | 
			
		||||
                layers.append(LlamaDecoderLayer(config))
 | 
			
		||||
        self.layers = nn.ModuleList(layers)
 | 
			
		||||
 | 
			
		||||
        self.padding_idx = config.pad_token_id
 | 
			
		||||
        self.vocab_size = config.vocab_size
 | 
			
		||||
        self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
 | 
			
		||||
        self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    def get_input_embeddings(self):
 | 
			
		||||
        return self.embed_tokens
 | 
			
		||||
 | 
			
		||||
    def set_input_embeddings(self, value):
 | 
			
		||||
        self.embed_tokens = value
 | 
			
		||||
 | 
			
		||||
    # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
 | 
			
		||||
    def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
 | 
			
		||||
        # create causal mask
 | 
			
		||||
        # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
 | 
			
		||||
        combined_attention_mask = None
 | 
			
		||||
        if input_shape[-1] > 1:
 | 
			
		||||
            combined_attention_mask = _make_causal_mask(
 | 
			
		||||
                input_shape,
 | 
			
		||||
                inputs_embeds.dtype,
 | 
			
		||||
                device=inputs_embeds.device,
 | 
			
		||||
                past_key_values_length=past_key_values_length,
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        if attention_mask is not None:
 | 
			
		||||
            # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
 | 
			
		||||
            expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(
 | 
			
		||||
                inputs_embeds.device
 | 
			
		||||
            )
 | 
			
		||||
            combined_attention_mask = (
 | 
			
		||||
                expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        return combined_attention_mask
 | 
			
		||||
 | 
			
		||||
    def forward(
 | 
			
		||||
        self,
 | 
			
		||||
        input_ids: torch.LongTensor = None,
 | 
			
		||||
        attention_mask: Optional[torch.Tensor] = None,
 | 
			
		||||
        position_ids: Optional[torch.LongTensor] = None,
 | 
			
		||||
        past_key_values: Optional[List[torch.FloatTensor]] = None,
 | 
			
		||||
        inputs_embeds: Optional[torch.FloatTensor] = None,
 | 
			
		||||
        use_cache: Optional[bool] = None,
 | 
			
		||||
        output_attentions: Optional[bool] = None,
 | 
			
		||||
        output_hidden_states: Optional[bool] = None,
 | 
			
		||||
        return_dict: Optional[bool] = None,
 | 
			
		||||
    ) -> Union[Tuple, BaseModelOutputWithPast]:
 | 
			
		||||
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
 | 
			
		||||
        output_hidden_states = (
 | 
			
		||||
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
 | 
			
		||||
        )
 | 
			
		||||
        use_cache = use_cache if use_cache is not None else self.config.use_cache
 | 
			
		||||
 | 
			
		||||
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
 | 
			
		||||
 | 
			
		||||
        # retrieve input_ids and inputs_embeds for pp
 | 
			
		||||
        if input_ids is not None and inputs_embeds is not None:
 | 
			
		||||
            raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
 | 
			
		||||
        elif input_ids is not None:
 | 
			
		||||
            assert self.pp_config.is_head, "input_ids is only supported on the head stage"
 | 
			
		||||
            batch_size, seq_length = input_ids.shape
 | 
			
		||||
        elif inputs_embeds is not None:
 | 
			
		||||
            assert not self.pp_config.is_head, "inputs_embeds is only supported on the tail stage"
 | 
			
		||||
            batch_size, seq_length, _ = inputs_embeds.shape
 | 
			
		||||
        else:
 | 
			
		||||
            raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
 | 
			
		||||
 | 
			
		||||
        seq_length_with_past = seq_length
 | 
			
		||||
        past_key_values_length = 0
 | 
			
		||||
 | 
			
		||||
        if past_key_values is not None:
 | 
			
		||||
            past_key_values_length = past_key_values[0][0].shape[2]
 | 
			
		||||
            seq_length_with_past = seq_length_with_past + past_key_values_length
 | 
			
		||||
 | 
			
		||||
        if position_ids is None:
 | 
			
		||||
            device = input_ids.device if input_ids is not None else inputs_embeds.device
 | 
			
		||||
            position_ids = torch.arange(
 | 
			
		||||
                past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
 | 
			
		||||
            )
 | 
			
		||||
            position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
 | 
			
		||||
        else:
 | 
			
		||||
            position_ids = position_ids.view(-1, seq_length).long()
 | 
			
		||||
 | 
			
		||||
        if inputs_embeds is None:
 | 
			
		||||
            inputs_embeds = self.embed_tokens(input_ids)
 | 
			
		||||
        # embed positions
 | 
			
		||||
        if attention_mask is None:
 | 
			
		||||
            attention_mask = torch.ones(
 | 
			
		||||
                (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
 | 
			
		||||
            )
 | 
			
		||||
        attention_mask = self._prepare_decoder_attention_mask(
 | 
			
		||||
            attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        hidden_states = inputs_embeds
 | 
			
		||||
 | 
			
		||||
        # decoder layers
 | 
			
		||||
        all_hidden_states = () if output_hidden_states else None
 | 
			
		||||
        all_self_attns = () if output_attentions else None
 | 
			
		||||
        next_decoder_cache = () if use_cache else None
 | 
			
		||||
 | 
			
		||||
        for idx in range(self.num_layers):
 | 
			
		||||
            decoder_layer = self.layers[self.layer_start + idx]
 | 
			
		||||
            if output_hidden_states:
 | 
			
		||||
                all_hidden_states += (hidden_states,)
 | 
			
		||||
 | 
			
		||||
            past_key_value = past_key_values[idx] if past_key_values is not None else None
 | 
			
		||||
 | 
			
		||||
            layer_outputs = decoder_layer(
 | 
			
		||||
                hidden_states,
 | 
			
		||||
                attention_mask=attention_mask,
 | 
			
		||||
                position_ids=position_ids,
 | 
			
		||||
                past_key_value=past_key_value,
 | 
			
		||||
                output_attentions=output_attentions,
 | 
			
		||||
                use_cache=use_cache,
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
            hidden_states = layer_outputs[0]
 | 
			
		||||
 | 
			
		||||
            if use_cache:
 | 
			
		||||
                next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
 | 
			
		||||
 | 
			
		||||
            if output_attentions:
 | 
			
		||||
                all_self_attns += (layer_outputs[1],)
 | 
			
		||||
        
 | 
			
		||||
        if self.pp_config.is_tail:
 | 
			
		||||
            hidden_states = self.norm(hidden_states)
 | 
			
		||||
 | 
			
		||||
        # add hidden states from the last decoder layer
 | 
			
		||||
        if output_hidden_states:
 | 
			
		||||
            all_hidden_states += (hidden_states,)
 | 
			
		||||
 | 
			
		||||
        next_cache = next_decoder_cache if use_cache else None
 | 
			
		||||
        if not return_dict:
 | 
			
		||||
            return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
 | 
			
		||||
        return BaseModelOutputWithPast(
 | 
			
		||||
            last_hidden_state=hidden_states,
 | 
			
		||||
            past_key_values=next_cache,
 | 
			
		||||
            hidden_states=all_hidden_states,
 | 
			
		||||
            attentions=all_self_attns,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class LlamaForCausalLM(LlamaPreTrainedModel):
 | 
			
		||||
 | 
			
		||||
    def __init__(self, config: LlamaConfig):
 | 
			
		||||
        super().__init__(config=config)
 | 
			
		||||
        self.config = config
 | 
			
		||||
        self.pp_config = PPConfig(pp_rank=dist.get_rank(), pp_world_size=dist.get_world_size())
 | 
			
		||||
        self.model = LlamaModel(config)
 | 
			
		||||
        self.pretraining_tp = config.pretraining_tp
 | 
			
		||||
        self.vocab_size = config.vocab_size
 | 
			
		||||
        if self.pp_config.is_tail:
 | 
			
		||||
            self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
 | 
			
		||||
 | 
			
		||||
    def get_input_embeddings(self):
 | 
			
		||||
        return self.model.embed_tokens
 | 
			
		||||
 | 
			
		||||
    def set_input_embeddings(self, value):
 | 
			
		||||
        self.model.embed_tokens = value
 | 
			
		||||
 | 
			
		||||
    def get_output_embeddings(self):
 | 
			
		||||
        return self.lm_head
 | 
			
		||||
 | 
			
		||||
    def set_output_embeddings(self, new_embeddings):
 | 
			
		||||
        self.lm_head = new_embeddings
 | 
			
		||||
 | 
			
		||||
    def set_decoder(self, decoder):
 | 
			
		||||
        self.model = decoder
 | 
			
		||||
 | 
			
		||||
    def get_decoder(self):
 | 
			
		||||
        return self.model
 | 
			
		||||
 | 
			
		||||
    def forward(
 | 
			
		||||
        self,
 | 
			
		||||
        input_ids: torch.LongTensor = None,
 | 
			
		||||
        attention_mask: Optional[torch.Tensor] = None,
 | 
			
		||||
        position_ids: Optional[torch.LongTensor] = None,
 | 
			
		||||
        past_key_values: Optional[List[torch.FloatTensor]] = None,
 | 
			
		||||
        inputs_embeds: Optional[torch.FloatTensor] = None,
 | 
			
		||||
        labels: Optional[torch.LongTensor] = None,
 | 
			
		||||
        use_cache: Optional[bool] = None,
 | 
			
		||||
        output_attentions: Optional[bool] = None,
 | 
			
		||||
        output_hidden_states: Optional[bool] = None,
 | 
			
		||||
        return_dict: Optional[bool] = None,
 | 
			
		||||
    ) -> Union[Tuple, CausalLMOutputWithPast]:
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
 | 
			
		||||
        output_hidden_states = (
 | 
			
		||||
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
 | 
			
		||||
        )
 | 
			
		||||
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
 | 
			
		||||
 | 
			
		||||
        # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
 | 
			
		||||
        outputs = self.model(
 | 
			
		||||
            input_ids=input_ids,
 | 
			
		||||
            attention_mask=attention_mask,
 | 
			
		||||
            position_ids=position_ids,
 | 
			
		||||
            past_key_values=past_key_values,
 | 
			
		||||
            inputs_embeds=inputs_embeds,
 | 
			
		||||
            use_cache=use_cache,
 | 
			
		||||
            output_attentions=output_attentions,
 | 
			
		||||
            output_hidden_states=output_hidden_states,
 | 
			
		||||
            return_dict=return_dict,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        if self.pp_config.is_tail:
 | 
			
		||||
            hidden_states = outputs[0]
 | 
			
		||||
            logits = self.lm_head(hidden_states)
 | 
			
		||||
            logits = logits.float()
 | 
			
		||||
 | 
			
		||||
            loss = None
 | 
			
		||||
            if labels is not None:
 | 
			
		||||
                # Shift so that tokens < n predict n
 | 
			
		||||
                shift_logits = logits[..., :-1, :].contiguous()
 | 
			
		||||
                shift_labels = labels[..., 1:].contiguous()
 | 
			
		||||
                # Flatten the tokens
 | 
			
		||||
                loss_fct = CrossEntropyLoss()
 | 
			
		||||
                shift_logits = shift_logits.view(-1, self.config.vocab_size)
 | 
			
		||||
                shift_labels = shift_labels.view(-1)
 | 
			
		||||
                # Enable model parallelism
 | 
			
		||||
                shift_labels = shift_labels.to(shift_logits.device)
 | 
			
		||||
                loss = loss_fct(shift_logits, shift_labels)
 | 
			
		||||
 | 
			
		||||
            if not return_dict:
 | 
			
		||||
                output = (logits,) + outputs[1:]
 | 
			
		||||
                return (loss,) + output if loss is not None else output
 | 
			
		||||
 | 
			
		||||
            return CausalLMOutputWithPast(
 | 
			
		||||
                loss=loss,
 | 
			
		||||
                logits=logits,
 | 
			
		||||
                past_key_values=outputs.past_key_values,
 | 
			
		||||
                hidden_states=outputs.hidden_states,
 | 
			
		||||
                attentions=outputs.attentions,
 | 
			
		||||
            )
 | 
			
		||||
        return outputs
 | 
			
		||||
 | 
			
		||||
    def prepare_inputs_for_generation(
 | 
			
		||||
        self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
 | 
			
		||||
    ):
 | 
			
		||||
        if past_key_values:
 | 
			
		||||
            input_ids = input_ids[:, -1:]
 | 
			
		||||
 | 
			
		||||
        position_ids = kwargs.get("position_ids", None)
 | 
			
		||||
        if attention_mask is not None and position_ids is None:
 | 
			
		||||
            # create position_ids on the fly for batch generation
 | 
			
		||||
            position_ids = attention_mask.long().cumsum(-1) - 1
 | 
			
		||||
            position_ids.masked_fill_(attention_mask == 0, 1)
 | 
			
		||||
            if past_key_values:
 | 
			
		||||
                position_ids = position_ids[:, -1].unsqueeze(-1)
 | 
			
		||||
 | 
			
		||||
        # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
 | 
			
		||||
        if inputs_embeds is not None and past_key_values is None:
 | 
			
		||||
            model_inputs = {"inputs_embeds": inputs_embeds}
 | 
			
		||||
        else:
 | 
			
		||||
            model_inputs = {"input_ids": input_ids}
 | 
			
		||||
 | 
			
		||||
        model_inputs.update(
 | 
			
		||||
            {
 | 
			
		||||
                "position_ids": position_ids,
 | 
			
		||||
                "past_key_values": past_key_values,
 | 
			
		||||
                "use_cache": kwargs.get("use_cache"),
 | 
			
		||||
                "attention_mask": attention_mask,
 | 
			
		||||
            }
 | 
			
		||||
        )
 | 
			
		||||
        return model_inputs
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def _reorder_cache(past_key_values, beam_idx):
 | 
			
		||||
        reordered_past = ()
 | 
			
		||||
        for layer_past in past_key_values:
 | 
			
		||||
            reordered_past += (
 | 
			
		||||
                tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
 | 
			
		||||
            )
 | 
			
		||||
        return reordered_past
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,510 @@
 | 
			
		|||
from torch import nn
 | 
			
		||||
import torch
 | 
			
		||||
import torch.distributed as dist
 | 
			
		||||
import intel_extension_for_pytorch as ipex
 | 
			
		||||
 | 
			
		||||
from typing import List, Optional, Tuple, Union, Iterator
 | 
			
		||||
import time
 | 
			
		||||
from transformers import AutoTokenizer, AutoConfig
 | 
			
		||||
from transformers.utils import logging
 | 
			
		||||
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
 | 
			
		||||
import numpy as np
 | 
			
		||||
import asyncio, uuid
 | 
			
		||||
import threading
 | 
			
		||||
 | 
			
		||||
logger = logging.get_logger(__name__)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class PPConfig:
 | 
			
		||||
    """Configuration for ModelSlices."""
 | 
			
		||||
 | 
			
		||||
    def __init__(self, pp_rank: int, pp_world_size: int) -> None:
 | 
			
		||||
        self.pp_rank = pp_rank
 | 
			
		||||
        self.pp_world_size = pp_world_size
 | 
			
		||||
        self.is_head = self.pp_rank == 0
 | 
			
		||||
        self.is_tail = self.pp_rank == self.pp_world_size - 1
 | 
			
		||||
 | 
			
		||||
# Copied from transformers.models.bart.modeling_bart._make_causal_mask
 | 
			
		||||
def _make_causal_mask(
 | 
			
		||||
    input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
 | 
			
		||||
):
 | 
			
		||||
    """
 | 
			
		||||
    Make causal mask used for bi-directional self-attention.
 | 
			
		||||
    """
 | 
			
		||||
    bsz, tgt_len = input_ids_shape
 | 
			
		||||
    mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
 | 
			
		||||
    mask_cond = torch.arange(mask.size(-1), device=device)
 | 
			
		||||
    mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
 | 
			
		||||
    mask = mask.to(dtype)
 | 
			
		||||
 | 
			
		||||
    if past_key_values_length > 0:
 | 
			
		||||
        mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
 | 
			
		||||
    return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
 | 
			
		||||
 | 
			
		||||
# Copied from transformers.models.bart.modeling_bart._expand_mask
 | 
			
		||||
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
 | 
			
		||||
    """
 | 
			
		||||
    Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
 | 
			
		||||
    """
 | 
			
		||||
    bsz, src_len = mask.size()
 | 
			
		||||
    tgt_len = tgt_len if tgt_len is not None else src_len
 | 
			
		||||
 | 
			
		||||
    expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
 | 
			
		||||
 | 
			
		||||
    inverted_mask = 1.0 - expanded_mask
 | 
			
		||||
 | 
			
		||||
    return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
 | 
			
		||||
    # create causal mask
 | 
			
		||||
    # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
 | 
			
		||||
    combined_attention_mask = None
 | 
			
		||||
    if input_shape[-1] > 1:
 | 
			
		||||
        combined_attention_mask = _make_causal_mask(
 | 
			
		||||
            input_shape,
 | 
			
		||||
            inputs_embeds.dtype,
 | 
			
		||||
            device=inputs_embeds.device,
 | 
			
		||||
            past_key_values_length=past_key_values_length,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    if attention_mask is not None:
 | 
			
		||||
        # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
 | 
			
		||||
        expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(
 | 
			
		||||
            inputs_embeds.device
 | 
			
		||||
        )
 | 
			
		||||
        combined_attention_mask = (
 | 
			
		||||
            expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    return combined_attention_mask
 | 
			
		||||
 | 
			
		||||
class DummyLayer(nn.Module):
 | 
			
		||||
    pass
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class PipelineBaseModel(nn.Module):
 | 
			
		||||
    def __init__(self, config):
 | 
			
		||||
        self.pp_config = PPConfig(pp_rank=dist.get_rank(), pp_world_size=dist.get_world_size())
 | 
			
		||||
        nr_slices = self.pp_config.pp_world_size
 | 
			
		||||
        # self.config.num_hidden_layers = 8
 | 
			
		||||
        slice_size = (self.config.num_hidden_layers + nr_slices -
 | 
			
		||||
                      1) // nr_slices
 | 
			
		||||
        self.layer_start = slice_size * self.pp_config.pp_rank
 | 
			
		||||
        self.layer_end  = self.layer_start + min(slice_size,
 | 
			
		||||
                                    self.config.num_hidden_layers - self.layer_start)
 | 
			
		||||
        self.num_layers = self.layer_end - self.layer_start
 | 
			
		||||
 | 
			
		||||
    # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
 | 
			
		||||
    def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
 | 
			
		||||
        # create causal mask
 | 
			
		||||
        # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
 | 
			
		||||
        combined_attention_mask = None
 | 
			
		||||
        if input_shape[-1] > 1:
 | 
			
		||||
            combined_attention_mask = _make_causal_mask(
 | 
			
		||||
                input_shape,
 | 
			
		||||
                inputs_embeds.dtype,
 | 
			
		||||
                device=inputs_embeds.device,
 | 
			
		||||
                past_key_values_length=past_key_values_length,
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        if attention_mask is not None:
 | 
			
		||||
            # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
 | 
			
		||||
            expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(
 | 
			
		||||
                inputs_embeds.device
 | 
			
		||||
            )
 | 
			
		||||
            combined_attention_mask = (
 | 
			
		||||
                expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        return combined_attention_mask
 | 
			
		||||
    
 | 
			
		||||
    def forward(
 | 
			
		||||
        self,
 | 
			
		||||
        input_ids: torch.LongTensor = None,
 | 
			
		||||
        attention_mask: Optional[torch.Tensor] = None,
 | 
			
		||||
        position_ids: Optional[torch.LongTensor] = None,
 | 
			
		||||
        past_key_values: Optional[List[torch.FloatTensor]] = None,
 | 
			
		||||
        inputs_embeds: Optional[torch.FloatTensor] = None,
 | 
			
		||||
        use_cache: Optional[bool] = None,
 | 
			
		||||
        output_attentions: Optional[bool] = None,
 | 
			
		||||
        output_hidden_states: Optional[bool] = None,
 | 
			
		||||
        return_dict: Optional[bool] = None,
 | 
			
		||||
    ) -> Union[Tuple, BaseModelOutputWithPast]:
 | 
			
		||||
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
 | 
			
		||||
        output_hidden_states = (
 | 
			
		||||
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
 | 
			
		||||
        )
 | 
			
		||||
        use_cache = use_cache if use_cache is not None else self.config.use_cache
 | 
			
		||||
 | 
			
		||||
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
 | 
			
		||||
 | 
			
		||||
        # retrieve input_ids and inputs_embeds
 | 
			
		||||
        if input_ids is not None and inputs_embeds is not None:
 | 
			
		||||
            raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
 | 
			
		||||
        elif input_ids is not None:
 | 
			
		||||
            assert self.pp_config.is_head, "input_ids is only supported on the head stage"
 | 
			
		||||
            batch_size, seq_length = input_ids.shape
 | 
			
		||||
        elif inputs_embeds is not None:
 | 
			
		||||
            assert not self.pp_config.is_head, "inputs_embeds is only supported on the tail stage"
 | 
			
		||||
            batch_size, seq_length, _ = inputs_embeds.shape
 | 
			
		||||
        else:
 | 
			
		||||
            raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
 | 
			
		||||
 | 
			
		||||
        seq_length_with_past = seq_length
 | 
			
		||||
        past_key_values_length = 0
 | 
			
		||||
 | 
			
		||||
        if past_key_values is not None:
 | 
			
		||||
            past_key_values_length = past_key_values[0][0].shape[2]
 | 
			
		||||
            seq_length_with_past = seq_length_with_past + past_key_values_length
 | 
			
		||||
 | 
			
		||||
        if position_ids is None:
 | 
			
		||||
            device = input_ids.device if input_ids is not None else inputs_embeds.device
 | 
			
		||||
            position_ids = torch.arange(
 | 
			
		||||
                past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
 | 
			
		||||
            )
 | 
			
		||||
            position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
 | 
			
		||||
        else:
 | 
			
		||||
            position_ids = position_ids.view(-1, seq_length).long()
 | 
			
		||||
 | 
			
		||||
        if inputs_embeds is None:
 | 
			
		||||
            inputs_embeds = self.embed_tokens(input_ids)
 | 
			
		||||
        # embed positions
 | 
			
		||||
        if attention_mask is None:
 | 
			
		||||
            attention_mask = torch.ones(
 | 
			
		||||
                (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
 | 
			
		||||
            )
 | 
			
		||||
        attention_mask = self._prepare_decoder_attention_mask(
 | 
			
		||||
            attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        hidden_states = inputs_embeds
 | 
			
		||||
 | 
			
		||||
        # decoder layers
 | 
			
		||||
        all_hidden_states = () if output_hidden_states else None
 | 
			
		||||
        all_self_attns = () if output_attentions else None
 | 
			
		||||
        next_decoder_cache = () if use_cache else None
 | 
			
		||||
 | 
			
		||||
        for idx in range(self.num_layers):
 | 
			
		||||
            decoder_layer = self.layers[self.layer_start + idx]
 | 
			
		||||
            if output_hidden_states:
 | 
			
		||||
                all_hidden_states += (hidden_states,)
 | 
			
		||||
 | 
			
		||||
            past_key_value = past_key_values[idx] if past_key_values is not None else None
 | 
			
		||||
 | 
			
		||||
            layer_outputs = decoder_layer(
 | 
			
		||||
                hidden_states,
 | 
			
		||||
                attention_mask=attention_mask,
 | 
			
		||||
                position_ids=position_ids,
 | 
			
		||||
                past_key_value=past_key_value,
 | 
			
		||||
                output_attentions=output_attentions,
 | 
			
		||||
                use_cache=use_cache,
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
            hidden_states = layer_outputs[0]
 | 
			
		||||
 | 
			
		||||
            if use_cache:
 | 
			
		||||
                next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
 | 
			
		||||
 | 
			
		||||
            if output_attentions:
 | 
			
		||||
                all_self_attns += (layer_outputs[1],)
 | 
			
		||||
        
 | 
			
		||||
        if self.pp_config.is_tail:
 | 
			
		||||
            hidden_states = self.norm(hidden_states)
 | 
			
		||||
 | 
			
		||||
        # add hidden states from the last decoder layer
 | 
			
		||||
        if output_hidden_states:
 | 
			
		||||
            all_hidden_states += (hidden_states,)
 | 
			
		||||
 | 
			
		||||
        next_cache = next_decoder_cache if use_cache else None
 | 
			
		||||
        if not return_dict:
 | 
			
		||||
            return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
 | 
			
		||||
        return BaseModelOutputWithPast(
 | 
			
		||||
            last_hidden_state=hidden_states,
 | 
			
		||||
            past_key_values=next_cache,
 | 
			
		||||
            hidden_states=all_hidden_states,
 | 
			
		||||
            attentions=all_self_attns,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def load_model(checkpoint):
 | 
			
		||||
    from llama_models import LlamaForCausalLM
 | 
			
		||||
    if 'llama' in checkpoint.lower():
 | 
			
		||||
        model = LlamaForCausalLM.from_pretrained(checkpoint, low_cpu_mem_usage=True, torch_dtype=torch.float16)
 | 
			
		||||
    return model
 | 
			
		||||
 | 
			
		||||
from pydantic import BaseModel
 | 
			
		||||
class BatchTask(BaseModel):
 | 
			
		||||
    batch_id: str
 | 
			
		||||
    request_ids: List[str]
 | 
			
		||||
    max_tokens: int
 | 
			
		||||
    batch_size: int
 | 
			
		||||
    input_len: int
 | 
			
		||||
    # plain_texts: List[str]
 | 
			
		||||
    prompt_lengths: List[int]
 | 
			
		||||
    stopped: bool
 | 
			
		||||
    # input_ids: torch.Tensor
 | 
			
		||||
    # attention_mask: torch.Tensor
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def make_attention_mask(prompt_lengths):
 | 
			
		||||
    max_length = max(prompt_lengths)
 | 
			
		||||
    attention_mask = torch.zeros((len(prompt_lengths), max_length), dtype=torch.int64)
 | 
			
		||||
    for i, length in enumerate(prompt_lengths):
 | 
			
		||||
        attention_mask[i, max_length - length:] = 1
 | 
			
		||||
    return attention_mask
 | 
			
		||||
 | 
			
		||||
class ModelRunner:
 | 
			
		||||
    
 | 
			
		||||
    def __init__(self, checkpoint, rank, world_size, low_bit, max_num_seqs):
 | 
			
		||||
        
 | 
			
		||||
        import sys
 | 
			
		||||
        self.pp_config = PPConfig(rank, world_size)
 | 
			
		||||
        
 | 
			
		||||
        start = time.perf_counter()
 | 
			
		||||
        model = load_model(checkpoint)
 | 
			
		||||
        end = time.perf_counter()
 | 
			
		||||
        logger.info(f"Time to load weights: {end - start:.2f}s")
 | 
			
		||||
        from ipex_llm import optimize_model
 | 
			
		||||
 | 
			
		||||
        model = optimize_model(model, low_bit=low_bit)
 | 
			
		||||
        
 | 
			
		||||
        model = model.to(torch.float16).to(f'xpu:{rank}')
 | 
			
		||||
        self.model = model
 | 
			
		||||
        self.rank = rank
 | 
			
		||||
        self.world_size = world_size
 | 
			
		||||
        self.pre_rank = (self.rank - 1) % self.world_size
 | 
			
		||||
        self.next_rank = (self.rank + 1) % self.world_size
 | 
			
		||||
        self.hidden_size = self.model.config.hidden_size
 | 
			
		||||
    
 | 
			
		||||
        self.max_num_seqs = max_num_seqs
 | 
			
		||||
        self.on_going_batches = [None] * self.world_size
 | 
			
		||||
        self.input_ids_dict = {}
 | 
			
		||||
        # self.attention_mask_dict = {}
 | 
			
		||||
        self.past_key_values_dict = {}
 | 
			
		||||
        self.tokens = {}
 | 
			
		||||
        self.token_times = {}
 | 
			
		||||
        self.dtype = torch.float16
 | 
			
		||||
 | 
			
		||||
        self.waiting_requests = asyncio.Queue()
 | 
			
		||||
        self.send_buff = None
 | 
			
		||||
        self.dict_lock = threading.Lock()
 | 
			
		||||
 | 
			
		||||
                
 | 
			
		||||
    # def generate(self, input_ids=None, max_tokens=5, attention_mask=None):
 | 
			
		||||
    #     times = []
 | 
			
		||||
    #     with torch.no_grad():
 | 
			
		||||
    #         _input_ids = None
 | 
			
		||||
    #         _past_key_values = None
 | 
			
		||||
    #         bs = input_ids.shape[0]
 | 
			
		||||
    #         output_ids = input_ids.clone()
 | 
			
		||||
    #         for i in range(max_tokens):
 | 
			
		||||
    #             start = time.perf_counter()
 | 
			
		||||
    #             if _input_ids is None:
 | 
			
		||||
    #                 _input_ids = input_ids
 | 
			
		||||
    #             if self.rank == 0:
 | 
			
		||||
    #                 outputs = self.model(input_ids=_input_ids, attention_mask=attention_mask, past_key_values=_past_key_values, use_cache=True)
 | 
			
		||||
    #             else:
 | 
			
		||||
    #                 inputs_embeds = torch.empty(_input_ids.shape + (self.hidden_size,) , device=f'xpu:{self.rank}', dtype=torch.float32)
 | 
			
		||||
    #                 dist.recv(inputs_embeds, src=self.pre_rank)
 | 
			
		||||
    #                 outputs = self.model(inputs_embeds=inputs_embeds, attention_mask=attention_mask, past_key_values=_past_key_values, use_cache=True)
 | 
			
		||||
                
 | 
			
		||||
    #             if self.rank == self.world_size - 1:
 | 
			
		||||
    #                 logits = outputs.logits
 | 
			
		||||
    #                 next_ids = torch.argmax(logits[:, -1:, :], dim=-1)
 | 
			
		||||
    #                 assert next_ids.shape == (bs, 1)
 | 
			
		||||
    #                 dist.broadcast(next_ids, src=self.rank)
 | 
			
		||||
    #             else:
 | 
			
		||||
    #                 dist.send(outputs.last_hidden_state, dst=self.next_rank)
 | 
			
		||||
    #                 next_ids = torch.empty((bs, 1), device=f'xpu:{self.rank}', dtype=torch.int64)
 | 
			
		||||
    #                 dist.broadcast(next_ids, src=self.world_size - 1)
 | 
			
		||||
                
 | 
			
		||||
    #             _input_ids = next_ids
 | 
			
		||||
    #             output_ids = torch.cat([output_ids, next_ids], dim=-1)
 | 
			
		||||
    #             _past_key_values = outputs.past_key_values
 | 
			
		||||
    #             end = time.perf_counter()
 | 
			
		||||
    #             times.append(end - start)
 | 
			
		||||
        
 | 
			
		||||
    #     if self.rank == 0:
 | 
			
		||||
    #         logger.info(f"first token latency: {times[0]}, rest token avg latecy: {np.mean(times[1:])}")
 | 
			
		||||
    #     return output_ids
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    def model_step(self, input, cur_batch):
 | 
			
		||||
        if cur_batch is None or cur_batch.stopped or input is None:
 | 
			
		||||
            return None
 | 
			
		||||
        
 | 
			
		||||
        cur_id = cur_batch.batch_id
 | 
			
		||||
        _past_key_values = self.past_key_values_dict.get(cur_id, None)
 | 
			
		||||
        # attention_mask = self.attention_mask_dict[cur_id]
 | 
			
		||||
        attention_mask = make_attention_mask(cur_batch.prompt_lengths)
 | 
			
		||||
 | 
			
		||||
        if self.rank == 0:
 | 
			
		||||
            input_ids = input
 | 
			
		||||
            inputs_embeds = None
 | 
			
		||||
        else:
 | 
			
		||||
            input_ids = None
 | 
			
		||||
            inputs_embeds = input
 | 
			
		||||
        output = self.model(
 | 
			
		||||
            input_ids=input_ids, 
 | 
			
		||||
            inputs_embeds=inputs_embeds,
 | 
			
		||||
            attention_mask=attention_mask, 
 | 
			
		||||
            past_key_values=_past_key_values,
 | 
			
		||||
            use_cache=True
 | 
			
		||||
        )
 | 
			
		||||
        self.past_key_values_dict[cur_id] = output.past_key_values
 | 
			
		||||
        if not self.pp_config.is_tail:
 | 
			
		||||
            return output.last_hidden_state
 | 
			
		||||
        else:
 | 
			
		||||
            # logger.info(f"logits: {output.logits.shape}")
 | 
			
		||||
            return output.logits
 | 
			
		||||
 | 
			
		||||
    
 | 
			
		||||
    def is_initialized(self):
 | 
			
		||||
        return True
 | 
			
		||||
    
 | 
			
		||||
    
 | 
			
		||||
    async def add_request(self, tokenizer):
 | 
			
		||||
        request_ids, prompt_requests = [], []
 | 
			
		||||
        for _ in range(self.max_num_seqs):
 | 
			
		||||
            if self.waiting_requests.empty():
 | 
			
		||||
                break
 | 
			
		||||
            
 | 
			
		||||
            tmp_result = await self.waiting_requests.get()
 | 
			
		||||
            # logger.info(tmp_result)
 | 
			
		||||
            request_id, prompt_request = tmp_result
 | 
			
		||||
            request_ids.append(request_id)
 | 
			
		||||
            prompt_requests.append(prompt_request)
 | 
			
		||||
 | 
			
		||||
        plain_texts = [req.prompt for req in prompt_requests]
 | 
			
		||||
        inputs = tokenizer(plain_texts, return_tensors="pt", padding=True)
 | 
			
		||||
        input_ids = inputs.input_ids.to(f'xpu:{self.rank}')
 | 
			
		||||
        attention_mask = inputs.attention_mask.to(f'xpu:{self.rank}')
 | 
			
		||||
        new_batch = BatchTask(
 | 
			
		||||
            batch_id="batch_" + str(uuid.uuid4()),
 | 
			
		||||
            request_ids=request_ids,
 | 
			
		||||
            max_tokens=max([req.n_predict for req in prompt_requests]),
 | 
			
		||||
            batch_size=input_ids.size(0),
 | 
			
		||||
            input_len=input_ids.size(1),
 | 
			
		||||
            prompt_lengths=[sum(attention_mask[i,:]) for i in range(input_ids.size(0))],
 | 
			
		||||
            stopped=False,
 | 
			
		||||
            # plain_texts=plain_texts,
 | 
			
		||||
            # input_ids=input_ids,
 | 
			
		||||
            # attention_mask=attention_mask,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        self.input_ids_dict[new_batch.batch_id] = input_ids
 | 
			
		||||
        self.token_times[new_batch.batch_id] = [time.perf_counter()]
 | 
			
		||||
        # self.attention_mask_dict[new_batch.batch_id] = attention_mask
 | 
			
		||||
 | 
			
		||||
        return new_batch
 | 
			
		||||
 | 
			
		||||
    
 | 
			
		||||
    def clear_batch(self, cur_id):
 | 
			
		||||
        self.input_ids_dict.pop(cur_id, None)
 | 
			
		||||
        self.tokens.pop(cur_id, None)
 | 
			
		||||
        self.token_times.pop(cur_id, None)
 | 
			
		||||
        # self.attention_mask_dict.pop(cur_id, None)
 | 
			
		||||
        self.past_key_values_dict.pop(cur_id, None)
 | 
			
		||||
        # torch.xpu.empty_cache()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    async def process_step(self, tokenizer, result_dict):
 | 
			
		||||
        cur_batch = None
 | 
			
		||||
 | 
			
		||||
        if self.rank == 0:
 | 
			
		||||
            if self.on_going_batches[0] is not None:
 | 
			
		||||
                cur_batch = self.on_going_batches[0]
 | 
			
		||||
                cur_input = None
 | 
			
		||||
            
 | 
			
		||||
            if cur_batch is None:
 | 
			
		||||
                if not self.waiting_requests.empty():
 | 
			
		||||
                    # await asyncio.sleep(0.01)
 | 
			
		||||
                    cur_batch = await self.add_request(tokenizer)
 | 
			
		||||
                    cur_input = self.input_ids_dict[cur_batch.batch_id]
 | 
			
		||||
                else:
 | 
			
		||||
                    cur_batch = None
 | 
			
		||||
                    cur_input = None
 | 
			
		||||
 | 
			
		||||
            if (cur_batch is not None) and (not cur_batch.stopped) and (cur_input is None):
 | 
			
		||||
                cur_id = cur_batch.batch_id
 | 
			
		||||
                next_ids = torch.empty((cur_batch.batch_size, 1,), device=f'xpu:{self.rank}', dtype=torch.int64)
 | 
			
		||||
                # logger.info(f"rank: {self.rank}, recv: {next_ids.shape}")
 | 
			
		||||
                dist.recv(next_ids, src=self.pre_rank)
 | 
			
		||||
                
 | 
			
		||||
                if self.tokens.get(cur_id, None) is None:
 | 
			
		||||
                    self.tokens[cur_id] = []
 | 
			
		||||
 | 
			
		||||
                if len(next_ids.shape) == 1:
 | 
			
		||||
                    next_ids = next_ids.unsqueeze(0)
 | 
			
		||||
                self.tokens[cur_id].append(next_ids)
 | 
			
		||||
                self.token_times[cur_id].append(time.perf_counter())
 | 
			
		||||
                # self.input_ids_dict[cur_id] += next_ids
 | 
			
		||||
                cur_input = next_ids
 | 
			
		||||
                # cur_batch.input_len += 1
 | 
			
		||||
                cur_batch.input_len = 1
 | 
			
		||||
                cur_batch.prompt_lengths = [x + 1 for x in cur_batch.prompt_lengths]
 | 
			
		||||
                if len(self.tokens[cur_id]) >= cur_batch.max_tokens:
 | 
			
		||||
                    # Finish a batch
 | 
			
		||||
                    # logger.info(self.tokens[cur_id])
 | 
			
		||||
                    outputs = torch.cat(self.tokens[cur_id], dim=1)
 | 
			
		||||
                    outputs = outputs.cpu()
 | 
			
		||||
                    output_strs = tokenizer.batch_decode(outputs, skip_special_tokens=False)
 | 
			
		||||
                    for request_id, output_str in zip(cur_batch.request_ids, output_strs):
 | 
			
		||||
                        with self.dict_lock:
 | 
			
		||||
                            result_dict[request_id] = output_str
 | 
			
		||||
 | 
			
		||||
                    cur_times = self.token_times[cur_id]
 | 
			
		||||
                    first_token = cur_times[1] - cur_times[0]
 | 
			
		||||
                    next_token = (cur_times[-1] - cur_times[1]) / (len(self.tokens[cur_id]) - 1)
 | 
			
		||||
                    logger.info(f"First token latency: {first_token}, next token latency: {next_token}")
 | 
			
		||||
                    self.clear_batch(cur_id)
 | 
			
		||||
                    cur_batch.stopped = True
 | 
			
		||||
            else:
 | 
			
		||||
                if (cur_batch is not None) and cur_batch.stopped:
 | 
			
		||||
                    cur_batch = None
 | 
			
		||||
 | 
			
		||||
            if self.send_buff is not None:
 | 
			
		||||
                # logger.info(f"rank: {self.rank}, send: {self.send_buff.shape}")
 | 
			
		||||
                dist.send(self.send_buff, dst=self.next_rank)
 | 
			
		||||
            dist.broadcast_object_list([cur_batch], src=0)
 | 
			
		||||
                
 | 
			
		||||
        else:
 | 
			
		||||
            batch_list = [None]
 | 
			
		||||
            dist.broadcast_object_list(batch_list, src=0)
 | 
			
		||||
 | 
			
		||||
            cur_batch = batch_list[0]
 | 
			
		||||
            cur_input = None
 | 
			
		||||
 | 
			
		||||
            if self.send_buff is not None:
 | 
			
		||||
                # logger.info(f"rank: {self.rank}, send: {self.send_buff.shape}")
 | 
			
		||||
                dist.send(self.send_buff, dst=self.next_rank)
 | 
			
		||||
 | 
			
		||||
            if cur_batch is not None:
 | 
			
		||||
                if cur_batch.stopped:
 | 
			
		||||
                    self.clear_batch(cur_batch.batch_id)
 | 
			
		||||
                else:
 | 
			
		||||
                    cur_len = cur_batch.input_len
 | 
			
		||||
                    cur_input = torch.empty((cur_batch.batch_size, cur_len, self.hidden_size,), device=f'xpu:{self.rank}', dtype=self.dtype)
 | 
			
		||||
                    # logger.info(f"rank: {self.rank}, recv: {cur_input.shape}")
 | 
			
		||||
                    dist.recv(cur_input, src=self.pre_rank)
 | 
			
		||||
 | 
			
		||||
                # if self.attention_mask_dict.get(cur_batch.batch_id, None) is None:
 | 
			
		||||
                #     self.attention_mask_dict[cur_batch.batch_id] = make_attention_mask(cur_batch.prompt_lengths)
 | 
			
		||||
            
 | 
			
		||||
        # if self.rank == 0:
 | 
			
		||||
        #     logger.info(f"rank: {self.rank}, {batch_list}")
 | 
			
		||||
        
 | 
			
		||||
        output = self.model_step(cur_input, cur_batch)
 | 
			
		||||
        if output is not None and self.rank == self.world_size - 1:
 | 
			
		||||
            output = torch.argmax(output[:, -1:, :], dim=-1)
 | 
			
		||||
 | 
			
		||||
        if output is not None:
 | 
			
		||||
            # dist.send(output, dst=self.next_rank)
 | 
			
		||||
            self.send_buff = output
 | 
			
		||||
        else:
 | 
			
		||||
            self.send_buff = None
 | 
			
		||||
        if self.rank == 0:
 | 
			
		||||
            self.on_going_batches[:-1] = self.on_going_batches[1:]
 | 
			
		||||
            self.on_going_batches[self.world_size - 1] = cur_batch
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,148 @@
 | 
			
		|||
from pipeline_models import ModelRunner
 | 
			
		||||
import torch.nn.parallel
 | 
			
		||||
import torch.distributed as dist
 | 
			
		||||
import os
 | 
			
		||||
import intel_extension_for_pytorch as ipex
 | 
			
		||||
 | 
			
		||||
import oneccl_bindings_for_pytorch
 | 
			
		||||
 | 
			
		||||
from transformers.utils import logging
 | 
			
		||||
logger = logging.get_logger(__name__)
 | 
			
		||||
 | 
			
		||||
os.environ['MASTER_ADDR'] = '127.0.0.1'
 | 
			
		||||
os.environ['MASTER_PORT'] = '29501'
 | 
			
		||||
 | 
			
		||||
backend = 'ccl'
 | 
			
		||||
dist.init_process_group(backend)
 | 
			
		||||
my_rank = dist.get_rank()
 | 
			
		||||
my_size = dist.get_world_size()
 | 
			
		||||
device = f"xpu:{my_rank}"
 | 
			
		||||
logger.info(f"rank: {my_rank}, size: {my_size}")
 | 
			
		||||
 | 
			
		||||
import time
 | 
			
		||||
from transformers import AutoTokenizer, AutoConfig, LlamaTokenizer
 | 
			
		||||
from fastapi import FastAPI, HTTPException
 | 
			
		||||
from pydantic import BaseModel
 | 
			
		||||
import uvicorn
 | 
			
		||||
import asyncio, uuid
 | 
			
		||||
from typing import Dict, List, Optional
 | 
			
		||||
import argparse
 | 
			
		||||
 | 
			
		||||
def get_int_from_env(env_keys, default):
 | 
			
		||||
    """Returns the first positive env value found in the `env_keys` list or the default."""
 | 
			
		||||
    for e in env_keys:
 | 
			
		||||
        val = int(os.environ.get(e, -1))
 | 
			
		||||
        if val >= 0:
 | 
			
		||||
            return val
 | 
			
		||||
    return int(default)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class PromptRequest(BaseModel):
 | 
			
		||||
    prompt: str
 | 
			
		||||
    n_predict: int = 32
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
empty_req = PromptRequest(prompt="", n_predict=0)
 | 
			
		||||
 | 
			
		||||
app = FastAPI()
 | 
			
		||||
global tokenizer
 | 
			
		||||
global local_model
 | 
			
		||||
 | 
			
		||||
request_queue: asyncio.Queue = asyncio.Queue()
 | 
			
		||||
result_dict: Dict[str, str] = {}
 | 
			
		||||
local_rank = my_rank
 | 
			
		||||
max_num_seqs = get_int_from_env(["MAX_NUM_SEQS"], "16")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@app.post("/generate/")
 | 
			
		||||
async def generate(prompt_request: PromptRequest):
 | 
			
		||||
    request_id = str(uuid.uuid4())
 | 
			
		||||
    await local_model.waiting_requests.put((request_id, prompt_request))
 | 
			
		||||
    while True:
 | 
			
		||||
        if request_id in result_dict:
 | 
			
		||||
            with local_model.dict_lock:
 | 
			
		||||
                output_str = result_dict[request_id]
 | 
			
		||||
            if len(output_str) == 0:
 | 
			
		||||
                logger.info(f"Why? {request_id}")
 | 
			
		||||
                # await asyncio.sleep(0.1)
 | 
			
		||||
                # continue
 | 
			
		||||
            result_dict.pop(request_id)
 | 
			
		||||
            return {"generated_text": output_str}
 | 
			
		||||
        await asyncio.sleep(0)            
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def generate_text(prompt: List[str], n_predict = 32):
 | 
			
		||||
    while prompt[-1] == "":
 | 
			
		||||
        prompt = prompt[:-1]
 | 
			
		||||
    if isinstance(n_predict, list):
 | 
			
		||||
        n_predict = max(n_predict)
 | 
			
		||||
        
 | 
			
		||||
    inputs = tokenizer(prompt, return_tensors="pt", padding=True)
 | 
			
		||||
    input_ids = inputs.input_ids.to(f'xpu:{local_rank}')
 | 
			
		||||
    print(inputs)
 | 
			
		||||
    attention_mask = inputs.attention_mask.to(f'xpu:{local_rank}')
 | 
			
		||||
    output = local_model.generate(input_ids,
 | 
			
		||||
                                  max_tokens=n_predict,
 | 
			
		||||
                            # attention_mask=attention_mask,
 | 
			
		||||
                            # max_new_tokens=n_predict,
 | 
			
		||||
                            # min_new_tokens=n_predict,
 | 
			
		||||
                            # do_sample=False,
 | 
			
		||||
                            # use_cache=True
 | 
			
		||||
                            )
 | 
			
		||||
    torch.xpu.synchronize()
 | 
			
		||||
 | 
			
		||||
    return output
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
async def process_requests(local_model, result_dict):
 | 
			
		||||
    while True:
 | 
			
		||||
        await asyncio.sleep(0)
 | 
			
		||||
        await local_model.process_step(tokenizer, result_dict)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@app.on_event("startup")
 | 
			
		||||
async def startup_event():
 | 
			
		||||
    asyncio.create_task(process_requests(local_model, result_dict))
 | 
			
		||||
 | 
			
		||||
async def main():
 | 
			
		||||
    parser = argparse.ArgumentParser(description='Predict Tokens using fastapi by leveraging DeepSpeed-AutoTP')
 | 
			
		||||
    parser.add_argument('--repo-id-or-model-path', type=str, default="meta-llama/Llama-2-7b-chat-hf",
 | 
			
		||||
                        help='The huggingface repo id for the Llama2 (e.g. `meta-llama/Llama-2-7b-chat-hf`, `meta-llama/Llama-2-13b-chat-hf` and `meta-llama/Llama-2-70b-chat-hf`) to be downloaded'
 | 
			
		||||
                             ', or the path to the huggingface checkpoint folder')
 | 
			
		||||
    parser.add_argument('--low-bit', type=str, default='sym_int4',
 | 
			
		||||
                    help='The quantization type the model will convert to.')
 | 
			
		||||
    parser.add_argument('--port', type=int, default=8000,
 | 
			
		||||
                    help='The port number on which the server will run.')
 | 
			
		||||
    parser.add_argument('--max-num-seqs', type=int, default=8,
 | 
			
		||||
                    help='Max num sequences in a batch.')
 | 
			
		||||
    
 | 
			
		||||
    args = parser.parse_args()
 | 
			
		||||
    model_path = args.repo_id_or_model_path
 | 
			
		||||
    low_bit = args.low_bit
 | 
			
		||||
    max_num_seqs = args.max_num_seqs
 | 
			
		||||
 | 
			
		||||
    # serialize model initialization so that we do not run out of CPU memory
 | 
			
		||||
    for i in range(my_size):
 | 
			
		||||
        if my_rank == i:
 | 
			
		||||
            logger.info("start model initialization")
 | 
			
		||||
            global local_model
 | 
			
		||||
            local_model = ModelRunner(model_path, my_rank, my_size, low_bit, max_num_seqs)
 | 
			
		||||
            logger.info("model initialized")
 | 
			
		||||
        dist.barrier()
 | 
			
		||||
    # Load tokenizer
 | 
			
		||||
    global tokenizer
 | 
			
		||||
    tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True, padding_side='left')
 | 
			
		||||
    if tokenizer.pad_token is None:
 | 
			
		||||
        tokenizer.pad_token = tokenizer.eos_token
 | 
			
		||||
 | 
			
		||||
    if local_rank == 0:
 | 
			
		||||
        config = uvicorn.Config(app=app, host="0.0.0.0", port=args.port)
 | 
			
		||||
        server = uvicorn.Server(config)
 | 
			
		||||
        await server.serve()
 | 
			
		||||
    else:
 | 
			
		||||
        while True:
 | 
			
		||||
            await asyncio.sleep(0)
 | 
			
		||||
            await local_model.process_step(tokenizer, result_dict)
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    asyncio.run(main())
 | 
			
		||||
							
								
								
									
										11
									
								
								python/llm/example/GPU/Pipeline-Parallel-FastAPI/run.sh
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										11
									
								
								python/llm/example/GPU/Pipeline-Parallel-FastAPI/run.sh
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
				
			
			@ -0,0 +1,11 @@
 | 
			
		|||
source /opt/intel/oneapi/setvars.sh
 | 
			
		||||
export no_proxy=localhost
 | 
			
		||||
export FI_PROVIDER=tcp
 | 
			
		||||
export OMP_NUM_THREADS=8
 | 
			
		||||
 | 
			
		||||
export USE_XETLA=OFF
 | 
			
		||||
export SYCL_PI_LEVEL_ZERO_USE_IMMEDIATE_COMMANDLISTS=2
 | 
			
		||||
export TORCH_LLM_ALLREDUCE=0
 | 
			
		||||
 | 
			
		||||
export MODEL_PATH=YOUR_MODEL_PATH
 | 
			
		||||
CCL_ZE_IPC_EXCHANGE=sockets torchrun --standalone --nnodes=1 --nproc-per-node 2 pipeline_serving.py --repo-id-or-model-path $MODEL_PATH --low-bit fp8
 | 
			
		||||
		Loading…
	
		Reference in a new issue