Optimizations for Pipeline Parallel Serving (#11702)
Optimizations for Pipeline Parallel Serving
This commit is contained in:
parent
8d1e0bd2f4
commit
1baa3efe0e
2 changed files with 86 additions and 55 deletions
|
|
@ -250,7 +250,7 @@ async def generate_stream(inputs_request: InputsRequest):
|
||||||
request_id = str(uuid.uuid4()) + "stream"
|
request_id = str(uuid.uuid4()) + "stream"
|
||||||
await local_model.waiting_requests.put((request_id, inputs_request))
|
await local_model.waiting_requests.put((request_id, inputs_request))
|
||||||
while True:
|
while True:
|
||||||
await asyncio.sleep(0)
|
await asyncio.sleep(0.1)
|
||||||
cur_streamer = local_model.streamer.get(request_id, None)
|
cur_streamer = local_model.streamer.get(request_id, None)
|
||||||
if cur_streamer is not None:
|
if cur_streamer is not None:
|
||||||
if inputs_request.req_type == 'completion':
|
if inputs_request.req_type == 'completion':
|
||||||
|
|
|
||||||
|
|
@ -25,7 +25,7 @@ import torch.distributed as dist
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from typing import Callable, List, Optional, Union, Tuple
|
from typing import Callable, List, Optional, Union, Tuple, Any
|
||||||
from types import SimpleNamespace
|
from types import SimpleNamespace
|
||||||
import transformers
|
import transformers
|
||||||
from transformers import GenerationConfig, LogitsProcessorList, StoppingCriteriaList
|
from transformers import GenerationConfig, LogitsProcessorList, StoppingCriteriaList
|
||||||
|
|
@ -37,6 +37,7 @@ logger = logging.getLogger(__name__)
|
||||||
import asyncio
|
import asyncio
|
||||||
import uuid
|
import uuid
|
||||||
import threading
|
import threading
|
||||||
|
import pickle
|
||||||
try:
|
try:
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
except ImportError:
|
except ImportError:
|
||||||
|
|
@ -513,6 +514,8 @@ class PPModelWorker:
|
||||||
self.max_prefilled_seqs = max_prefilled_seqs
|
self.max_prefilled_seqs = max_prefilled_seqs
|
||||||
self.partial_output_dict = {}
|
self.partial_output_dict = {}
|
||||||
|
|
||||||
|
self.stream_tasks = {}
|
||||||
|
|
||||||
def load_model(self, model_path, world_size, low_bit='sym_int4'):
|
def load_model(self, model_path, world_size, low_bit='sym_int4'):
|
||||||
from ipex_llm.transformers import AutoModelForCausalLM, AutoModel
|
from ipex_llm.transformers import AutoModelForCausalLM, AutoModel
|
||||||
try:
|
try:
|
||||||
|
|
@ -683,16 +686,14 @@ class PPModelWorker:
|
||||||
_output = output[0]
|
_output = output[0]
|
||||||
if _output.dtype != self.dtype:
|
if _output.dtype != self.dtype:
|
||||||
_output = _output.to(self.dtype)
|
_output = _output.to(self.dtype)
|
||||||
return _output, cur_batch
|
|
||||||
else:
|
else:
|
||||||
if cur_batch.partial_prefilling > 0 and \
|
if cur_batch.partial_prefilling > 0 and \
|
||||||
cur_batch.prefilled_index == cur_batch.batch_size:
|
cur_batch.prefilled_index == cur_batch.batch_size:
|
||||||
_output = self.partial_output_dict.pop(cur_id, None)
|
_output = self.partial_output_dict.pop(cur_id, None)
|
||||||
cur_batch.partial_prefilling = 0
|
cur_batch.partial_prefilling = 0
|
||||||
return _output, cur_batch
|
|
||||||
else:
|
else:
|
||||||
_output = torch.argmax(output.logits[:, -1:, :], dim=-1)
|
_output = torch.argmax(output.logits[:, -1:, :], dim=-1)
|
||||||
return _output, cur_batch
|
return _output, cur_batch
|
||||||
|
|
||||||
def is_initialized(self):
|
def is_initialized(self):
|
||||||
return True
|
return True
|
||||||
|
|
@ -738,14 +739,71 @@ class PPModelWorker:
|
||||||
self.is_finish.pop(cur_id, None)
|
self.is_finish.pop(cur_id, None)
|
||||||
self.partial_output_dict.pop(cur_id, None)
|
self.partial_output_dict.pop(cur_id, None)
|
||||||
|
|
||||||
|
async def wait_stream_output(self, cur_id):
|
||||||
|
cur_task = self.stream_tasks.pop(cur_id, None)
|
||||||
|
if cur_task is not None:
|
||||||
|
await cur_task
|
||||||
|
|
||||||
|
def get_printable_text(self, cur_text, request_id):
|
||||||
|
if cur_text.endswith("\n"):
|
||||||
|
printable_text = cur_text[self.print_len[request_id]:]
|
||||||
|
self.token_cache[request_id] = []
|
||||||
|
self.print_len[request_id] = 0
|
||||||
|
elif len(cur_text) > 0 and _is_chinese_char(ord(cur_text[-1])):
|
||||||
|
printable_text = cur_text[self.print_len[request_id]:]
|
||||||
|
self.print_len[request_id] += len(printable_text)
|
||||||
|
self.token_cache[request_id] = []
|
||||||
|
self.print_len[request_id] = 0
|
||||||
|
else:
|
||||||
|
r_index = cur_text.rfind(" ") + 1
|
||||||
|
if r_index > self.print_len[request_id]:
|
||||||
|
printable_text = cur_text[self.print_len[request_id]: r_index]
|
||||||
|
self.token_cache[request_id] = self.token_cache[request_id][-1:]
|
||||||
|
self.print_len[request_id] = 0
|
||||||
|
else:
|
||||||
|
printable_text = cur_text[self.print_len[request_id]: r_index]
|
||||||
|
return printable_text
|
||||||
|
|
||||||
|
async def stream_output(self, cur_batch, tokenizer, next_ids):
|
||||||
|
cur_id = cur_batch.batch_id
|
||||||
|
cur_cached_ids = []
|
||||||
|
_stream_tasks = []
|
||||||
|
for index, request_id in enumerate(cur_batch.request_ids):
|
||||||
|
if not self.is_finish.get(request_id, False):
|
||||||
|
if self.token_cache.get(request_id, None) is None:
|
||||||
|
self.token_cache[request_id] = []
|
||||||
|
self.print_len[request_id] = 0
|
||||||
|
self.token_cache[request_id].extend(next_ids[index].tolist())
|
||||||
|
cur_cached_ids.append(self.token_cache[request_id])
|
||||||
|
|
||||||
|
for index, request_id in enumerate(cur_batch.request_ids):
|
||||||
|
if not self.is_finish.get(request_id, False):
|
||||||
|
remain = cur_batch.max_tokens - len(self.tokens[cur_id])
|
||||||
|
|
||||||
|
if self.streamer.get(request_id, None) is None:
|
||||||
|
self.streamer[request_id] = asyncio.Queue()
|
||||||
|
|
||||||
|
# Currently ignore eos for benchmark
|
||||||
|
# if next_ids[index].int() == tokenizer.eos_token_id:
|
||||||
|
# remain = 0
|
||||||
|
# self.is_finish[request_id] = True
|
||||||
|
|
||||||
|
cur_text = tokenizer.decode(self.token_cache[request_id])
|
||||||
|
printable_text = self.get_printable_text(cur_text, request_id)
|
||||||
|
|
||||||
|
if remain > 0:
|
||||||
|
_stream_tasks.append(self.streamer[request_id].put((remain, printable_text)))
|
||||||
|
else:
|
||||||
|
printable_text = printable_text + cur_text[self.print_len[request_id]:]
|
||||||
|
self.token_cache.pop(request_id, None)
|
||||||
|
self.print_len.pop(request_id, None)
|
||||||
|
_stream_tasks.append(self.streamer[request_id].put((remain, printable_text)))
|
||||||
|
await asyncio.gather(*_stream_tasks)
|
||||||
|
|
||||||
async def process_step(self, tokenizer, result_dict):
|
async def process_step(self, tokenizer, result_dict):
|
||||||
cur_batch = None
|
cur_batch = None
|
||||||
|
torch.xpu.synchronize(self.device)
|
||||||
if self.rank == 0:
|
if self.rank == 0:
|
||||||
if self.send_buff is not None:
|
|
||||||
# logger.info(f"send {self.rank} {self.send_buff.shape}")
|
|
||||||
dist.send(self.send_buff, dst=self.next_rank)
|
|
||||||
|
|
||||||
if self.on_going_batches[0] is not None:
|
if self.on_going_batches[0] is not None:
|
||||||
cur_batch = self.on_going_batches[0]
|
cur_batch = self.on_going_batches[0]
|
||||||
cur_input = None
|
cur_input = None
|
||||||
|
|
@ -773,13 +831,13 @@ class PPModelWorker:
|
||||||
|
|
||||||
# logger.info(f"recv {self.rank} {next_ids.shape}")
|
# logger.info(f"recv {self.rank} {next_ids.shape}")
|
||||||
dist.recv(next_ids, src=self.pre_rank)
|
dist.recv(next_ids, src=self.pre_rank)
|
||||||
|
torch.xpu.synchronize(self.device)
|
||||||
|
|
||||||
if cur_batch.partial_prefilling > 0:
|
if cur_batch.partial_prefilling > 0:
|
||||||
cur_input = self.input_ids_dict[cur_batch.batch_id]
|
cur_input = self.input_ids_dict[cur_batch.batch_id]
|
||||||
else:
|
else:
|
||||||
if self.tokens.get(cur_id, None) is None:
|
if self.tokens.get(cur_id, None) is None:
|
||||||
self.tokens[cur_id] = []
|
self.tokens[cur_id] = []
|
||||||
|
|
||||||
if len(next_ids.shape) == 1:
|
if len(next_ids.shape) == 1:
|
||||||
next_ids = next_ids.unsqueeze(0)
|
next_ids = next_ids.unsqueeze(0)
|
||||||
self.tokens[cur_id].append(next_ids)
|
self.tokens[cur_id].append(next_ids)
|
||||||
|
|
@ -788,44 +846,14 @@ class PPModelWorker:
|
||||||
cur_batch.input_len = 1
|
cur_batch.input_len = 1
|
||||||
cur_batch.prompt_lengths = [x + 1 for x in cur_batch.prompt_lengths]
|
cur_batch.prompt_lengths = [x + 1 for x in cur_batch.prompt_lengths]
|
||||||
|
|
||||||
for index, request_id in enumerate(cur_batch.request_ids):
|
pre_task = self.stream_tasks.get(cur_id)
|
||||||
|
if pre_task is not None:
|
||||||
if not self.is_finish.get(request_id, False):
|
await pre_task
|
||||||
remain = cur_batch.max_tokens - len(self.tokens[cur_id])
|
del self.stream_tasks[cur_id]
|
||||||
|
cur_task = asyncio.create_task(
|
||||||
if self.streamer.get(request_id, None) is None:
|
self.stream_output(cur_batch, tokenizer, next_ids)
|
||||||
self.streamer[request_id] = asyncio.Queue()
|
)
|
||||||
|
self.stream_tasks[cur_id] = cur_task
|
||||||
# Currently ignore eos for benchmark
|
|
||||||
# if next_ids[index].int() == tokenizer.eos_token_id:
|
|
||||||
# remain = 0
|
|
||||||
# self.is_finish[request_id] = True
|
|
||||||
|
|
||||||
if self.token_cache.get(request_id, None) is None:
|
|
||||||
self.token_cache[request_id] = []
|
|
||||||
self.print_len[request_id] = 0
|
|
||||||
self.token_cache[request_id].extend(next_ids[index].tolist())
|
|
||||||
|
|
||||||
text = tokenizer.decode(self.token_cache[request_id])
|
|
||||||
if text.endswith("\n"):
|
|
||||||
printable_text = text[self.print_len[request_id]:]
|
|
||||||
self.token_cache[request_id] = []
|
|
||||||
self.print_len[request_id] = 0
|
|
||||||
elif len(text) > 0 and _is_chinese_char(ord(text[-1])):
|
|
||||||
printable_text = text[self.print_len[request_id]:]
|
|
||||||
self.print_len[request_id] += len(printable_text)
|
|
||||||
else:
|
|
||||||
r_index = text.rfind(" ") + 1
|
|
||||||
printable_text = text[self.print_len[request_id]: r_index]
|
|
||||||
self.print_len[request_id] += len(printable_text)
|
|
||||||
|
|
||||||
if remain > 0:
|
|
||||||
await self.streamer[request_id].put((remain, printable_text))
|
|
||||||
else:
|
|
||||||
printable_text = printable_text + text[self.print_len[request_id]:]
|
|
||||||
self.token_cache.pop(request_id, None)
|
|
||||||
self.print_len.pop(request_id, None)
|
|
||||||
await self.streamer[request_id].put((remain, printable_text))
|
|
||||||
|
|
||||||
if len(self.tokens[cur_id]) >= cur_batch.max_tokens:
|
if len(self.tokens[cur_id]) >= cur_batch.max_tokens:
|
||||||
# Finish a batch
|
# Finish a batch
|
||||||
|
|
@ -841,6 +869,7 @@ class PPModelWorker:
|
||||||
next_token = (cur_times[-1] - cur_times[1]) / (len(self.tokens[cur_id]) - 1)
|
next_token = (cur_times[-1] - cur_times[1]) / (len(self.tokens[cur_id]) - 1)
|
||||||
logger.info(f"First token latency: {first_token}, "
|
logger.info(f"First token latency: {first_token}, "
|
||||||
f"next token latency: {next_token}")
|
f"next token latency: {next_token}")
|
||||||
|
await self.wait_stream_output(cur_id)
|
||||||
self.clear_batch(cur_id)
|
self.clear_batch(cur_id)
|
||||||
cur_batch.stopped = True
|
cur_batch.stopped = True
|
||||||
else:
|
else:
|
||||||
|
|
@ -850,15 +879,12 @@ class PPModelWorker:
|
||||||
if cur_batch is not None:
|
if cur_batch is not None:
|
||||||
cur_batch = self.prepare_batch(cur_batch)
|
cur_batch = self.prepare_batch(cur_batch)
|
||||||
dist.broadcast_object_list([cur_batch], src=0)
|
dist.broadcast_object_list([cur_batch], src=0)
|
||||||
|
else:
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
if self.send_buff is not None:
|
|
||||||
# logger.info(f"send {self.rank} {self.send_buff.shape}")
|
|
||||||
dist.send(self.send_buff, dst=self.next_rank)
|
|
||||||
|
|
||||||
batch_list = [None]
|
batch_list = [None]
|
||||||
dist.broadcast_object_list(batch_list, src=0)
|
dist.broadcast_object_list(batch_list, src=0)
|
||||||
|
|
||||||
cur_batch = batch_list[0]
|
cur_batch = batch_list[0]
|
||||||
cur_input = None
|
cur_input = None
|
||||||
|
|
||||||
|
|
@ -882,10 +908,15 @@ class PPModelWorker:
|
||||||
)
|
)
|
||||||
# logger.info(f"recv {self.rank} {cur_input.shape}")
|
# logger.info(f"recv {self.rank} {cur_input.shape}")
|
||||||
dist.recv(cur_input, src=self.pre_rank)
|
dist.recv(cur_input, src=self.pre_rank)
|
||||||
|
torch.xpu.synchronize(self.device)
|
||||||
|
|
||||||
output, cur_batch = self.model_step(cur_input, cur_batch)
|
output, cur_batch = self.model_step(cur_input, cur_batch)
|
||||||
|
|
||||||
self.send_buff = output
|
torch.xpu.synchronize(self.device)
|
||||||
|
if self.send_buff is not None:
|
||||||
|
self.send_buff.wait()
|
||||||
|
if output is not None:
|
||||||
|
self.send_buff = dist.isend(output, dst=self.next_rank)
|
||||||
|
|
||||||
if self.rank == 0:
|
if self.rank == 0:
|
||||||
self.on_going_batches[:-1] = self.on_going_batches[1:]
|
self.on_going_batches[:-1] = self.on_going_batches[1:]
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue