Optimizations for Pipeline Parallel Serving (#11702)

Optimizations for Pipeline Parallel Serving
This commit is contained in:
Xiangyu Tian 2024-08-02 12:06:59 +08:00 committed by GitHub
parent 8d1e0bd2f4
commit 1baa3efe0e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 86 additions and 55 deletions

View file

@ -250,7 +250,7 @@ async def generate_stream(inputs_request: InputsRequest):
request_id = str(uuid.uuid4()) + "stream"
await local_model.waiting_requests.put((request_id, inputs_request))
while True:
await asyncio.sleep(0)
await asyncio.sleep(0.1)
cur_streamer = local_model.streamer.get(request_id, None)
if cur_streamer is not None:
if inputs_request.req_type == 'completion':

View file

@ -25,7 +25,7 @@ import torch.distributed as dist
import os
import time
import numpy as np
from typing import Callable, List, Optional, Union, Tuple
from typing import Callable, List, Optional, Union, Tuple, Any
from types import SimpleNamespace
import transformers
from transformers import GenerationConfig, LogitsProcessorList, StoppingCriteriaList
@ -37,6 +37,7 @@ logger = logging.getLogger(__name__)
import asyncio
import uuid
import threading
import pickle
try:
from pydantic import BaseModel
except ImportError:
@ -513,6 +514,8 @@ class PPModelWorker:
self.max_prefilled_seqs = max_prefilled_seqs
self.partial_output_dict = {}
self.stream_tasks = {}
def load_model(self, model_path, world_size, low_bit='sym_int4'):
from ipex_llm.transformers import AutoModelForCausalLM, AutoModel
try:
@ -683,13 +686,11 @@ class PPModelWorker:
_output = output[0]
if _output.dtype != self.dtype:
_output = _output.to(self.dtype)
return _output, cur_batch
else:
if cur_batch.partial_prefilling > 0 and \
cur_batch.prefilled_index == cur_batch.batch_size:
_output = self.partial_output_dict.pop(cur_id, None)
cur_batch.partial_prefilling = 0
return _output, cur_batch
else:
_output = torch.argmax(output.logits[:, -1:, :], dim=-1)
return _output, cur_batch
@ -738,14 +739,71 @@ class PPModelWorker:
self.is_finish.pop(cur_id, None)
self.partial_output_dict.pop(cur_id, None)
async def wait_stream_output(self, cur_id):
cur_task = self.stream_tasks.pop(cur_id, None)
if cur_task is not None:
await cur_task
def get_printable_text(self, cur_text, request_id):
if cur_text.endswith("\n"):
printable_text = cur_text[self.print_len[request_id]:]
self.token_cache[request_id] = []
self.print_len[request_id] = 0
elif len(cur_text) > 0 and _is_chinese_char(ord(cur_text[-1])):
printable_text = cur_text[self.print_len[request_id]:]
self.print_len[request_id] += len(printable_text)
self.token_cache[request_id] = []
self.print_len[request_id] = 0
else:
r_index = cur_text.rfind(" ") + 1
if r_index > self.print_len[request_id]:
printable_text = cur_text[self.print_len[request_id]: r_index]
self.token_cache[request_id] = self.token_cache[request_id][-1:]
self.print_len[request_id] = 0
else:
printable_text = cur_text[self.print_len[request_id]: r_index]
return printable_text
async def stream_output(self, cur_batch, tokenizer, next_ids):
cur_id = cur_batch.batch_id
cur_cached_ids = []
_stream_tasks = []
for index, request_id in enumerate(cur_batch.request_ids):
if not self.is_finish.get(request_id, False):
if self.token_cache.get(request_id, None) is None:
self.token_cache[request_id] = []
self.print_len[request_id] = 0
self.token_cache[request_id].extend(next_ids[index].tolist())
cur_cached_ids.append(self.token_cache[request_id])
for index, request_id in enumerate(cur_batch.request_ids):
if not self.is_finish.get(request_id, False):
remain = cur_batch.max_tokens - len(self.tokens[cur_id])
if self.streamer.get(request_id, None) is None:
self.streamer[request_id] = asyncio.Queue()
# Currently ignore eos for benchmark
# if next_ids[index].int() == tokenizer.eos_token_id:
# remain = 0
# self.is_finish[request_id] = True
cur_text = tokenizer.decode(self.token_cache[request_id])
printable_text = self.get_printable_text(cur_text, request_id)
if remain > 0:
_stream_tasks.append(self.streamer[request_id].put((remain, printable_text)))
else:
printable_text = printable_text + cur_text[self.print_len[request_id]:]
self.token_cache.pop(request_id, None)
self.print_len.pop(request_id, None)
_stream_tasks.append(self.streamer[request_id].put((remain, printable_text)))
await asyncio.gather(*_stream_tasks)
async def process_step(self, tokenizer, result_dict):
cur_batch = None
torch.xpu.synchronize(self.device)
if self.rank == 0:
if self.send_buff is not None:
# logger.info(f"send {self.rank} {self.send_buff.shape}")
dist.send(self.send_buff, dst=self.next_rank)
if self.on_going_batches[0] is not None:
cur_batch = self.on_going_batches[0]
cur_input = None
@ -773,13 +831,13 @@ class PPModelWorker:
# logger.info(f"recv {self.rank} {next_ids.shape}")
dist.recv(next_ids, src=self.pre_rank)
torch.xpu.synchronize(self.device)
if cur_batch.partial_prefilling > 0:
cur_input = self.input_ids_dict[cur_batch.batch_id]
else:
if self.tokens.get(cur_id, None) is None:
self.tokens[cur_id] = []
if len(next_ids.shape) == 1:
next_ids = next_ids.unsqueeze(0)
self.tokens[cur_id].append(next_ids)
@ -788,44 +846,14 @@ class PPModelWorker:
cur_batch.input_len = 1
cur_batch.prompt_lengths = [x + 1 for x in cur_batch.prompt_lengths]
for index, request_id in enumerate(cur_batch.request_ids):
if not self.is_finish.get(request_id, False):
remain = cur_batch.max_tokens - len(self.tokens[cur_id])
if self.streamer.get(request_id, None) is None:
self.streamer[request_id] = asyncio.Queue()
# Currently ignore eos for benchmark
# if next_ids[index].int() == tokenizer.eos_token_id:
# remain = 0
# self.is_finish[request_id] = True
if self.token_cache.get(request_id, None) is None:
self.token_cache[request_id] = []
self.print_len[request_id] = 0
self.token_cache[request_id].extend(next_ids[index].tolist())
text = tokenizer.decode(self.token_cache[request_id])
if text.endswith("\n"):
printable_text = text[self.print_len[request_id]:]
self.token_cache[request_id] = []
self.print_len[request_id] = 0
elif len(text) > 0 and _is_chinese_char(ord(text[-1])):
printable_text = text[self.print_len[request_id]:]
self.print_len[request_id] += len(printable_text)
else:
r_index = text.rfind(" ") + 1
printable_text = text[self.print_len[request_id]: r_index]
self.print_len[request_id] += len(printable_text)
if remain > 0:
await self.streamer[request_id].put((remain, printable_text))
else:
printable_text = printable_text + text[self.print_len[request_id]:]
self.token_cache.pop(request_id, None)
self.print_len.pop(request_id, None)
await self.streamer[request_id].put((remain, printable_text))
pre_task = self.stream_tasks.get(cur_id)
if pre_task is not None:
await pre_task
del self.stream_tasks[cur_id]
cur_task = asyncio.create_task(
self.stream_output(cur_batch, tokenizer, next_ids)
)
self.stream_tasks[cur_id] = cur_task
if len(self.tokens[cur_id]) >= cur_batch.max_tokens:
# Finish a batch
@ -841,6 +869,7 @@ class PPModelWorker:
next_token = (cur_times[-1] - cur_times[1]) / (len(self.tokens[cur_id]) - 1)
logger.info(f"First token latency: {first_token}, "
f"next token latency: {next_token}")
await self.wait_stream_output(cur_id)
self.clear_batch(cur_id)
cur_batch.stopped = True
else:
@ -850,15 +879,12 @@ class PPModelWorker:
if cur_batch is not None:
cur_batch = self.prepare_batch(cur_batch)
dist.broadcast_object_list([cur_batch], src=0)
else:
await asyncio.sleep(0)
else:
if self.send_buff is not None:
# logger.info(f"send {self.rank} {self.send_buff.shape}")
dist.send(self.send_buff, dst=self.next_rank)
batch_list = [None]
dist.broadcast_object_list(batch_list, src=0)
cur_batch = batch_list[0]
cur_input = None
@ -882,10 +908,15 @@ class PPModelWorker:
)
# logger.info(f"recv {self.rank} {cur_input.shape}")
dist.recv(cur_input, src=self.pre_rank)
torch.xpu.synchronize(self.device)
output, cur_batch = self.model_step(cur_input, cur_batch)
self.send_buff = output
torch.xpu.synchronize(self.device)
if self.send_buff is not None:
self.send_buff.wait()
if output is not None:
self.send_buff = dist.isend(output, dst=self.next_rank)
if self.rank == 0:
self.on_going_batches[:-1] = self.on_going_batches[1:]