Optimizations for Pipeline Parallel Serving (#11702)

Optimizations for Pipeline Parallel Serving
This commit is contained in:
Xiangyu Tian 2024-08-02 12:06:59 +08:00 committed by GitHub
parent 8d1e0bd2f4
commit 1baa3efe0e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 86 additions and 55 deletions

View file

@ -250,7 +250,7 @@ async def generate_stream(inputs_request: InputsRequest):
request_id = str(uuid.uuid4()) + "stream" request_id = str(uuid.uuid4()) + "stream"
await local_model.waiting_requests.put((request_id, inputs_request)) await local_model.waiting_requests.put((request_id, inputs_request))
while True: while True:
await asyncio.sleep(0) await asyncio.sleep(0.1)
cur_streamer = local_model.streamer.get(request_id, None) cur_streamer = local_model.streamer.get(request_id, None)
if cur_streamer is not None: if cur_streamer is not None:
if inputs_request.req_type == 'completion': if inputs_request.req_type == 'completion':

View file

@ -25,7 +25,7 @@ import torch.distributed as dist
import os import os
import time import time
import numpy as np import numpy as np
from typing import Callable, List, Optional, Union, Tuple from typing import Callable, List, Optional, Union, Tuple, Any
from types import SimpleNamespace from types import SimpleNamespace
import transformers import transformers
from transformers import GenerationConfig, LogitsProcessorList, StoppingCriteriaList from transformers import GenerationConfig, LogitsProcessorList, StoppingCriteriaList
@ -37,6 +37,7 @@ logger = logging.getLogger(__name__)
import asyncio import asyncio
import uuid import uuid
import threading import threading
import pickle
try: try:
from pydantic import BaseModel from pydantic import BaseModel
except ImportError: except ImportError:
@ -513,6 +514,8 @@ class PPModelWorker:
self.max_prefilled_seqs = max_prefilled_seqs self.max_prefilled_seqs = max_prefilled_seqs
self.partial_output_dict = {} self.partial_output_dict = {}
self.stream_tasks = {}
def load_model(self, model_path, world_size, low_bit='sym_int4'): def load_model(self, model_path, world_size, low_bit='sym_int4'):
from ipex_llm.transformers import AutoModelForCausalLM, AutoModel from ipex_llm.transformers import AutoModelForCausalLM, AutoModel
try: try:
@ -683,16 +686,14 @@ class PPModelWorker:
_output = output[0] _output = output[0]
if _output.dtype != self.dtype: if _output.dtype != self.dtype:
_output = _output.to(self.dtype) _output = _output.to(self.dtype)
return _output, cur_batch
else: else:
if cur_batch.partial_prefilling > 0 and \ if cur_batch.partial_prefilling > 0 and \
cur_batch.prefilled_index == cur_batch.batch_size: cur_batch.prefilled_index == cur_batch.batch_size:
_output = self.partial_output_dict.pop(cur_id, None) _output = self.partial_output_dict.pop(cur_id, None)
cur_batch.partial_prefilling = 0 cur_batch.partial_prefilling = 0
return _output, cur_batch
else: else:
_output = torch.argmax(output.logits[:, -1:, :], dim=-1) _output = torch.argmax(output.logits[:, -1:, :], dim=-1)
return _output, cur_batch return _output, cur_batch
def is_initialized(self): def is_initialized(self):
return True return True
@ -738,14 +739,71 @@ class PPModelWorker:
self.is_finish.pop(cur_id, None) self.is_finish.pop(cur_id, None)
self.partial_output_dict.pop(cur_id, None) self.partial_output_dict.pop(cur_id, None)
async def wait_stream_output(self, cur_id):
cur_task = self.stream_tasks.pop(cur_id, None)
if cur_task is not None:
await cur_task
def get_printable_text(self, cur_text, request_id):
if cur_text.endswith("\n"):
printable_text = cur_text[self.print_len[request_id]:]
self.token_cache[request_id] = []
self.print_len[request_id] = 0
elif len(cur_text) > 0 and _is_chinese_char(ord(cur_text[-1])):
printable_text = cur_text[self.print_len[request_id]:]
self.print_len[request_id] += len(printable_text)
self.token_cache[request_id] = []
self.print_len[request_id] = 0
else:
r_index = cur_text.rfind(" ") + 1
if r_index > self.print_len[request_id]:
printable_text = cur_text[self.print_len[request_id]: r_index]
self.token_cache[request_id] = self.token_cache[request_id][-1:]
self.print_len[request_id] = 0
else:
printable_text = cur_text[self.print_len[request_id]: r_index]
return printable_text
async def stream_output(self, cur_batch, tokenizer, next_ids):
cur_id = cur_batch.batch_id
cur_cached_ids = []
_stream_tasks = []
for index, request_id in enumerate(cur_batch.request_ids):
if not self.is_finish.get(request_id, False):
if self.token_cache.get(request_id, None) is None:
self.token_cache[request_id] = []
self.print_len[request_id] = 0
self.token_cache[request_id].extend(next_ids[index].tolist())
cur_cached_ids.append(self.token_cache[request_id])
for index, request_id in enumerate(cur_batch.request_ids):
if not self.is_finish.get(request_id, False):
remain = cur_batch.max_tokens - len(self.tokens[cur_id])
if self.streamer.get(request_id, None) is None:
self.streamer[request_id] = asyncio.Queue()
# Currently ignore eos for benchmark
# if next_ids[index].int() == tokenizer.eos_token_id:
# remain = 0
# self.is_finish[request_id] = True
cur_text = tokenizer.decode(self.token_cache[request_id])
printable_text = self.get_printable_text(cur_text, request_id)
if remain > 0:
_stream_tasks.append(self.streamer[request_id].put((remain, printable_text)))
else:
printable_text = printable_text + cur_text[self.print_len[request_id]:]
self.token_cache.pop(request_id, None)
self.print_len.pop(request_id, None)
_stream_tasks.append(self.streamer[request_id].put((remain, printable_text)))
await asyncio.gather(*_stream_tasks)
async def process_step(self, tokenizer, result_dict): async def process_step(self, tokenizer, result_dict):
cur_batch = None cur_batch = None
torch.xpu.synchronize(self.device)
if self.rank == 0: if self.rank == 0:
if self.send_buff is not None:
# logger.info(f"send {self.rank} {self.send_buff.shape}")
dist.send(self.send_buff, dst=self.next_rank)
if self.on_going_batches[0] is not None: if self.on_going_batches[0] is not None:
cur_batch = self.on_going_batches[0] cur_batch = self.on_going_batches[0]
cur_input = None cur_input = None
@ -773,13 +831,13 @@ class PPModelWorker:
# logger.info(f"recv {self.rank} {next_ids.shape}") # logger.info(f"recv {self.rank} {next_ids.shape}")
dist.recv(next_ids, src=self.pre_rank) dist.recv(next_ids, src=self.pre_rank)
torch.xpu.synchronize(self.device)
if cur_batch.partial_prefilling > 0: if cur_batch.partial_prefilling > 0:
cur_input = self.input_ids_dict[cur_batch.batch_id] cur_input = self.input_ids_dict[cur_batch.batch_id]
else: else:
if self.tokens.get(cur_id, None) is None: if self.tokens.get(cur_id, None) is None:
self.tokens[cur_id] = [] self.tokens[cur_id] = []
if len(next_ids.shape) == 1: if len(next_ids.shape) == 1:
next_ids = next_ids.unsqueeze(0) next_ids = next_ids.unsqueeze(0)
self.tokens[cur_id].append(next_ids) self.tokens[cur_id].append(next_ids)
@ -788,44 +846,14 @@ class PPModelWorker:
cur_batch.input_len = 1 cur_batch.input_len = 1
cur_batch.prompt_lengths = [x + 1 for x in cur_batch.prompt_lengths] cur_batch.prompt_lengths = [x + 1 for x in cur_batch.prompt_lengths]
for index, request_id in enumerate(cur_batch.request_ids): pre_task = self.stream_tasks.get(cur_id)
if pre_task is not None:
if not self.is_finish.get(request_id, False): await pre_task
remain = cur_batch.max_tokens - len(self.tokens[cur_id]) del self.stream_tasks[cur_id]
cur_task = asyncio.create_task(
if self.streamer.get(request_id, None) is None: self.stream_output(cur_batch, tokenizer, next_ids)
self.streamer[request_id] = asyncio.Queue() )
self.stream_tasks[cur_id] = cur_task
# Currently ignore eos for benchmark
# if next_ids[index].int() == tokenizer.eos_token_id:
# remain = 0
# self.is_finish[request_id] = True
if self.token_cache.get(request_id, None) is None:
self.token_cache[request_id] = []
self.print_len[request_id] = 0
self.token_cache[request_id].extend(next_ids[index].tolist())
text = tokenizer.decode(self.token_cache[request_id])
if text.endswith("\n"):
printable_text = text[self.print_len[request_id]:]
self.token_cache[request_id] = []
self.print_len[request_id] = 0
elif len(text) > 0 and _is_chinese_char(ord(text[-1])):
printable_text = text[self.print_len[request_id]:]
self.print_len[request_id] += len(printable_text)
else:
r_index = text.rfind(" ") + 1
printable_text = text[self.print_len[request_id]: r_index]
self.print_len[request_id] += len(printable_text)
if remain > 0:
await self.streamer[request_id].put((remain, printable_text))
else:
printable_text = printable_text + text[self.print_len[request_id]:]
self.token_cache.pop(request_id, None)
self.print_len.pop(request_id, None)
await self.streamer[request_id].put((remain, printable_text))
if len(self.tokens[cur_id]) >= cur_batch.max_tokens: if len(self.tokens[cur_id]) >= cur_batch.max_tokens:
# Finish a batch # Finish a batch
@ -841,6 +869,7 @@ class PPModelWorker:
next_token = (cur_times[-1] - cur_times[1]) / (len(self.tokens[cur_id]) - 1) next_token = (cur_times[-1] - cur_times[1]) / (len(self.tokens[cur_id]) - 1)
logger.info(f"First token latency: {first_token}, " logger.info(f"First token latency: {first_token}, "
f"next token latency: {next_token}") f"next token latency: {next_token}")
await self.wait_stream_output(cur_id)
self.clear_batch(cur_id) self.clear_batch(cur_id)
cur_batch.stopped = True cur_batch.stopped = True
else: else:
@ -850,15 +879,12 @@ class PPModelWorker:
if cur_batch is not None: if cur_batch is not None:
cur_batch = self.prepare_batch(cur_batch) cur_batch = self.prepare_batch(cur_batch)
dist.broadcast_object_list([cur_batch], src=0) dist.broadcast_object_list([cur_batch], src=0)
else:
await asyncio.sleep(0)
else: else:
if self.send_buff is not None:
# logger.info(f"send {self.rank} {self.send_buff.shape}")
dist.send(self.send_buff, dst=self.next_rank)
batch_list = [None] batch_list = [None]
dist.broadcast_object_list(batch_list, src=0) dist.broadcast_object_list(batch_list, src=0)
cur_batch = batch_list[0] cur_batch = batch_list[0]
cur_input = None cur_input = None
@ -882,10 +908,15 @@ class PPModelWorker:
) )
# logger.info(f"recv {self.rank} {cur_input.shape}") # logger.info(f"recv {self.rank} {cur_input.shape}")
dist.recv(cur_input, src=self.pre_rank) dist.recv(cur_input, src=self.pre_rank)
torch.xpu.synchronize(self.device)
output, cur_batch = self.model_step(cur_input, cur_batch) output, cur_batch = self.model_step(cur_input, cur_batch)
self.send_buff = output torch.xpu.synchronize(self.device)
if self.send_buff is not None:
self.send_buff.wait()
if output is not None:
self.send_buff = dist.isend(output, dst=self.next_rank)
if self.rank == 0: if self.rank == 0:
self.on_going_batches[:-1] = self.on_going_batches[1:] self.on_going_batches[:-1] = self.on_going_batches[1:]