Optimizations for Pipeline Parallel Serving (#11702)
Optimizations for Pipeline Parallel Serving
This commit is contained in:
parent
8d1e0bd2f4
commit
1baa3efe0e
2 changed files with 86 additions and 55 deletions
|
|
@ -250,7 +250,7 @@ async def generate_stream(inputs_request: InputsRequest):
|
|||
request_id = str(uuid.uuid4()) + "stream"
|
||||
await local_model.waiting_requests.put((request_id, inputs_request))
|
||||
while True:
|
||||
await asyncio.sleep(0)
|
||||
await asyncio.sleep(0.1)
|
||||
cur_streamer = local_model.streamer.get(request_id, None)
|
||||
if cur_streamer is not None:
|
||||
if inputs_request.req_type == 'completion':
|
||||
|
|
|
|||
|
|
@ -25,7 +25,7 @@ import torch.distributed as dist
|
|||
import os
|
||||
import time
|
||||
import numpy as np
|
||||
from typing import Callable, List, Optional, Union, Tuple
|
||||
from typing import Callable, List, Optional, Union, Tuple, Any
|
||||
from types import SimpleNamespace
|
||||
import transformers
|
||||
from transformers import GenerationConfig, LogitsProcessorList, StoppingCriteriaList
|
||||
|
|
@ -37,6 +37,7 @@ logger = logging.getLogger(__name__)
|
|||
import asyncio
|
||||
import uuid
|
||||
import threading
|
||||
import pickle
|
||||
try:
|
||||
from pydantic import BaseModel
|
||||
except ImportError:
|
||||
|
|
@ -513,6 +514,8 @@ class PPModelWorker:
|
|||
self.max_prefilled_seqs = max_prefilled_seqs
|
||||
self.partial_output_dict = {}
|
||||
|
||||
self.stream_tasks = {}
|
||||
|
||||
def load_model(self, model_path, world_size, low_bit='sym_int4'):
|
||||
from ipex_llm.transformers import AutoModelForCausalLM, AutoModel
|
||||
try:
|
||||
|
|
@ -683,16 +686,14 @@ class PPModelWorker:
|
|||
_output = output[0]
|
||||
if _output.dtype != self.dtype:
|
||||
_output = _output.to(self.dtype)
|
||||
return _output, cur_batch
|
||||
else:
|
||||
if cur_batch.partial_prefilling > 0 and \
|
||||
cur_batch.prefilled_index == cur_batch.batch_size:
|
||||
_output = self.partial_output_dict.pop(cur_id, None)
|
||||
cur_batch.partial_prefilling = 0
|
||||
return _output, cur_batch
|
||||
else:
|
||||
_output = torch.argmax(output.logits[:, -1:, :], dim=-1)
|
||||
return _output, cur_batch
|
||||
return _output, cur_batch
|
||||
|
||||
def is_initialized(self):
|
||||
return True
|
||||
|
|
@ -738,14 +739,71 @@ class PPModelWorker:
|
|||
self.is_finish.pop(cur_id, None)
|
||||
self.partial_output_dict.pop(cur_id, None)
|
||||
|
||||
async def wait_stream_output(self, cur_id):
|
||||
cur_task = self.stream_tasks.pop(cur_id, None)
|
||||
if cur_task is not None:
|
||||
await cur_task
|
||||
|
||||
def get_printable_text(self, cur_text, request_id):
|
||||
if cur_text.endswith("\n"):
|
||||
printable_text = cur_text[self.print_len[request_id]:]
|
||||
self.token_cache[request_id] = []
|
||||
self.print_len[request_id] = 0
|
||||
elif len(cur_text) > 0 and _is_chinese_char(ord(cur_text[-1])):
|
||||
printable_text = cur_text[self.print_len[request_id]:]
|
||||
self.print_len[request_id] += len(printable_text)
|
||||
self.token_cache[request_id] = []
|
||||
self.print_len[request_id] = 0
|
||||
else:
|
||||
r_index = cur_text.rfind(" ") + 1
|
||||
if r_index > self.print_len[request_id]:
|
||||
printable_text = cur_text[self.print_len[request_id]: r_index]
|
||||
self.token_cache[request_id] = self.token_cache[request_id][-1:]
|
||||
self.print_len[request_id] = 0
|
||||
else:
|
||||
printable_text = cur_text[self.print_len[request_id]: r_index]
|
||||
return printable_text
|
||||
|
||||
async def stream_output(self, cur_batch, tokenizer, next_ids):
|
||||
cur_id = cur_batch.batch_id
|
||||
cur_cached_ids = []
|
||||
_stream_tasks = []
|
||||
for index, request_id in enumerate(cur_batch.request_ids):
|
||||
if not self.is_finish.get(request_id, False):
|
||||
if self.token_cache.get(request_id, None) is None:
|
||||
self.token_cache[request_id] = []
|
||||
self.print_len[request_id] = 0
|
||||
self.token_cache[request_id].extend(next_ids[index].tolist())
|
||||
cur_cached_ids.append(self.token_cache[request_id])
|
||||
|
||||
for index, request_id in enumerate(cur_batch.request_ids):
|
||||
if not self.is_finish.get(request_id, False):
|
||||
remain = cur_batch.max_tokens - len(self.tokens[cur_id])
|
||||
|
||||
if self.streamer.get(request_id, None) is None:
|
||||
self.streamer[request_id] = asyncio.Queue()
|
||||
|
||||
# Currently ignore eos for benchmark
|
||||
# if next_ids[index].int() == tokenizer.eos_token_id:
|
||||
# remain = 0
|
||||
# self.is_finish[request_id] = True
|
||||
|
||||
cur_text = tokenizer.decode(self.token_cache[request_id])
|
||||
printable_text = self.get_printable_text(cur_text, request_id)
|
||||
|
||||
if remain > 0:
|
||||
_stream_tasks.append(self.streamer[request_id].put((remain, printable_text)))
|
||||
else:
|
||||
printable_text = printable_text + cur_text[self.print_len[request_id]:]
|
||||
self.token_cache.pop(request_id, None)
|
||||
self.print_len.pop(request_id, None)
|
||||
_stream_tasks.append(self.streamer[request_id].put((remain, printable_text)))
|
||||
await asyncio.gather(*_stream_tasks)
|
||||
|
||||
async def process_step(self, tokenizer, result_dict):
|
||||
cur_batch = None
|
||||
|
||||
torch.xpu.synchronize(self.device)
|
||||
if self.rank == 0:
|
||||
if self.send_buff is not None:
|
||||
# logger.info(f"send {self.rank} {self.send_buff.shape}")
|
||||
dist.send(self.send_buff, dst=self.next_rank)
|
||||
|
||||
if self.on_going_batches[0] is not None:
|
||||
cur_batch = self.on_going_batches[0]
|
||||
cur_input = None
|
||||
|
|
@ -773,13 +831,13 @@ class PPModelWorker:
|
|||
|
||||
# logger.info(f"recv {self.rank} {next_ids.shape}")
|
||||
dist.recv(next_ids, src=self.pre_rank)
|
||||
torch.xpu.synchronize(self.device)
|
||||
|
||||
if cur_batch.partial_prefilling > 0:
|
||||
cur_input = self.input_ids_dict[cur_batch.batch_id]
|
||||
else:
|
||||
if self.tokens.get(cur_id, None) is None:
|
||||
self.tokens[cur_id] = []
|
||||
|
||||
if len(next_ids.shape) == 1:
|
||||
next_ids = next_ids.unsqueeze(0)
|
||||
self.tokens[cur_id].append(next_ids)
|
||||
|
|
@ -788,44 +846,14 @@ class PPModelWorker:
|
|||
cur_batch.input_len = 1
|
||||
cur_batch.prompt_lengths = [x + 1 for x in cur_batch.prompt_lengths]
|
||||
|
||||
for index, request_id in enumerate(cur_batch.request_ids):
|
||||
|
||||
if not self.is_finish.get(request_id, False):
|
||||
remain = cur_batch.max_tokens - len(self.tokens[cur_id])
|
||||
|
||||
if self.streamer.get(request_id, None) is None:
|
||||
self.streamer[request_id] = asyncio.Queue()
|
||||
|
||||
# Currently ignore eos for benchmark
|
||||
# if next_ids[index].int() == tokenizer.eos_token_id:
|
||||
# remain = 0
|
||||
# self.is_finish[request_id] = True
|
||||
|
||||
if self.token_cache.get(request_id, None) is None:
|
||||
self.token_cache[request_id] = []
|
||||
self.print_len[request_id] = 0
|
||||
self.token_cache[request_id].extend(next_ids[index].tolist())
|
||||
|
||||
text = tokenizer.decode(self.token_cache[request_id])
|
||||
if text.endswith("\n"):
|
||||
printable_text = text[self.print_len[request_id]:]
|
||||
self.token_cache[request_id] = []
|
||||
self.print_len[request_id] = 0
|
||||
elif len(text) > 0 and _is_chinese_char(ord(text[-1])):
|
||||
printable_text = text[self.print_len[request_id]:]
|
||||
self.print_len[request_id] += len(printable_text)
|
||||
else:
|
||||
r_index = text.rfind(" ") + 1
|
||||
printable_text = text[self.print_len[request_id]: r_index]
|
||||
self.print_len[request_id] += len(printable_text)
|
||||
|
||||
if remain > 0:
|
||||
await self.streamer[request_id].put((remain, printable_text))
|
||||
else:
|
||||
printable_text = printable_text + text[self.print_len[request_id]:]
|
||||
self.token_cache.pop(request_id, None)
|
||||
self.print_len.pop(request_id, None)
|
||||
await self.streamer[request_id].put((remain, printable_text))
|
||||
pre_task = self.stream_tasks.get(cur_id)
|
||||
if pre_task is not None:
|
||||
await pre_task
|
||||
del self.stream_tasks[cur_id]
|
||||
cur_task = asyncio.create_task(
|
||||
self.stream_output(cur_batch, tokenizer, next_ids)
|
||||
)
|
||||
self.stream_tasks[cur_id] = cur_task
|
||||
|
||||
if len(self.tokens[cur_id]) >= cur_batch.max_tokens:
|
||||
# Finish a batch
|
||||
|
|
@ -841,6 +869,7 @@ class PPModelWorker:
|
|||
next_token = (cur_times[-1] - cur_times[1]) / (len(self.tokens[cur_id]) - 1)
|
||||
logger.info(f"First token latency: {first_token}, "
|
||||
f"next token latency: {next_token}")
|
||||
await self.wait_stream_output(cur_id)
|
||||
self.clear_batch(cur_id)
|
||||
cur_batch.stopped = True
|
||||
else:
|
||||
|
|
@ -850,15 +879,12 @@ class PPModelWorker:
|
|||
if cur_batch is not None:
|
||||
cur_batch = self.prepare_batch(cur_batch)
|
||||
dist.broadcast_object_list([cur_batch], src=0)
|
||||
else:
|
||||
await asyncio.sleep(0)
|
||||
|
||||
else:
|
||||
if self.send_buff is not None:
|
||||
# logger.info(f"send {self.rank} {self.send_buff.shape}")
|
||||
dist.send(self.send_buff, dst=self.next_rank)
|
||||
|
||||
batch_list = [None]
|
||||
dist.broadcast_object_list(batch_list, src=0)
|
||||
|
||||
cur_batch = batch_list[0]
|
||||
cur_input = None
|
||||
|
||||
|
|
@ -882,10 +908,15 @@ class PPModelWorker:
|
|||
)
|
||||
# logger.info(f"recv {self.rank} {cur_input.shape}")
|
||||
dist.recv(cur_input, src=self.pre_rank)
|
||||
torch.xpu.synchronize(self.device)
|
||||
|
||||
output, cur_batch = self.model_step(cur_input, cur_batch)
|
||||
|
||||
self.send_buff = output
|
||||
torch.xpu.synchronize(self.device)
|
||||
if self.send_buff is not None:
|
||||
self.send_buff.wait()
|
||||
if output is not None:
|
||||
self.send_buff = dist.isend(output, dst=self.next_rank)
|
||||
|
||||
if self.rank == 0:
|
||||
self.on_going_batches[:-1] = self.on_going_batches[1:]
|
||||
|
|
|
|||
Loading…
Reference in a new issue