LLM: Refactor Pipeline-Parallel-FastAPI example (#11319)

Initially Refactor for Pipeline-Parallel-FastAPI example
This commit is contained in:
Xiangyu Tian 2024-06-25 13:30:36 +08:00 committed by GitHub
parent 34c15d3a10
commit 8ddae22cfb
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 147 additions and 705 deletions

View file

@ -22,7 +22,7 @@ pip install mpi4py fastapi uvicorn openai
pip install gradio # for gradio web UI pip install gradio # for gradio web UI
conda install -c conda-forge -y gperftools=2.10 # to enable tcmalloc conda install -c conda-forge -y gperftools=2.10 # to enable tcmalloc
pip install transformers==4.31.0 # for llama2 models pip install transformers==4.37.0
``` ```
### 2. Run pipeline parallel serving on multiple GPUs ### 2. Run pipeline parallel serving on multiple GPUs

View file

@ -30,7 +30,6 @@ def perform_request(session, url, payload, headers):
start_time = time.perf_counter() start_time = time.perf_counter()
with session.post(url, json=payload, headers=headers, stream=True) as response: with session.post(url, json=payload, headers=headers, stream=True) as response:
response.raise_for_status() response.raise_for_status()
first_token_time = None first_token_time = None
last_token_time = 0 last_token_time = 0
first_token_inference_time = None first_token_inference_time = None
@ -38,21 +37,29 @@ def perform_request(session, url, payload, headers):
next_token_time = [] next_token_time = []
i = 0 i = 0
for line in response.iter_lines(): for line in response.iter_lines():
token_time = time.perf_counter() - start_time token_time = time.perf_counter() - start_time
if line: if line:
data = line.decode("utf-8").strip() data = line.decode('utf-8').strip()
i = i + 1 if data.startswith('data: '):
try: data = data[len('data: '):]
json_data = json.loads(data) i = i + 1
if json_data["message"] is not None: try:
if first_token_time is None: json_data = json.loads(data)
first_token_time = token_time if 'choices' in json_data and len(json_data['choices']) > 0:
else: choice = json_data['choices'][0]
next_token_time.append(token_time - last_token_time) if 'finish_reason' in choice and (choice['finish_reason'] == 'length' or choice['finish_reason'] == 'stop'):
last_token_time = token_time if 'first_token_time' in choice and isinstance(choice['first_token_time'], float):
except json.JSONDecodeError: first_token_inference_time = choice['first_token_time']
pass if 'rest_token_time' in choice and isinstance(choice['rest_token_time'], float):
next_token_inference_time = choice['rest_token_time']
else:
if first_token_time is None:
first_token_time = token_time
else:
next_token_time.append(token_time - last_token_time)
last_token_time = token_time
except json.JSONDecodeError:
pass
end_time = time.perf_counter() end_time = time.perf_counter()
return ( return (
first_token_time, first_token_time,
@ -76,11 +83,11 @@ def extend_list_to_length(lst, target_length):
def benchmark( def benchmark(
llm_urls, llm_urls,
prompt, prompt,
num_warmup_requests,
num_requests, num_requests,
max_concurrent_requests, max_concurrent_requests,
max_tokens, max_tokens,
prompt_length, prompt_length,
is_warmup=False,
): ):
headers = {"Content-Type": "application/json"} headers = {"Content-Type": "application/json"}
@ -92,6 +99,8 @@ def benchmark(
next_token_inference_times = [] next_token_inference_times = []
cur_url_index = 0 cur_url_index = 0
num_requests = num_requests + num_warmup_requests
with requests.Session() as session: with requests.Session() as session:
with ThreadPoolExecutor(max_workers=max_concurrent_requests) as executor: with ThreadPoolExecutor(max_workers=max_concurrent_requests) as executor:
llm_url = llm_urls[cur_url_index] llm_url = llm_urls[cur_url_index]
@ -101,8 +110,17 @@ def benchmark(
cur_len = len(cur_llm_urls) cur_len = len(cur_llm_urls)
payload = { payload = {
"model": "Meta-Llama-3-8B-Instruct",
"prompt": prompt, "prompt": prompt,
"n_predict": max_tokens, "max_tokens": max_tokens,
"stream": True,
# for vllm openai api server
"ignore_eos": True,
"n": 1,
"best_of": 1,
"use_beam_search": False,
"temperature": 0.0,
"top_p": 1.0,
} }
futures = [ futures = [
executor.submit( executor.submit(
@ -115,14 +133,13 @@ def benchmark(
for index in range(num_requests) for index in range(num_requests)
] ]
start_time = time.perf_counter() phase = "Benchmarking"
if is_warmup:
phase = "Warm Up"
else:
phase = "Benchmarking"
with tqdm(total=num_requests, desc=phase, unit="req", ncols=100) as pbar: with tqdm(total=num_requests, desc=phase, unit="req", ncols=100) as pbar:
cur_index = 0
for future in concurrent.futures.as_completed(futures): for future in concurrent.futures.as_completed(futures):
if cur_index == num_warmup_requests:
start_time = time.perf_counter()
try: try:
( (
first_token_latency, first_token_latency,
@ -131,21 +148,21 @@ def benchmark(
first_token_inference_time, first_token_inference_time,
next_token_inference_time, next_token_inference_time,
) = future.result() ) = future.result()
first_token_latencies.append(first_token_latency) cur_index = cur_index + 1
next_token_latencies.append(next_token_latency) if cur_index > num_warmup_requests:
total_responce_times.append(total_responce_time) first_token_latencies.append(first_token_latency)
if first_token_inference_time: next_token_latencies.append(next_token_latency)
first_token_inference_times.append( total_responce_times.append(total_responce_time)
first_token_inference_time if first_token_inference_time:
) first_token_inference_times.append(
if next_token_inference_time: first_token_inference_time
next_token_inference_times.append(next_token_inference_time) )
if next_token_inference_time:
next_token_inference_times.append(next_token_inference_time)
except Exception as e: except Exception as e:
print(f"Request failed: {e}") print(f"Request failed: {e}")
pbar.update(1) pbar.update(1)
if is_warmup:
return
total_time = time.perf_counter() - start_time total_time = time.perf_counter() - start_time
log_file = f"{max_concurrent_requests}.log" log_file = f"{max_concurrent_requests}.log"
@ -174,9 +191,6 @@ def benchmark(
) )
p90_first_token_latency = np.percentile(first_token_latencies, 90) p90_first_token_latency = np.percentile(first_token_latencies, 90)
p95_first_token_latency = np.percentile(first_token_latencies, 95) p95_first_token_latency = np.percentile(first_token_latencies, 95)
# average_first_token_inference_latency = np.mean(
# first_token_inference_times
# )
print( print(
f"Average first token latency: {average_first_token_latency * 1000} milliseconds.", f"Average first token latency: {average_first_token_latency * 1000} milliseconds.",
file=file, file=file,
@ -189,10 +203,6 @@ def benchmark(
f"P95 first token latency: {p95_first_token_latency * 1000} milliseconds.", f"P95 first token latency: {p95_first_token_latency * 1000} milliseconds.",
file=file, file=file,
) )
# print(
# f"Average first token inference latency: {average_first_token_inference_latency * 1000} milliseconds.",
# file=file,
# )
print(file=file) print(file=file)
if next_token_latencies: if next_token_latencies:
@ -201,9 +211,6 @@ def benchmark(
) )
p90_next_token_latency = np.percentile(next_token_latencies, 90) p90_next_token_latency = np.percentile(next_token_latencies, 90)
p95_next_token_latency = np.percentile(next_token_latencies, 95) p95_next_token_latency = np.percentile(next_token_latencies, 95)
# average_next_token_inference_latency = np.mean(
# next_token_inference_times
# )
print( print(
f"Average next token latency: {average_next_token_latency * 1000} milliseconds.", f"Average next token latency: {average_next_token_latency * 1000} milliseconds.",
file=file, file=file,
@ -216,14 +223,10 @@ def benchmark(
f"P95 next token latency: {p95_next_token_latency * 1000} milliseconds.", f"P95 next token latency: {p95_next_token_latency * 1000} milliseconds.",
file=file, file=file,
) )
# print(
# f"Average next token inference latency: {average_next_token_inference_latency * 1000} milliseconds.",
# file=file,
# )
print(file=file) print(file=file)
LLM_URLS = [f"http://localhost:{PORT}/generate_stream/" for PORT in [8000]] LLM_URLS = [f"http://localhost:{PORT}/v1/completions" for PORT in [8000]]
parser = argparse.ArgumentParser(description="Set prompt length.") parser = argparse.ArgumentParser(description="Set prompt length.")
parser.add_argument( parser.add_argument(
@ -254,17 +257,6 @@ MAX_TOKENS = args.max_new_tokens
for MAX_CONCURRENT_REQUESTS in args.max_concurrent_requests: for MAX_CONCURRENT_REQUESTS in args.max_concurrent_requests:
NUM_WARMUP = 5 * MAX_CONCURRENT_REQUESTS NUM_WARMUP = 5 * MAX_CONCURRENT_REQUESTS
NUM_REQUESTS = 10 * MAX_CONCURRENT_REQUESTS NUM_REQUESTS = 30 * MAX_CONCURRENT_REQUESTS
# warm up benchmark(LLM_URLS, PROMPT, NUM_WARMUP, NUM_REQUESTS, MAX_CONCURRENT_REQUESTS, MAX_TOKENS, PROMPT_LENGTH)
benchmark(
LLM_URLS,
PROMPT,
NUM_WARMUP,
MAX_CONCURRENT_REQUESTS,
MAX_TOKENS,
PROMPT_LENGTH,
is_warmup=True,
)
benchmark(LLM_URLS, PROMPT, NUM_REQUESTS, MAX_CONCURRENT_REQUESTS, MAX_TOKENS, PROMPT_LENGTH)

View file

@ -1,328 +0,0 @@
from transformers.modeling_utils import PreTrainedModel
from transformers.models.llama.modeling_llama import LlamaConfig, LlamaDecoderLayer, LlamaRMSNorm, LlamaPreTrainedModel
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from torch import nn
import torch
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from typing import List, Optional, Tuple, Union, Iterator
from transformers.utils import logging
logger = logging.get_logger(__name__)
import numpy as np
import time
from transformers import AutoTokenizer, AutoConfig
import torch.distributed as dist
from pipeline_models import (
_make_causal_mask, _expand_mask, DummyLayer, PPConfig,
PipelineBaseModel,
)
class LlamaModel(LlamaPreTrainedModel):
"""
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`]
Args:
config: LlamaConfig
"""
def __init__(self, config: LlamaConfig):
super().__init__(config)
self.config = config
# pp modification
self.pp_config = PPConfig(pp_rank=dist.get_rank(), pp_world_size=dist.get_world_size())
nr_slices = self.pp_config.pp_world_size
# self.config.num_hidden_layers = 8
slice_size = (self.config.num_hidden_layers + nr_slices -
1) // nr_slices
self.layer_start = slice_size * self.pp_config.pp_rank
self.layer_end = self.layer_start + min(slice_size,
self.config.num_hidden_layers - self.layer_start)
self.num_layers = self.layer_end - self.layer_start
layers = []
for i in range(self.config.num_hidden_layers):
if i < self.layer_start or i >= self.layer_end:
layers.append(DummyLayer())
else:
layers.append(LlamaDecoderLayer(config))
self.layers = nn.ModuleList(layers)
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
if self.pp_config.is_head:
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
if self.pp_config.is_tail:
self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def get_input_embeddings(self):
return self.embed_tokens
def set_input_embeddings(self, value):
self.embed_tokens = value
# Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
# create causal mask
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
combined_attention_mask = None
if input_shape[-1] > 1:
combined_attention_mask = _make_causal_mask(
input_shape,
inputs_embeds.dtype,
device=inputs_embeds.device,
past_key_values_length=past_key_values_length,
)
if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(
inputs_embeds.device
)
combined_attention_mask = (
expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
)
return combined_attention_mask
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# retrieve input_ids and inputs_embeds for pp
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
elif input_ids is not None:
assert self.pp_config.is_head, "input_ids is only supported on the head stage"
batch_size, seq_length = input_ids.shape
elif inputs_embeds is not None:
assert not self.pp_config.is_head, "inputs_embeds is only supported on the tail stage"
batch_size, seq_length, _ = inputs_embeds.shape
else:
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
seq_length_with_past = seq_length
past_key_values_length = 0
if past_key_values is not None:
past_key_values_length = past_key_values[0][0].shape[2]
seq_length_with_past = seq_length_with_past + past_key_values_length
if position_ids is None:
device = input_ids.device if input_ids is not None else inputs_embeds.device
position_ids = torch.arange(
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
)
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
else:
position_ids = position_ids.view(-1, seq_length).long()
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
# embed positions
if attention_mask is None:
attention_mask = torch.ones(
(batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
)
attention_mask = self._prepare_decoder_attention_mask(
attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
)
hidden_states = inputs_embeds
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
next_decoder_cache = () if use_cache else None
for idx in range(self.num_layers):
decoder_layer = self.layers[self.layer_start + idx]
if output_hidden_states:
all_hidden_states += (hidden_states,)
past_key_value = past_key_values[idx] if past_key_values is not None else None
layer_outputs = decoder_layer(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
)
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
if output_attentions:
all_self_attns += (layer_outputs[1],)
if self.pp_config.is_tail:
hidden_states = self.norm(hidden_states)
# add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states,)
next_cache = next_decoder_cache if use_cache else None
if not return_dict:
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)
class LlamaForCausalLM(LlamaPreTrainedModel):
def __init__(self, config: LlamaConfig):
super().__init__(config=config)
self.config = config
self.pp_config = PPConfig(pp_rank=dist.get_rank(), pp_world_size=dist.get_world_size())
self.model = LlamaModel(config)
self.pretraining_tp = config.pretraining_tp
self.vocab_size = config.vocab_size
if self.pp_config.is_tail:
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
def get_input_embeddings(self):
return self.model.embed_tokens
def set_input_embeddings(self, value):
self.model.embed_tokens = value
def get_output_embeddings(self):
return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def set_decoder(self, decoder):
self.model = decoder
def get_decoder(self):
return self.model
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, CausalLMOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
if self.pp_config.is_tail:
hidden_states = outputs[0]
logits = self.lm_head(hidden_states)
loss = None
if labels is not None:
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
shift_logits = shift_logits.view(-1, self.config.vocab_size)
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
return outputs
def prepare_inputs_for_generation(
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
):
if past_key_values:
input_ids = input_ids[:, -1:]
position_ids = kwargs.get("position_ids", None)
if attention_mask is not None and position_ids is None:
# create position_ids on the fly for batch generation
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values:
position_ids = position_ids[:, -1].unsqueeze(-1)
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None:
model_inputs = {"inputs_embeds": inputs_embeds}
else:
model_inputs = {"input_ids": input_ids}
model_inputs.update(
{
"position_ids": position_ids,
"past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache"),
"attention_mask": attention_mask,
}
)
return model_inputs
@staticmethod
def _reorder_cache(past_key_values, beam_idx):
reordered_past = ()
for layer_past in past_key_values:
reordered_past += (
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
)
return reordered_past

View file

@ -1,15 +1,15 @@
from torch import nn
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from typing import List, Optional, Tuple, Union, Iterator from typing import List, Optional, Tuple, Union, Iterator
import time import time
from transformers import AutoTokenizer, AutoConfig from transformers.cache_utils import Cache
from transformers.utils import logging from transformers.utils import logging
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
import numpy as np import numpy as np
import asyncio, uuid import asyncio, uuid
import threading import threading
from pydantic import BaseModel
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
@ -23,227 +23,15 @@ class PPConfig:
self.is_head = self.pp_rank == 0 self.is_head = self.pp_rank == 0
self.is_tail = self.pp_rank == self.pp_world_size - 1 self.is_tail = self.pp_rank == self.pp_world_size - 1
# Copied from transformers.models.bart.modeling_bart._make_causal_mask
def _make_causal_mask(
input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
):
"""
Make causal mask used for bi-directional self-attention.
"""
bsz, tgt_len = input_ids_shape
mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
mask_cond = torch.arange(mask.size(-1), device=device)
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
mask = mask.to(dtype)
if past_key_values_length > 0:
mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
# Copied from transformers.models.bart.modeling_bart._expand_mask
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
"""
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
"""
bsz, src_len = mask.size()
tgt_len = tgt_len if tgt_len is not None else src_len
expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
inverted_mask = 1.0 - expanded_mask
return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
# create causal mask
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
combined_attention_mask = None
if input_shape[-1] > 1:
combined_attention_mask = _make_causal_mask(
input_shape,
inputs_embeds.dtype,
device=inputs_embeds.device,
past_key_values_length=past_key_values_length,
)
if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(
inputs_embeds.device
)
combined_attention_mask = (
expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
)
return combined_attention_mask
class DummyLayer(nn.Module):
pass
class PipelineBaseModel(nn.Module):
def __init__(self, config):
self.pp_config = PPConfig(pp_rank=dist.get_rank(), pp_world_size=dist.get_world_size())
nr_slices = self.pp_config.pp_world_size
# self.config.num_hidden_layers = 8
slice_size = (self.config.num_hidden_layers + nr_slices -
1) // nr_slices
self.layer_start = slice_size * self.pp_config.pp_rank
self.layer_end = self.layer_start + min(slice_size,
self.config.num_hidden_layers - self.layer_start)
self.num_layers = self.layer_end - self.layer_start
# Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
# create causal mask
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
combined_attention_mask = None
if input_shape[-1] > 1:
combined_attention_mask = _make_causal_mask(
input_shape,
inputs_embeds.dtype,
device=inputs_embeds.device,
past_key_values_length=past_key_values_length,
)
if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(
inputs_embeds.device
)
combined_attention_mask = (
expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
)
return combined_attention_mask
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# retrieve input_ids and inputs_embeds
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
elif input_ids is not None:
assert self.pp_config.is_head, "input_ids is only supported on the head stage"
batch_size, seq_length = input_ids.shape
elif inputs_embeds is not None:
assert not self.pp_config.is_head, "inputs_embeds is only supported on the tail stage"
batch_size, seq_length, _ = inputs_embeds.shape
else:
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
seq_length_with_past = seq_length
past_key_values_length = 0
if past_key_values is not None:
past_key_values_length = past_key_values[0][0].shape[2]
seq_length_with_past = seq_length_with_past + past_key_values_length
if position_ids is None:
device = input_ids.device if input_ids is not None else inputs_embeds.device
position_ids = torch.arange(
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
)
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
else:
position_ids = position_ids.view(-1, seq_length).long()
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
# embed positions
if attention_mask is None:
attention_mask = torch.ones(
(batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
)
attention_mask = self._prepare_decoder_attention_mask(
attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
)
hidden_states = inputs_embeds
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
next_decoder_cache = () if use_cache else None
for idx in range(self.num_layers):
decoder_layer = self.layers[self.layer_start + idx]
if output_hidden_states:
all_hidden_states += (hidden_states,)
past_key_value = past_key_values[idx] if past_key_values is not None else None
layer_outputs = decoder_layer(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
)
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
if output_attentions:
all_self_attns += (layer_outputs[1],)
if self.pp_config.is_tail:
hidden_states = self.norm(hidden_states)
# add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states,)
next_cache = next_decoder_cache if use_cache else None
if not return_dict:
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)
def load_model(checkpoint):
from llama_models import LlamaForCausalLM
if 'llama' in checkpoint.lower():
model = LlamaForCausalLM.from_pretrained(checkpoint, low_cpu_mem_usage=True, torch_dtype=torch.float16)
return model
from pydantic import BaseModel
class BatchTask(BaseModel): class BatchTask(BaseModel):
batch_id: str batch_id: str
request_ids: List[str] request_ids: List[str]
max_tokens: int max_tokens: int
batch_size: int batch_size: int
input_len: int input_len: int
# plain_texts: List[str]
prompt_lengths: List[int] prompt_lengths: List[int]
stopped: bool stopped: bool
# input_ids: torch.Tensor
# attention_mask: torch.Tensor
def make_attention_mask(prompt_lengths): def make_attention_mask(prompt_lengths):
@ -256,19 +44,14 @@ def make_attention_mask(prompt_lengths):
class ModelRunner: class ModelRunner:
def __init__(self, checkpoint, rank, world_size, low_bit, max_num_seqs): def __init__(self, checkpoint, rank, world_size, low_bit, max_num_seqs):
import sys
self.pp_config = PPConfig(rank, world_size) self.pp_config = PPConfig(rank, world_size)
start = time.perf_counter() start = time.perf_counter()
model = load_model(checkpoint) model = self.load_model(checkpoint, rank, world_size, low_bit)
end = time.perf_counter() end = time.perf_counter()
logger.info(f"Time to load weights: {end - start:.2f}s") logger.info(f"Time to load weights: {end - start:.2f}s")
from ipex_llm import optimize_model
model = optimize_model(model, low_bit=low_bit)
model = model.to(torch.float16).to(f'xpu:{rank}')
self.model = model self.model = model
self.rank = rank self.rank = rank
self.world_size = world_size self.world_size = world_size
@ -295,44 +78,63 @@ class ModelRunner:
self.is_finish = {} self.is_finish = {}
self.model_name = checkpoint self.model_name = checkpoint
self.layer_start = 0
# def generate(self, input_ids=None, max_tokens=5, attention_mask=None):
# times = []
# with torch.no_grad(): def load_model(self, model_path, my_rank, my_size, low_bit='sym_int4'):
# _input_ids = None device = f"xpu:{my_rank}"
# _past_key_values = None from ipex_llm.transformers import AutoModelForCausalLM
# bs = input_ids.shape[0] model = AutoModelForCausalLM.from_pretrained(model_path,
# output_ids = input_ids.clone() load_in_low_bit=low_bit,
# for i in range(max_tokens): torch_dtype=torch.float16,
# start = time.perf_counter() optimize_model=True,
# if _input_ids is None: trust_remote_code=True,
# _input_ids = input_ids use_cache=True,
# if self.rank == 0: pipeline_parallel_stages=my_size).eval()
# outputs = self.model(input_ids=_input_ids, attention_mask=attention_mask, past_key_values=_past_key_values, use_cache=True) # print(model)
# else:
# inputs_embeds = torch.empty(_input_ids.shape + (self.hidden_size,) , device=f'xpu:{self.rank}', dtype=torch.float32) # config_class = type(model.config).__name__
# dist.recv(inputs_embeds, src=self.pre_rank) # if config_class == 'ChatGLMConfig':
# outputs = self.model(inputs_embeds=inputs_embeds, attention_mask=attention_mask, past_key_values=_past_key_values, use_cache=True) # model.config.num_hidden_layers = model.config.num_layers
# nr_slices = my_size
# if self.rank == self.world_size - 1: # slice_size = (model.config.num_layers + nr_slices - 1) // nr_slices
# logits = outputs.logits # layer_start = slice_size * my_rank
# next_ids = torch.argmax(logits[:, -1:, :], dim=-1) # layer_end = layer_start + min(slice_size, model.config.num_layers - layer_start)
# assert next_ids.shape == (bs, 1)
# dist.broadcast(next_ids, src=self.rank) # for i in range(model.config.num_layers):
# else: # if i < layer_start or i >= layer_end:
# dist.send(outputs.last_hidden_state, dst=self.next_rank) # model.transformer.encoder.layers[i] = Dummy_DecoderLayer()
# next_ids = torch.empty((bs, 1), device=f'xpu:{self.rank}', dtype=torch.int64) # else:
# dist.broadcast(next_ids, src=self.world_size - 1) # pass
# # align layer_idx and len(past_key_values), otherwise abnormal output
# _input_ids = next_ids # # model._modules['encoder'].layers[i].self_attention.layer_idx = i - layer_start
# output_ids = torch.cat([output_ids, next_ids], dim=-1) # # model.transformer.encoder.layers[i].self_attention.layer_idx = i - layer_start
# _past_key_values = outputs.past_key_values
# end = time.perf_counter() # if my_rank != 0:
# times.append(end - start) # model.transformer.embedding = DummyLayer()
# if my_rank != my_size - 1:
# if self.rank == 0: # model.transformer.output_layer = DummyLayer()
# logger.info(f"first token latency: {times[0]}, rest token avg latecy: {np.mean(times[1:])}")
# return output_ids # else:
# nr_slices = my_size
# slice_size = (model.config.num_hidden_layers + nr_slices - 1) // nr_slices
# layer_start = slice_size * my_rank
# layer_end = layer_start + min(slice_size, model.config.num_hidden_layers - layer_start)
# for i in range(model.config.num_hidden_layers):
# if i < layer_start or i >= layer_end:
# model._modules['model'].layers[i] = Dummy_DecoderLayer()
# else:
# # align layer_idx and len(past_key_values), otherwise abnormal output
# model._modules['model'].layers[i].self_attn.layer_idx = i - layer_start
# if my_rank != 0:
# model._modules['model'].embed_tokens = DummyLayer()
# if my_rank != my_size - 1:
# model._modules['model'].norm = DummyLayer()
# model._modules['lm_head'] = DummyLayer()
# model = model.to(f'xpu:{my_rank}')
return model
def model_step(self, input, cur_batch): def model_step(self, input, cur_batch):
@ -341,7 +143,6 @@ class ModelRunner:
cur_id = cur_batch.batch_id cur_id = cur_batch.batch_id
_past_key_values = self.past_key_values_dict.get(cur_id, None) _past_key_values = self.past_key_values_dict.get(cur_id, None)
# attention_mask = self.attention_mask_dict[cur_id]
attention_mask = make_attention_mask(cur_batch.prompt_lengths) attention_mask = make_attention_mask(cur_batch.prompt_lengths)
if self.rank == 0: if self.rank == 0:
@ -350,18 +151,33 @@ class ModelRunner:
else: else:
input_ids = None input_ids = None
inputs_embeds = input inputs_embeds = input
# logger.info(f"{self.rank}, {_past_key_values}")
output = self.model( output = self.model(
input_ids=input_ids, input_ids=input_ids,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
attention_mask=attention_mask, attention_mask=attention_mask,
past_key_values=_past_key_values, past_key_values=_past_key_values,
use_cache=True use_cache=True,
output_hidden_states=True,
) )
self.past_key_values_dict[cur_id] = output.past_key_values use_legacy_cache = not isinstance(output.past_key_values, Cache)
if not self.pp_config.is_tail: if use_legacy_cache and self.rank > 0:
return output.last_hidden_state if output.past_key_values[0] is None:
_past_key_values = list(output.past_key_values)
slice_size = (self.model.config.num_hidden_layers + self.world_size - 1) // self.world_size
layer_start = slice_size * self.rank
_past_key_values[0] = [torch.empty_like(output.past_key_values[layer_start][0])]
_past_key_values = tuple(_past_key_values)
else:
_past_key_values = output.past_key_values
else:
_past_key_values = output.past_key_values
self.past_key_values_dict[cur_id] = _past_key_values
if not self.pp_config.is_tail:
return output.hidden_states[-1]
else: else:
# logger.info(f"logits: {output.logits.shape}")
return output.logits return output.logits
@ -376,7 +192,6 @@ class ModelRunner:
break break
tmp_result = await self.waiting_requests.get() tmp_result = await self.waiting_requests.get()
# logger.info(tmp_result)
request_id, prompt_request = tmp_result request_id, prompt_request = tmp_result
request_ids.append(request_id) request_ids.append(request_id)
prompt_requests.append(prompt_request) prompt_requests.append(prompt_request)
@ -393,14 +208,10 @@ class ModelRunner:
input_len=input_ids.size(1), input_len=input_ids.size(1),
prompt_lengths=[sum(attention_mask[i,:]) for i in range(input_ids.size(0))], prompt_lengths=[sum(attention_mask[i,:]) for i in range(input_ids.size(0))],
stopped=False, stopped=False,
# plain_texts=plain_texts,
# input_ids=input_ids,
# attention_mask=attention_mask,
) )
self.input_ids_dict[new_batch.batch_id] = input_ids self.input_ids_dict[new_batch.batch_id] = input_ids
self.token_times[new_batch.batch_id] = [time.perf_counter()] self.token_times[new_batch.batch_id] = [time.perf_counter()]
# self.attention_mask_dict[new_batch.batch_id] = attention_mask
return new_batch return new_batch
@ -409,7 +220,6 @@ class ModelRunner:
self.input_ids_dict.pop(cur_id, None) self.input_ids_dict.pop(cur_id, None)
self.tokens.pop(cur_id, None) self.tokens.pop(cur_id, None)
self.token_times.pop(cur_id, None) self.token_times.pop(cur_id, None)
# self.attention_mask_dict.pop(cur_id, None)
self.past_key_values_dict.pop(cur_id, None) self.past_key_values_dict.pop(cur_id, None)
# torch.xpu.empty_cache() # torch.xpu.empty_cache()
@ -448,9 +258,7 @@ class ModelRunner:
next_ids = next_ids.unsqueeze(0) next_ids = next_ids.unsqueeze(0)
self.tokens[cur_id].append(next_ids) self.tokens[cur_id].append(next_ids)
self.token_times[cur_id].append(time.perf_counter()) self.token_times[cur_id].append(time.perf_counter())
# self.input_ids_dict[cur_id] += next_ids
cur_input = next_ids cur_input = next_ids
# cur_batch.input_len += 1
cur_batch.input_len = 1 cur_batch.input_len = 1
cur_batch.prompt_lengths = [x + 1 for x in cur_batch.prompt_lengths] cur_batch.prompt_lengths = [x + 1 for x in cur_batch.prompt_lengths]
@ -462,9 +270,10 @@ class ModelRunner:
if self.streamer.get(request_id, None) is None: if self.streamer.get(request_id, None) is None:
self.streamer[request_id] = asyncio.Queue() self.streamer[request_id] = asyncio.Queue()
if next_ids[index].int() == tokenizer.eos_token_id: # Currently ignore eos for benchmark
remain = 0 # if next_ids[index].int() == tokenizer.eos_token_id:
self.is_finish[request_id] = True # remain = 0
# self.is_finish[request_id] = True
if self.token_cache.get(request_id, None) is None: if self.token_cache.get(request_id, None) is None:
self.token_cache[request_id] = [] self.token_cache[request_id] = []
@ -533,12 +342,6 @@ class ModelRunner:
cur_input = torch.empty((cur_batch.batch_size, cur_len, self.hidden_size,), device=f'xpu:{self.rank}', dtype=self.dtype) cur_input = torch.empty((cur_batch.batch_size, cur_len, self.hidden_size,), device=f'xpu:{self.rank}', dtype=self.dtype)
# logger.info(f"rank: {self.rank}, recv: {cur_input.shape}") # logger.info(f"rank: {self.rank}, recv: {cur_input.shape}")
dist.recv(cur_input, src=self.pre_rank) dist.recv(cur_input, src=self.pre_rank)
# if self.attention_mask_dict.get(cur_batch.batch_id, None) is None:
# self.attention_mask_dict[cur_batch.batch_id] = make_attention_mask(cur_batch.prompt_lengths)
# if self.rank == 0:
# logger.info(f"rank: {self.rank}, {batch_list}")
output = self.model_step(cur_input, cur_batch) output = self.model_step(cur_input, cur_batch)
if output is not None and self.rank == self.world_size - 1: if output is not None and self.rank == self.world_size - 1:
@ -576,4 +379,4 @@ def _is_chinese_char(cp):
): # ): #
return True return True
return False return False

View file

@ -3,19 +3,16 @@ import torch.nn.parallel
import torch.distributed as dist import torch.distributed as dist
import os import os
import ipex_llm
from ipex_llm.utils.common import invalidInputError from ipex_llm.utils.common import invalidInputError
from ipex_llm.transformers import init_pipeline_parallel
import oneccl_bindings_for_pytorch import oneccl_bindings_for_pytorch
import json import json
from transformers.utils import logging from transformers.utils import logging
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
os.environ['MASTER_ADDR'] = '127.0.0.1' init_pipeline_parallel()
os.environ['MASTER_PORT'] = '29501'
backend = 'ccl'
dist.init_process_group(backend)
my_rank = dist.get_rank() my_rank = dist.get_rank()
my_size = dist.get_world_size() my_size = dist.get_world_size()
device = f"xpu:{my_rank}" device = f"xpu:{my_rank}"
@ -146,7 +143,7 @@ async def completion_stream_generator(local_model, delta_text_queue, request_id)
if remain == 0: if remain == 0:
choice_data = CompletionResponseStreamChoice( choice_data = CompletionResponseStreamChoice(
index=index, index=index,
text=None, text="",
logprobs=None, logprobs=None,
finish_reason="length") finish_reason="length")
chunk = CompletionStreamResponse( chunk = CompletionStreamResponse(
@ -171,7 +168,6 @@ async def generator(local_model, delta_text_queue, request_id):
break break
else: else:
await asyncio.sleep(0) await asyncio.sleep(0)
# streamer_dict.pop(request_id, None)
local_model.streamer.pop(request_id, None) local_model.streamer.pop(request_id, None)
@ -282,29 +278,6 @@ async def create_completion(request: CompletionRequest):
return result return result
def generate_text(prompt: List[str], n_predict = 32):
while prompt[-1] == "":
prompt = prompt[:-1]
if isinstance(n_predict, list):
n_predict = max(n_predict)
inputs = tokenizer(prompt, return_tensors="pt", padding=True)
input_ids = inputs.input_ids.to(f'xpu:{local_rank}')
print(inputs)
attention_mask = inputs.attention_mask.to(f'xpu:{local_rank}')
output = local_model.generate(input_ids,
max_tokens=n_predict,
# attention_mask=attention_mask,
# max_new_tokens=n_predict,
# min_new_tokens=n_predict,
# do_sample=False,
# use_cache=True
)
torch.xpu.synchronize()
return output
async def process_requests(local_model, result_dict): async def process_requests(local_model, result_dict):
while True: while True:
await asyncio.sleep(0) await asyncio.sleep(0)

View file

@ -14,6 +14,6 @@ export TORCH_LLM_ALLREDUCE=0
export MODEL_PATH=YOUR_MODEL_PATH export MODEL_PATH=YOUR_MODEL_PATH
export NUM_GPUS=2 export NUM_GPUS=2
export BIGDL_QUANTIZE_KV_CACHE=1 export IPEX_LLM_QUANTIZE_KV_CACHE=1
CCL_ZE_IPC_EXCHANGE=sockets torchrun --standalone --nnodes=1 --nproc-per-node $NUM_GPUS pipeline_serving.py --repo-id-or-model-path $MODEL_PATH --low-bit fp8 --max-num-seqs 4 CCL_ZE_IPC_EXCHANGE=sockets torchrun --standalone --nnodes=1 --nproc-per-node $NUM_GPUS pipeline_serving.py --repo-id-or-model-path $MODEL_PATH --low-bit fp8 --max-num-seqs 4

View file

@ -64,7 +64,9 @@ class Dummy_DecoderLayer(nn.Module):
self.input_layernorm = DummyLayer() self.input_layernorm = DummyLayer()
self.mlp = Dummy_MLPLayer() self.mlp = Dummy_MLPLayer()
def forward(self, hidden_states, past_key_value=None, use_cache=False, **kwargs): def forward(self, hidden_states, *args, **kwargs):
past_key_value = kwargs.get('past_key_value', None)
use_cache = kwargs.get('use_cache', False)
outputs = (hidden_states,) outputs = (hidden_states,)
if use_cache: if use_cache:
outputs += (past_key_value,) outputs += (past_key_value,)