ipex-llm/python/llm/example/GPU/Pipeline-Parallel-FastAPI/pipeline_models.py
Qiyuan Gong ce3f08b25a
Fix IPEX auto importer (#11192)
* Fix ipex auto importer with Python builtins.
* Raise errors if the user imports ipex manually before importing ipex_llm. Do nothing if they import ipex after importing ipex_llm.
* Remove import ipex in examples.
2024-06-04 16:57:18 +08:00

511 lines
21 KiB
Python

from torch import nn
import torch
import torch.distributed as dist
from typing import List, Optional, Tuple, Union, Iterator
import time
from transformers import AutoTokenizer, AutoConfig
from transformers.utils import logging
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
import numpy as np
import asyncio, uuid
import threading
logger = logging.get_logger(__name__)
class PPConfig:
"""Configuration for ModelSlices."""
def __init__(self, pp_rank: int, pp_world_size: int) -> None:
self.pp_rank = pp_rank
self.pp_world_size = pp_world_size
self.is_head = self.pp_rank == 0
self.is_tail = self.pp_rank == self.pp_world_size - 1
# Copied from transformers.models.bart.modeling_bart._make_causal_mask
def _make_causal_mask(
input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
):
"""
Make causal mask used for bi-directional self-attention.
"""
bsz, tgt_len = input_ids_shape
mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
mask_cond = torch.arange(mask.size(-1), device=device)
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
mask = mask.to(dtype)
if past_key_values_length > 0:
mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
# Copied from transformers.models.bart.modeling_bart._expand_mask
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
"""
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
"""
bsz, src_len = mask.size()
tgt_len = tgt_len if tgt_len is not None else src_len
expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
inverted_mask = 1.0 - expanded_mask
return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
# create causal mask
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
combined_attention_mask = None
if input_shape[-1] > 1:
combined_attention_mask = _make_causal_mask(
input_shape,
inputs_embeds.dtype,
device=inputs_embeds.device,
past_key_values_length=past_key_values_length,
)
if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(
inputs_embeds.device
)
combined_attention_mask = (
expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
)
return combined_attention_mask
class DummyLayer(nn.Module):
pass
class PipelineBaseModel(nn.Module):
def __init__(self, config):
self.pp_config = PPConfig(pp_rank=dist.get_rank(), pp_world_size=dist.get_world_size())
nr_slices = self.pp_config.pp_world_size
# self.config.num_hidden_layers = 8
slice_size = (self.config.num_hidden_layers + nr_slices -
1) // nr_slices
self.layer_start = slice_size * self.pp_config.pp_rank
self.layer_end = self.layer_start + min(slice_size,
self.config.num_hidden_layers - self.layer_start)
self.num_layers = self.layer_end - self.layer_start
# Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
# create causal mask
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
combined_attention_mask = None
if input_shape[-1] > 1:
combined_attention_mask = _make_causal_mask(
input_shape,
inputs_embeds.dtype,
device=inputs_embeds.device,
past_key_values_length=past_key_values_length,
)
if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(
inputs_embeds.device
)
combined_attention_mask = (
expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
)
return combined_attention_mask
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# retrieve input_ids and inputs_embeds
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
elif input_ids is not None:
assert self.pp_config.is_head, "input_ids is only supported on the head stage"
batch_size, seq_length = input_ids.shape
elif inputs_embeds is not None:
assert not self.pp_config.is_head, "inputs_embeds is only supported on the tail stage"
batch_size, seq_length, _ = inputs_embeds.shape
else:
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
seq_length_with_past = seq_length
past_key_values_length = 0
if past_key_values is not None:
past_key_values_length = past_key_values[0][0].shape[2]
seq_length_with_past = seq_length_with_past + past_key_values_length
if position_ids is None:
device = input_ids.device if input_ids is not None else inputs_embeds.device
position_ids = torch.arange(
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
)
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
else:
position_ids = position_ids.view(-1, seq_length).long()
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
# embed positions
if attention_mask is None:
attention_mask = torch.ones(
(batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
)
attention_mask = self._prepare_decoder_attention_mask(
attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
)
hidden_states = inputs_embeds
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
next_decoder_cache = () if use_cache else None
for idx in range(self.num_layers):
decoder_layer = self.layers[self.layer_start + idx]
if output_hidden_states:
all_hidden_states += (hidden_states,)
past_key_value = past_key_values[idx] if past_key_values is not None else None
layer_outputs = decoder_layer(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
)
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
if output_attentions:
all_self_attns += (layer_outputs[1],)
if self.pp_config.is_tail:
hidden_states = self.norm(hidden_states)
# add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states,)
next_cache = next_decoder_cache if use_cache else None
if not return_dict:
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)
def load_model(checkpoint):
from llama_models import LlamaForCausalLM
if 'llama' in checkpoint.lower():
model = LlamaForCausalLM.from_pretrained(checkpoint, low_cpu_mem_usage=True, torch_dtype=torch.float16)
return model
from pydantic import BaseModel
class BatchTask(BaseModel):
batch_id: str
request_ids: List[str]
max_tokens: int
batch_size: int
input_len: int
# plain_texts: List[str]
prompt_lengths: List[int]
stopped: bool
# input_ids: torch.Tensor
# attention_mask: torch.Tensor
def make_attention_mask(prompt_lengths):
max_length = max(prompt_lengths)
attention_mask = torch.zeros((len(prompt_lengths), max_length), dtype=torch.int64)
for i, length in enumerate(prompt_lengths):
attention_mask[i, max_length - length:] = 1
return attention_mask
class ModelRunner:
def __init__(self, checkpoint, rank, world_size, low_bit, max_num_seqs):
import sys
self.pp_config = PPConfig(rank, world_size)
start = time.perf_counter()
model = load_model(checkpoint)
end = time.perf_counter()
logger.info(f"Time to load weights: {end - start:.2f}s")
from ipex_llm import optimize_model
model = optimize_model(model, low_bit=low_bit)
model = model.to(torch.float16).to(f'xpu:{rank}')
self.model = model
self.rank = rank
self.world_size = world_size
self.pre_rank = (self.rank - 1) % self.world_size
self.next_rank = (self.rank + 1) % self.world_size
self.hidden_size = self.model.config.hidden_size
self.max_num_seqs = max_num_seqs
self.on_going_batches = [None] * self.world_size
self.input_ids_dict = {}
# self.attention_mask_dict = {}
self.past_key_values_dict = {}
self.tokens = {}
self.token_times = {}
self.dtype = torch.float16
self.waiting_requests = asyncio.Queue()
self.send_buff = None
self.dict_lock = threading.Lock()
# def generate(self, input_ids=None, max_tokens=5, attention_mask=None):
# times = []
# with torch.no_grad():
# _input_ids = None
# _past_key_values = None
# bs = input_ids.shape[0]
# output_ids = input_ids.clone()
# for i in range(max_tokens):
# start = time.perf_counter()
# if _input_ids is None:
# _input_ids = input_ids
# if self.rank == 0:
# outputs = self.model(input_ids=_input_ids, attention_mask=attention_mask, past_key_values=_past_key_values, use_cache=True)
# else:
# inputs_embeds = torch.empty(_input_ids.shape + (self.hidden_size,) , device=f'xpu:{self.rank}', dtype=torch.float32)
# dist.recv(inputs_embeds, src=self.pre_rank)
# outputs = self.model(inputs_embeds=inputs_embeds, attention_mask=attention_mask, past_key_values=_past_key_values, use_cache=True)
# if self.rank == self.world_size - 1:
# logits = outputs.logits
# next_ids = torch.argmax(logits[:, -1:, :], dim=-1)
# assert next_ids.shape == (bs, 1)
# dist.broadcast(next_ids, src=self.rank)
# else:
# dist.send(outputs.last_hidden_state, dst=self.next_rank)
# next_ids = torch.empty((bs, 1), device=f'xpu:{self.rank}', dtype=torch.int64)
# dist.broadcast(next_ids, src=self.world_size - 1)
# _input_ids = next_ids
# output_ids = torch.cat([output_ids, next_ids], dim=-1)
# _past_key_values = outputs.past_key_values
# end = time.perf_counter()
# times.append(end - start)
# if self.rank == 0:
# logger.info(f"first token latency: {times[0]}, rest token avg latecy: {np.mean(times[1:])}")
# return output_ids
def model_step(self, input, cur_batch):
if cur_batch is None or cur_batch.stopped or input is None:
return None
cur_id = cur_batch.batch_id
_past_key_values = self.past_key_values_dict.get(cur_id, None)
# attention_mask = self.attention_mask_dict[cur_id]
attention_mask = make_attention_mask(cur_batch.prompt_lengths)
if self.rank == 0:
input_ids = input
inputs_embeds = None
else:
input_ids = None
inputs_embeds = input
output = self.model(
input_ids=input_ids,
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
past_key_values=_past_key_values,
use_cache=True
)
self.past_key_values_dict[cur_id] = output.past_key_values
if not self.pp_config.is_tail:
return output.last_hidden_state
else:
# logger.info(f"logits: {output.logits.shape}")
return output.logits
def is_initialized(self):
return True
async def add_request(self, tokenizer):
request_ids, prompt_requests = [], []
for _ in range(self.max_num_seqs):
if self.waiting_requests.empty():
break
tmp_result = await self.waiting_requests.get()
# logger.info(tmp_result)
request_id, prompt_request = tmp_result
request_ids.append(request_id)
prompt_requests.append(prompt_request)
plain_texts = [req.prompt for req in prompt_requests]
inputs = tokenizer(plain_texts, return_tensors="pt", padding=True)
input_ids = inputs.input_ids.to(f'xpu:{self.rank}')
attention_mask = inputs.attention_mask.to(f'xpu:{self.rank}')
new_batch = BatchTask(
batch_id="batch_" + str(uuid.uuid4()),
request_ids=request_ids,
max_tokens=max([req.n_predict for req in prompt_requests]),
batch_size=input_ids.size(0),
input_len=input_ids.size(1),
prompt_lengths=[sum(attention_mask[i,:]) for i in range(input_ids.size(0))],
stopped=False,
# plain_texts=plain_texts,
# input_ids=input_ids,
# attention_mask=attention_mask,
)
self.input_ids_dict[new_batch.batch_id] = input_ids
self.token_times[new_batch.batch_id] = [time.perf_counter()]
# self.attention_mask_dict[new_batch.batch_id] = attention_mask
return new_batch
def clear_batch(self, cur_id):
self.input_ids_dict.pop(cur_id, None)
self.tokens.pop(cur_id, None)
self.token_times.pop(cur_id, None)
# self.attention_mask_dict.pop(cur_id, None)
self.past_key_values_dict.pop(cur_id, None)
# torch.xpu.empty_cache()
async def process_step(self, tokenizer, result_dict):
cur_batch = None
if self.rank == 0:
if self.send_buff is not None:
# logger.info(f"rank: {self.rank}, send: {self.send_buff.shape}")
dist.send(self.send_buff, dst=self.next_rank)
if self.on_going_batches[0] is not None:
cur_batch = self.on_going_batches[0]
cur_input = None
if cur_batch is None:
if not self.waiting_requests.empty():
# await asyncio.sleep(0.01)
cur_batch = await self.add_request(tokenizer)
cur_input = self.input_ids_dict[cur_batch.batch_id]
else:
cur_batch = None
cur_input = None
if (cur_batch is not None) and (not cur_batch.stopped) and (cur_input is None):
cur_id = cur_batch.batch_id
next_ids = torch.empty((cur_batch.batch_size, 1,), device=f'xpu:{self.rank}', dtype=torch.int64)
# logger.info(f"rank: {self.rank}, recv: {next_ids.shape}")
dist.recv(next_ids, src=self.pre_rank)
if self.tokens.get(cur_id, None) is None:
self.tokens[cur_id] = []
if len(next_ids.shape) == 1:
next_ids = next_ids.unsqueeze(0)
self.tokens[cur_id].append(next_ids)
self.token_times[cur_id].append(time.perf_counter())
# self.input_ids_dict[cur_id] += next_ids
cur_input = next_ids
# cur_batch.input_len += 1
cur_batch.input_len = 1
cur_batch.prompt_lengths = [x + 1 for x in cur_batch.prompt_lengths]
if len(self.tokens[cur_id]) >= cur_batch.max_tokens:
# Finish a batch
# logger.info(self.tokens[cur_id])
outputs = torch.cat(self.tokens[cur_id], dim=1)
outputs = outputs.cpu()
output_strs = tokenizer.batch_decode(outputs, skip_special_tokens=False)
for request_id, output_str in zip(cur_batch.request_ids, output_strs):
with self.dict_lock:
result_dict[request_id] = output_str
cur_times = self.token_times[cur_id]
first_token = cur_times[1] - cur_times[0]
next_token = (cur_times[-1] - cur_times[1]) / (len(self.tokens[cur_id]) - 1)
logger.info(f"First token latency: {first_token}, next token latency: {next_token}")
self.clear_batch(cur_id)
cur_batch.stopped = True
else:
if (cur_batch is not None) and cur_batch.stopped:
cur_batch = None
if cur_batch is not None:
dist.broadcast_object_list([cur_batch], src=0)
else:
if self.send_buff is not None:
# logger.info(f"rank: {self.rank}, send: {self.send_buff.shape}")
dist.send(self.send_buff, dst=self.next_rank)
batch_list = [None]
dist.broadcast_object_list(batch_list, src=0)
cur_batch = batch_list[0]
cur_input = None
if cur_batch is not None:
if cur_batch.stopped:
self.clear_batch(cur_batch.batch_id)
else:
cur_len = cur_batch.input_len
cur_input = torch.empty((cur_batch.batch_size, cur_len, self.hidden_size,), device=f'xpu:{self.rank}', dtype=self.dtype)
# logger.info(f"rank: {self.rank}, recv: {cur_input.shape}")
dist.recv(cur_input, src=self.pre_rank)
# if self.attention_mask_dict.get(cur_batch.batch_id, None) is None:
# self.attention_mask_dict[cur_batch.batch_id] = make_attention_mask(cur_batch.prompt_lengths)
# if self.rank == 0:
# logger.info(f"rank: {self.rank}, {batch_list}")
output = self.model_step(cur_input, cur_batch)
if output is not None and self.rank == self.world_size - 1:
output = torch.argmax(output[:, -1:, :], dim=-1)
if output is not None:
# dist.send(output, dst=self.next_rank)
self.send_buff = output
else:
self.send_buff = None
if self.rank == 0:
self.on_going_batches[:-1] = self.on_going_batches[1:]
self.on_going_batches[self.world_size - 1] = cur_batch