Optimizations for Pipeline Parallel Serving (#11702)
Optimizations for Pipeline Parallel Serving
This commit is contained in:
		
							parent
							
								
									8d1e0bd2f4
								
							
						
					
					
						commit
						1baa3efe0e
					
				
					 2 changed files with 86 additions and 55 deletions
				
			
		| 
						 | 
				
			
			@ -250,7 +250,7 @@ async def generate_stream(inputs_request: InputsRequest):
 | 
			
		|||
    request_id = str(uuid.uuid4()) + "stream"
 | 
			
		||||
    await local_model.waiting_requests.put((request_id, inputs_request))
 | 
			
		||||
    while True:
 | 
			
		||||
        await asyncio.sleep(0)
 | 
			
		||||
        await asyncio.sleep(0.1)
 | 
			
		||||
        cur_streamer = local_model.streamer.get(request_id, None)
 | 
			
		||||
        if cur_streamer is not None:
 | 
			
		||||
            if inputs_request.req_type == 'completion':
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -25,7 +25,7 @@ import torch.distributed as dist
 | 
			
		|||
import os
 | 
			
		||||
import time
 | 
			
		||||
import numpy as np
 | 
			
		||||
from typing import Callable, List, Optional, Union, Tuple
 | 
			
		||||
from typing import Callable, List, Optional, Union, Tuple, Any
 | 
			
		||||
from types import SimpleNamespace
 | 
			
		||||
import transformers
 | 
			
		||||
from transformers import GenerationConfig, LogitsProcessorList, StoppingCriteriaList
 | 
			
		||||
| 
						 | 
				
			
			@ -37,6 +37,7 @@ logger = logging.getLogger(__name__)
 | 
			
		|||
import asyncio
 | 
			
		||||
import uuid
 | 
			
		||||
import threading
 | 
			
		||||
import pickle
 | 
			
		||||
try:
 | 
			
		||||
    from pydantic import BaseModel
 | 
			
		||||
except ImportError:
 | 
			
		||||
| 
						 | 
				
			
			@ -513,6 +514,8 @@ class PPModelWorker:
 | 
			
		|||
        self.max_prefilled_seqs = max_prefilled_seqs
 | 
			
		||||
        self.partial_output_dict = {}
 | 
			
		||||
 | 
			
		||||
        self.stream_tasks = {}
 | 
			
		||||
 | 
			
		||||
    def load_model(self, model_path, world_size, low_bit='sym_int4'):
 | 
			
		||||
        from ipex_llm.transformers import AutoModelForCausalLM, AutoModel
 | 
			
		||||
        try:
 | 
			
		||||
| 
						 | 
				
			
			@ -683,13 +686,11 @@ class PPModelWorker:
 | 
			
		|||
            _output = output[0]
 | 
			
		||||
            if _output.dtype != self.dtype:
 | 
			
		||||
                _output = _output.to(self.dtype)
 | 
			
		||||
            return _output, cur_batch
 | 
			
		||||
        else:
 | 
			
		||||
            if cur_batch.partial_prefilling > 0 and \
 | 
			
		||||
               cur_batch.prefilled_index == cur_batch.batch_size:
 | 
			
		||||
                _output = self.partial_output_dict.pop(cur_id, None)
 | 
			
		||||
                cur_batch.partial_prefilling = 0
 | 
			
		||||
                return _output, cur_batch
 | 
			
		||||
            else:
 | 
			
		||||
                _output = torch.argmax(output.logits[:, -1:, :], dim=-1)
 | 
			
		||||
        return _output, cur_batch
 | 
			
		||||
| 
						 | 
				
			
			@ -738,14 +739,71 @@ class PPModelWorker:
 | 
			
		|||
        self.is_finish.pop(cur_id, None)
 | 
			
		||||
        self.partial_output_dict.pop(cur_id, None)
 | 
			
		||||
 | 
			
		||||
    async def wait_stream_output(self, cur_id):
 | 
			
		||||
        cur_task = self.stream_tasks.pop(cur_id, None)
 | 
			
		||||
        if cur_task is not None:
 | 
			
		||||
            await cur_task
 | 
			
		||||
 | 
			
		||||
    def get_printable_text(self, cur_text, request_id):
 | 
			
		||||
        if cur_text.endswith("\n"):
 | 
			
		||||
            printable_text = cur_text[self.print_len[request_id]:]
 | 
			
		||||
            self.token_cache[request_id] = []
 | 
			
		||||
            self.print_len[request_id] = 0
 | 
			
		||||
        elif len(cur_text) > 0 and _is_chinese_char(ord(cur_text[-1])):
 | 
			
		||||
            printable_text = cur_text[self.print_len[request_id]:]
 | 
			
		||||
            self.print_len[request_id] += len(printable_text)
 | 
			
		||||
            self.token_cache[request_id] = []
 | 
			
		||||
            self.print_len[request_id] = 0
 | 
			
		||||
        else:
 | 
			
		||||
            r_index = cur_text.rfind(" ") + 1
 | 
			
		||||
            if r_index > self.print_len[request_id]:
 | 
			
		||||
                printable_text = cur_text[self.print_len[request_id]: r_index]
 | 
			
		||||
                self.token_cache[request_id] = self.token_cache[request_id][-1:]
 | 
			
		||||
                self.print_len[request_id] = 0
 | 
			
		||||
            else:
 | 
			
		||||
                printable_text = cur_text[self.print_len[request_id]: r_index]
 | 
			
		||||
        return printable_text
 | 
			
		||||
 | 
			
		||||
    async def stream_output(self, cur_batch, tokenizer, next_ids):
 | 
			
		||||
        cur_id = cur_batch.batch_id
 | 
			
		||||
        cur_cached_ids = []
 | 
			
		||||
        _stream_tasks = []
 | 
			
		||||
        for index, request_id in enumerate(cur_batch.request_ids):
 | 
			
		||||
            if not self.is_finish.get(request_id, False):
 | 
			
		||||
                if self.token_cache.get(request_id, None) is None:
 | 
			
		||||
                    self.token_cache[request_id] = []
 | 
			
		||||
                    self.print_len[request_id] = 0
 | 
			
		||||
                self.token_cache[request_id].extend(next_ids[index].tolist())
 | 
			
		||||
                cur_cached_ids.append(self.token_cache[request_id])
 | 
			
		||||
 | 
			
		||||
        for index, request_id in enumerate(cur_batch.request_ids):
 | 
			
		||||
            if not self.is_finish.get(request_id, False):
 | 
			
		||||
                remain = cur_batch.max_tokens - len(self.tokens[cur_id])
 | 
			
		||||
 | 
			
		||||
                if self.streamer.get(request_id, None) is None:
 | 
			
		||||
                    self.streamer[request_id] = asyncio.Queue()
 | 
			
		||||
 | 
			
		||||
                # Currently ignore eos for benchmark
 | 
			
		||||
                # if next_ids[index].int() == tokenizer.eos_token_id:
 | 
			
		||||
                #     remain = 0
 | 
			
		||||
                #     self.is_finish[request_id] = True
 | 
			
		||||
 | 
			
		||||
                cur_text = tokenizer.decode(self.token_cache[request_id])
 | 
			
		||||
                printable_text = self.get_printable_text(cur_text, request_id)
 | 
			
		||||
 | 
			
		||||
                if remain > 0:
 | 
			
		||||
                    _stream_tasks.append(self.streamer[request_id].put((remain, printable_text)))
 | 
			
		||||
                else:
 | 
			
		||||
                    printable_text = printable_text + cur_text[self.print_len[request_id]:]
 | 
			
		||||
                    self.token_cache.pop(request_id, None)
 | 
			
		||||
                    self.print_len.pop(request_id, None)
 | 
			
		||||
                    _stream_tasks.append(self.streamer[request_id].put((remain, printable_text)))
 | 
			
		||||
        await asyncio.gather(*_stream_tasks)
 | 
			
		||||
 | 
			
		||||
    async def process_step(self, tokenizer, result_dict):
 | 
			
		||||
        cur_batch = None
 | 
			
		||||
 | 
			
		||||
        torch.xpu.synchronize(self.device)
 | 
			
		||||
        if self.rank == 0:
 | 
			
		||||
            if self.send_buff is not None:
 | 
			
		||||
                # logger.info(f"send {self.rank} {self.send_buff.shape}")
 | 
			
		||||
                dist.send(self.send_buff, dst=self.next_rank)
 | 
			
		||||
 | 
			
		||||
            if self.on_going_batches[0] is not None:
 | 
			
		||||
                cur_batch = self.on_going_batches[0]
 | 
			
		||||
                cur_input = None
 | 
			
		||||
| 
						 | 
				
			
			@ -773,13 +831,13 @@ class PPModelWorker:
 | 
			
		|||
 | 
			
		||||
                # logger.info(f"recv {self.rank} {next_ids.shape}")
 | 
			
		||||
                dist.recv(next_ids, src=self.pre_rank)
 | 
			
		||||
                torch.xpu.synchronize(self.device)
 | 
			
		||||
 | 
			
		||||
                if cur_batch.partial_prefilling > 0:
 | 
			
		||||
                    cur_input = self.input_ids_dict[cur_batch.batch_id]
 | 
			
		||||
                else:
 | 
			
		||||
                    if self.tokens.get(cur_id, None) is None:
 | 
			
		||||
                        self.tokens[cur_id] = []
 | 
			
		||||
 | 
			
		||||
                    if len(next_ids.shape) == 1:
 | 
			
		||||
                        next_ids = next_ids.unsqueeze(0)
 | 
			
		||||
                    self.tokens[cur_id].append(next_ids)
 | 
			
		||||
| 
						 | 
				
			
			@ -788,44 +846,14 @@ class PPModelWorker:
 | 
			
		|||
                    cur_batch.input_len = 1
 | 
			
		||||
                    cur_batch.prompt_lengths = [x + 1 for x in cur_batch.prompt_lengths]
 | 
			
		||||
 | 
			
		||||
                    for index, request_id in enumerate(cur_batch.request_ids):
 | 
			
		||||
 | 
			
		||||
                        if not self.is_finish.get(request_id, False):
 | 
			
		||||
                            remain = cur_batch.max_tokens - len(self.tokens[cur_id])
 | 
			
		||||
 | 
			
		||||
                            if self.streamer.get(request_id, None) is None:
 | 
			
		||||
                                self.streamer[request_id] = asyncio.Queue()
 | 
			
		||||
 | 
			
		||||
                            # Currently ignore eos for benchmark
 | 
			
		||||
                            # if next_ids[index].int() == tokenizer.eos_token_id:
 | 
			
		||||
                            #     remain = 0
 | 
			
		||||
                            #     self.is_finish[request_id] = True
 | 
			
		||||
 | 
			
		||||
                            if self.token_cache.get(request_id, None) is None:
 | 
			
		||||
                                self.token_cache[request_id] = []
 | 
			
		||||
                                self.print_len[request_id] = 0
 | 
			
		||||
                            self.token_cache[request_id].extend(next_ids[index].tolist())
 | 
			
		||||
 | 
			
		||||
                            text = tokenizer.decode(self.token_cache[request_id])
 | 
			
		||||
                            if text.endswith("\n"):
 | 
			
		||||
                                printable_text = text[self.print_len[request_id]:]
 | 
			
		||||
                                self.token_cache[request_id] = []
 | 
			
		||||
                                self.print_len[request_id] = 0
 | 
			
		||||
                            elif len(text) > 0 and _is_chinese_char(ord(text[-1])):
 | 
			
		||||
                                printable_text = text[self.print_len[request_id]:]
 | 
			
		||||
                                self.print_len[request_id] += len(printable_text)
 | 
			
		||||
                            else:
 | 
			
		||||
                                r_index = text.rfind(" ") + 1
 | 
			
		||||
                                printable_text = text[self.print_len[request_id]: r_index]
 | 
			
		||||
                                self.print_len[request_id] += len(printable_text)
 | 
			
		||||
 | 
			
		||||
                            if remain > 0:
 | 
			
		||||
                                await self.streamer[request_id].put((remain, printable_text))
 | 
			
		||||
                            else:
 | 
			
		||||
                                printable_text = printable_text + text[self.print_len[request_id]:]
 | 
			
		||||
                                self.token_cache.pop(request_id, None)
 | 
			
		||||
                                self.print_len.pop(request_id, None)
 | 
			
		||||
                                await self.streamer[request_id].put((remain, printable_text))
 | 
			
		||||
                    pre_task = self.stream_tasks.get(cur_id)
 | 
			
		||||
                    if pre_task is not None:
 | 
			
		||||
                        await pre_task
 | 
			
		||||
                        del self.stream_tasks[cur_id]
 | 
			
		||||
                    cur_task = asyncio.create_task(
 | 
			
		||||
                        self.stream_output(cur_batch, tokenizer, next_ids)
 | 
			
		||||
                    )
 | 
			
		||||
                    self.stream_tasks[cur_id] = cur_task
 | 
			
		||||
 | 
			
		||||
                    if len(self.tokens[cur_id]) >= cur_batch.max_tokens:
 | 
			
		||||
                        # Finish a batch
 | 
			
		||||
| 
						 | 
				
			
			@ -841,6 +869,7 @@ class PPModelWorker:
 | 
			
		|||
                        next_token = (cur_times[-1] - cur_times[1]) / (len(self.tokens[cur_id]) - 1)
 | 
			
		||||
                        logger.info(f"First token latency: {first_token}, "
 | 
			
		||||
                                    f"next token latency: {next_token}")
 | 
			
		||||
                        await self.wait_stream_output(cur_id)
 | 
			
		||||
                        self.clear_batch(cur_id)
 | 
			
		||||
                        cur_batch.stopped = True
 | 
			
		||||
            else:
 | 
			
		||||
| 
						 | 
				
			
			@ -850,15 +879,12 @@ class PPModelWorker:
 | 
			
		|||
            if cur_batch is not None:
 | 
			
		||||
                cur_batch = self.prepare_batch(cur_batch)
 | 
			
		||||
                dist.broadcast_object_list([cur_batch], src=0)
 | 
			
		||||
            else:
 | 
			
		||||
                await asyncio.sleep(0)
 | 
			
		||||
 | 
			
		||||
        else:
 | 
			
		||||
            if self.send_buff is not None:
 | 
			
		||||
                # logger.info(f"send {self.rank} {self.send_buff.shape}")
 | 
			
		||||
                dist.send(self.send_buff, dst=self.next_rank)
 | 
			
		||||
 | 
			
		||||
            batch_list = [None]
 | 
			
		||||
            dist.broadcast_object_list(batch_list, src=0)
 | 
			
		||||
 | 
			
		||||
            cur_batch = batch_list[0]
 | 
			
		||||
            cur_input = None
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -882,10 +908,15 @@ class PPModelWorker:
 | 
			
		|||
                        )
 | 
			
		||||
                    # logger.info(f"recv {self.rank} {cur_input.shape}")
 | 
			
		||||
                    dist.recv(cur_input, src=self.pre_rank)
 | 
			
		||||
                    torch.xpu.synchronize(self.device)
 | 
			
		||||
 | 
			
		||||
        output, cur_batch = self.model_step(cur_input, cur_batch)
 | 
			
		||||
 | 
			
		||||
        self.send_buff = output
 | 
			
		||||
        torch.xpu.synchronize(self.device)
 | 
			
		||||
        if self.send_buff is not None:
 | 
			
		||||
            self.send_buff.wait()
 | 
			
		||||
        if output is not None:
 | 
			
		||||
            self.send_buff = dist.isend(output, dst=self.next_rank)
 | 
			
		||||
 | 
			
		||||
        if self.rank == 0:
 | 
			
		||||
            self.on_going_batches[:-1] = self.on_going_batches[1:]
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue