From 63e95698eb1148bcf0c73653902488a335d1dd14 Mon Sep 17 00:00:00 2001 From: ZehuaCao <47251317+Romanticoseu@users.noreply.github.com> Date: Fri, 24 May 2024 17:16:14 +0800 Subject: [PATCH] [LLM]Reopen autotp generate_stream (#11120) * reopen autotp generate_stream * fix style error * update --- .../GPU/Deepspeed-AutoTP-FastAPI/README.md | 64 +++- .../GPU/Deepspeed-AutoTP-FastAPI/serving.py | 302 ++++++++++++++---- ...tart-deepspeed-autotp-ipex-llm-serving.sh} | 3 +- .../llm/src/ipex_llm/transformers/streamer.py | 114 +++++++ 4 files changed, 414 insertions(+), 69 deletions(-) rename python/llm/example/GPU/Deepspeed-AutoTP-FastAPI/{run_llama2_7b_chat_hf_arc_2_card.sh => start-deepspeed-autotp-ipex-llm-serving.sh} (92%) create mode 100644 python/llm/src/ipex_llm/transformers/streamer.py diff --git a/python/llm/example/GPU/Deepspeed-AutoTP-FastAPI/README.md b/python/llm/example/GPU/Deepspeed-AutoTP-FastAPI/README.md index f99c6731..04f11c10 100644 --- a/python/llm/example/GPU/Deepspeed-AutoTP-FastAPI/README.md +++ b/python/llm/example/GPU/Deepspeed-AutoTP-FastAPI/README.md @@ -58,6 +58,8 @@ If you successfully run the serving, you can get output like this: We can use `curl` to test serving api +#### generate() + ```bash # Set http_proxy and https_proxy to null to ensure that requests are not forwarded by a proxy. export http_proxy= @@ -77,10 +79,68 @@ And you should get output like this: ```json { - "generated_text": "What is AI? Artificial intelligence (AI) refers to the development of computer systems able to perform tasks that would normally require human intelligence, such as visual perception, speech", - "generate_time": "0.45149803161621094s" + "index": 0, + "message": { + "role": "assistant", + "content": "\n\nArtificial intelligence (AI) is a branch of computer science that deals with the creation of intelligent machines that can perform tasks that typically " + }, + "finish_reason": "stop" } +``` +#### generate_stream() +```bash +# Set http_proxy and https_proxy to null to ensure that requests are not forwarded by a proxy. +export http_proxy= +export https_proxy= + +curl -X 'POST' \ + 'http://127.0.0.1:8000/generate_stream/' \ + -H 'accept: application/json' \ + -H 'Content-Type: application/json' \ + -d '{ + "prompt": "What is AI?", + "n_predict": 32 +}' +``` + +And you should get output like this: +```json +{"index": 0, "message": {"role": "assistant", "content": "\n"}, "finish_reason": null} +{"index": 1, "message": {"role": "assistant", "content": "\n"}, "finish_reason": null} +{"index": 2, "message": {"role": "assistant", "content": ""}, "finish_reason": null} +{"index": 3, "message": {"role": "assistant", "content": ""}, "finish_reason": null} +{"index": 4, "message": {"role": "assistant", "content": ""}, "finish_reason": null} +{"index": 5, "message": {"role": "assistant", "content": "Artificial "}, "finish_reason": null} +{"index": 6, "message": {"role": "assistant", "content": "intelligence "}, "finish_reason": null} +{"index": 7, "message": {"role": "assistant", "content": ""}, "finish_reason": null} +{"index": 8, "message": {"role": "assistant", "content": ""}, "finish_reason": null} +{"index": 9, "message": {"role": "assistant", "content": "(AI) "}, "finish_reason": null} +{"index": 10, "message": {"role": "assistant", "content": "is "}, "finish_reason": null} +{"index": 11, "message": {"role": "assistant", "content": "a "}, "finish_reason": null} +{"index": 12, "message": {"role": "assistant", "content": "branch "}, "finish_reason": null} +{"index": 13, "message": {"role": "assistant", "content": "of "}, "finish_reason": null} +{"index": 14, "message": {"role": "assistant", "content": "computer "}, "finish_reason": null} +{"index": 15, "message": {"role": "assistant", "content": "science "}, "finish_reason": null} +{"index": 16, "message": {"role": "assistant", "content": "that "}, "finish_reason": null} +{"index": 17, "message": {"role": "assistant", "content": ""}, "finish_reason": null} +{"index": 18, "message": {"role": "assistant", "content": "deals "}, "finish_reason": null} +{"index": 19, "message": {"role": "assistant", "content": "with "}, "finish_reason": null} +{"index": 20, "message": {"role": "assistant", "content": "the "}, "finish_reason": null} +{"index": 21, "message": {"role": "assistant", "content": "creation "}, "finish_reason": null} +{"index": 22, "message": {"role": "assistant", "content": "of "}, "finish_reason": null} +{"index": 23, "message": {"role": "assistant", "content": ""}, "finish_reason": null} +{"index": 24, "message": {"role": "assistant", "content": "intelligent "}, "finish_reason": null} +{"index": 25, "message": {"role": "assistant", "content": "machines "}, "finish_reason": null} +{"index": 26, "message": {"role": "assistant", "content": "that "}, "finish_reason": null} +{"index": 27, "message": {"role": "assistant", "content": "can "}, "finish_reason": null} +{"index": 28, "message": {"role": "assistant", "content": "perform "}, "finish_reason": null} +{"index": 29, "message": {"role": "assistant", "content": "tasks "}, "finish_reason": null} +{"index": 30, "message": {"role": "assistant", "content": "that "}, "finish_reason": null} +{"index": 31, "message": {"role": "assistant", "content": "typically "}, "finish_reason": null} +{"index": 32, "message": {"role": "assistant", "content": null}, "finish_reason": "length"} + + ``` **Important**: The first token latency is much larger than rest token latency, you could use [our benchmark tool](https://github.com/intel-analytics/ipex-llm/blob/main/python/llm/dev/benchmark/README.md) to obtain more details about first and rest token latency. diff --git a/python/llm/example/GPU/Deepspeed-AutoTP-FastAPI/serving.py b/python/llm/example/GPU/Deepspeed-AutoTP-FastAPI/serving.py index bf8b64e8..eab52364 100644 --- a/python/llm/example/GPU/Deepspeed-AutoTP-FastAPI/serving.py +++ b/python/llm/example/GPU/Deepspeed-AutoTP-FastAPI/serving.py @@ -18,17 +18,25 @@ import os import torch import transformers import time +import json import argparse import torch.distributed as dist from fastapi import FastAPI, HTTPException +from fastapi.responses import StreamingResponse + from pydantic import BaseModel import uvicorn +from threading import Thread +from ipex_llm.transformers.streamer import BatchTextIteratorStreamer + import asyncio, uuid +from collections import deque from typing import Dict, List, Optional from transformers.utils import logging + logger = logging.get_logger(__name__) from ipex_llm.utils.benchmark_util import BenchmarkWrapper @@ -42,17 +50,31 @@ def get_int_from_env(env_keys, default): return val return int(default) + global max_num_seqs global max_num_batched_tokens -local_rank = get_int_from_env(["LOCAL_RANK","PMI_RANK"], "0") -world_size = get_int_from_env(["WORLD_SIZE","PMI_SIZE"], "1") +local_rank = get_int_from_env(["LOCAL_RANK", "PMI_RANK"], "0") +world_size = get_int_from_env(["WORLD_SIZE", "PMI_SIZE"], "1") os.environ["RANK"] = str(local_rank) os.environ["WORLD_SIZE"] = str(world_size) os.environ["MASTER_PORT"] = os.environ.get("MASTER_PORT", "29500") global model, tokenizer + +class PromptRequest(BaseModel): + prompt: str + n_predict: int = 32 + + +rest_req_deque = deque(maxlen=128) +request_queue: asyncio.Queue = asyncio.Queue() +result_dict: Dict[str, str] = {} +streamer_dict = {} +empty_req = PromptRequest(prompt="", n_predict=0) + + def load_model(model_path, low_bit): from ipex_llm import optimize_model @@ -61,7 +83,9 @@ def load_model(model_path, low_bit): import time import argparse - from transformers import AutoModelForCausalLM # export AutoModelForCausalLM from transformers so that deepspeed use it + from transformers import ( + AutoModelForCausalLM, + ) # export AutoModelForCausalLM from transformers so that deepspeed use it from transformers import LlamaTokenizer, AutoTokenizer import deepspeed from deepspeed.accelerator.cpu_accelerator import CPU_Accelerator @@ -73,12 +97,14 @@ def load_model(model_path, low_bit): current_accel = CPU_Accelerator() set_accelerator(current_accel) global model, tokenizer - model = AutoModelForCausalLM.from_pretrained(model_path, - device_map={"": "cpu"}, - low_cpu_mem_usage=True, - torch_dtype=torch.float16, - trust_remote_code=True, - use_cache=True) + model = AutoModelForCausalLM.from_pretrained( + model_path, + device_map={"": "cpu"}, + low_cpu_mem_usage=True, + torch_dtype=torch.float16, + trust_remote_code=True, + use_cache=True, + ) model = deepspeed.init_inference( model, @@ -89,70 +115,171 @@ def load_model(model_path, low_bit): # Use IPEX-LLM `optimize_model` to convert the model into optimized low bit format # Convert the rest of the model into float16 to reduce allreduce traffic - model = optimize_model(model.module.to(f'cpu'), low_bit=low_bit).to(torch.float16) + model = optimize_model(model.module.to(f"cpu"), low_bit=low_bit).to(torch.float16) # Next, use XPU as accelerator to speed up inference current_accel = XPU_Accelerator() set_accelerator(current_accel) # Move model back to xpu - model = model.to(f'xpu:{local_rank}') + model = model.to(f"xpu:{local_rank}") model = BenchmarkWrapper(model) - # Modify backend related settings + # Modify backend related settings if world_size > 1: get_accelerator().set_device(local_rank) dist_backend = get_accelerator().communication_backend_name() import deepspeed.comm.comm + deepspeed.comm.comm.cdb = None from deepspeed.comm.comm import init_distributed + init_distributed() # Load tokenizer - tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True, padding_side='left') + tokenizer = AutoTokenizer.from_pretrained( + model_path, trust_remote_code=True, padding_side="left" + ) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token -def generate_text(prompt: List[str], n_predict = 32): + +async def generate_stream_gate(prompt: List[str], n_predict=32, request_ids=[]): while prompt[-1] == "": prompt = prompt[:-1] if isinstance(n_predict, list): n_predict = max(n_predict) - + inputs = tokenizer(prompt, return_tensors="pt", padding=True) - input_ids = inputs.input_ids.to(f'xpu:{local_rank}') - # print(input_ids) - attention_mask = inputs.attention_mask.to(f'xpu:{local_rank}') - output = model.generate(input_ids, - attention_mask=attention_mask, - max_new_tokens=n_predict, - use_cache=True) - torch.xpu.synchronize() - return output + input_ids = inputs.input_ids.to(f"xpu:{local_rank}") + attention_mask = inputs.attention_mask.to(f"xpu:{local_rank}") + for request_id in request_ids: + if request_id not in streamer_dict: + streamer_dict[request_id] = asyncio.Queue() -class PromptRequest(BaseModel): - prompt: str - n_predict: int = 32 + streamer = BatchTextIteratorStreamer( + tokenizer=tokenizer, + timeout=600, + skip_prompt=True, + skip_special_tokens=True, + batch_size=len(prompt), + ) + + generated_kwargs = dict( + max_new_tokens=n_predict, + min_new_tokens=n_predict, + streamer=streamer, + attention_mask=attention_mask, + do_sample=False, + ) + + def model_generate(): + model.generate(input_ids, **generated_kwargs) + torch.xpu.empty_cache() + torch.xpu.synchronize() + + t1 = Thread(target=model_generate) + t1.start() + + stopped = False + + async def put_item(queue, item): + await queue.put(item) + + for i in range(n_predict): + tasks = [] + try: + output_token = next(streamer) + except StopIteration: + stopped = True + for index, request_id in enumerate(request_ids): + task = asyncio.create_task( + put_item( + streamer_dict[request_id], + (0 if stopped else n_predict - 1 - i, output_token[index]), + ) + ) + tasks.append(task) + await asyncio.gather(*tasks) + if stopped: + break -empty_req = PromptRequest(prompt="", n_predict=0) app = FastAPI() -from collections import deque -rest_req_deque = deque(maxlen=128) -request_queue: asyncio.Queue = asyncio.Queue() -result_dict: Dict[str, str] = {} + +async def stream_generator(token_queue, request_id): + index = 0 + while True: + if not token_queue.empty(): + remain, token = await token_queue.get() + response = { + "index": index, + "message": {"role": "assistant", "content": token}, + "finish_reason": None, + } + yield json.dumps(response) + "\n" + index = index + 1 + if remain == 0: + response = { + "index": index, + "message": {"role": "assistant", "content": None}, + "finish_reason": "length", + } + yield json.dumps(response) + "\n" + break + else: + await asyncio.sleep(0) + streamer_dict.pop(request_id, None) + + +async def generator(token_queue, request_id): + while True: + if not token_queue.empty(): + remain, token = await token_queue.get() + yield token + if remain == 0: + break + else: + await asyncio.sleep(0) + streamer_dict.pop(request_id, None) + @app.post("/generate/") async def generate(prompt_request: PromptRequest): request_id = str(uuid.uuid4()) await request_queue.put((request_id, prompt_request)) while True: - await asyncio.sleep(0.1) - if request_id in result_dict: - output_str = result_dict.pop(request_id) - return {"generated_text": output_str} + await asyncio.sleep(0) + if request_id in streamer_dict: + output_str = [] + token_queue = streamer_dict[request_id] + async for item in generator(token_queue, request_id): + output_str.append(item) + + return { + "index": 0, + "message": { + "role": "assistant", + "content": "".join(output_str), + }, + "finish_reason": "stop", + } + + +@app.post("/generate_stream/") +async def generate_stream(prompt_request: PromptRequest): + request_id = str(uuid.uuid4()) + "stream" + await request_queue.put((request_id, prompt_request)) + while True: + await asyncio.sleep(0) + if request_id in streamer_dict: + token_queue = streamer_dict[request_id] + + return StreamingResponse( + stream_generator(token_queue, request_id), media_type="application/json" + ) async def process_requests(): @@ -164,7 +291,9 @@ async def process_requests(): while rest_req_deque: request_id, rest_request = rest_req_deque.popleft() prompt = rest_request.prompt - cur_prompt_len = tokenizer(prompt_request.prompt, return_tensors="pt").input_ids.size(1) + cur_prompt_len = tokenizer( + prompt_request.prompt, return_tensors="pt" + ).input_ids.size(1) cur_batched_tokens += cur_prompt_len if cur_batched_tokens > max_num_batched_tokens: cur_batched_tokens -= cur_prompt_len @@ -179,9 +308,9 @@ async def process_requests(): if request_queue.empty(): break request_id, prompt_request = await request_queue.get() - # import pdb - # pdb.set_trace() - cur_prompt_len = tokenizer(prompt_request.prompt, return_tensors="pt").input_ids.size(1) + cur_prompt_len = tokenizer( + prompt_request.prompt, return_tensors="pt" + ).input_ids.size(1) cur_batched_tokens += cur_prompt_len if cur_batched_tokens > max_num_batched_tokens: cur_batched_tokens -= cur_prompt_len @@ -193,21 +322,28 @@ async def process_requests(): if local_rank == 0 and prompt_requests: object_list = prompt_requests if len(object_list) < max_num_seqs: - object_list = object_list + [empty_req] * (max_num_seqs - len(object_list)) - logger.info(f"Running: {len(prompt_requests)}, Pending: {request_queue.qsize()}") + object_list = object_list + [empty_req] * ( + max_num_seqs - len(object_list) + ) + logger.info( + f"Running: {len(prompt_requests)}, Pending: {request_queue.qsize()}" + ) dist.broadcast_object_list(object_list, src=0) + start_time = time.time() - outputs = generate_text([req.prompt for req in object_list], [req.n_predict for req in object_list]) + await generate_stream_gate( + [req.prompt for req in object_list], + [req.n_predict for req in object_list], + request_ids, + ) + generate_time = time.time() - start_time - outputs = outputs.cpu() - output_strs = tokenizer.batch_decode(outputs, skip_special_tokens=True) - output_strs = output_strs[:len(prompt_requests)] - for request_id, output_str in zip(request_ids, output_strs): - result_dict[request_id] = output_str - logger.info(f"First token latency: {model.first_cost}, next token latency: {model.rest_cost_mean}, generate time: {generate_time}") + logger.info( + f"First token latency: {model.first_cost}, next token latency: {model.rest_cost_mean}, generate time: {generate_time}" + ) - await asyncio.sleep(0.1) + await asyncio.sleep(0) @app.on_event("startup") @@ -216,19 +352,44 @@ async def startup_event(): asyncio.create_task(process_requests()) -if __name__ == "__main__": - parser = argparse.ArgumentParser(description='Predict Tokens using fastapi by leveraging DeepSpeed-AutoTP') - parser.add_argument('--repo-id-or-model-path', type=str, default="meta-llama/Llama-2-7b-chat-hf", - help='The huggingface repo id for the Llama2 (e.g. `meta-llama/Llama-2-7b-chat-hf`, `meta-llama/Llama-2-13b-chat-hf` and `meta-llama/Llama-2-70b-chat-hf`) to be downloaded' - ', or the path to the huggingface checkpoint folder') - parser.add_argument('--low-bit', type=str, default='sym_int4', - help='The quantization type the model will convert to.') - parser.add_argument('--port', type=int, default=8000, - help='The port number on which the server will run.') - parser.add_argument('--max-num-batched-tokens', type=int, default=4096, - help='Max tokens can be batched by this service.') - parser.add_argument('--max-num-seqs', type=int, default=8, - help='Max requests can be batched by this service.') +async def main(): + parser = argparse.ArgumentParser( + description="Predict Tokens using fastapi by leveraging DeepSpeed-AutoTP" + ) + parser.add_argument( + "--repo-id-or-model-path", + type=str, + default="meta-llama/Llama-2-7b-chat-hf", + help="The huggingface repo id for the Llama2 (e.g. `meta-llama/Llama-2-7b-chat-hf`, `meta-llama/Llama-2-13b-chat-hf` and `meta-llama/Llama-2-70b-chat-hf`) to be downloaded" + ", or the path to the huggingface checkpoint folder", + ) + parser.add_argument( + "--low-bit", + type=str, + default="sym_int4", + help="The quantization type the model will convert to.", + ) + parser.add_argument( + "--port", + type=int, + default=8000, + help="The port number on which the server will run.", + ) + parser.add_argument( + "--max-num-batched-tokens", + type=int, + default=4096, + help="Max tokens can be batched by this service.", + ) + parser.add_argument( + "--max-num-seqs", + type=int, + default=8, + help="Max requests can be batched by this service.", + ) + + global max_num_seqs + global max_num_batched_tokens args = parser.parse_args() model_path = args.repo_id_or_model_path @@ -236,10 +397,21 @@ if __name__ == "__main__": max_num_seqs = args.max_num_seqs max_num_batched_tokens = args.max_num_batched_tokens load_model(model_path, low_bit) + + config = uvicorn.Config(app=app, host="0.0.0.0", port=args.port) + server = uvicorn.Server(config) + if local_rank == 0: - uvicorn.run(app, host="0.0.0.0", port=args.port) + await server.serve() else: while True: object_list = [None] * max_num_seqs dist.broadcast_object_list(object_list, src=0) - output = generate_text([req.prompt for req in object_list], [req.n_predict for req in object_list]) + await generate_stream_gate( + [req.prompt for req in object_list], + [req.n_predict for req in object_list], + ) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/python/llm/example/GPU/Deepspeed-AutoTP-FastAPI/run_llama2_7b_chat_hf_arc_2_card.sh b/python/llm/example/GPU/Deepspeed-AutoTP-FastAPI/start-deepspeed-autotp-ipex-llm-serving.sh similarity index 92% rename from python/llm/example/GPU/Deepspeed-AutoTP-FastAPI/run_llama2_7b_chat_hf_arc_2_card.sh rename to python/llm/example/GPU/Deepspeed-AutoTP-FastAPI/start-deepspeed-autotp-ipex-llm-serving.sh index 86e737cb..c3d3bd85 100644 --- a/python/llm/example/GPU/Deepspeed-AutoTP-FastAPI/run_llama2_7b_chat_hf_arc_2_card.sh +++ b/python/llm/example/GPU/Deepspeed-AutoTP-FastAPI/start-deepspeed-autotp-ipex-llm-serving.sh @@ -33,5 +33,4 @@ export TORCH_LLM_ALLREDUCE=0 export WORLD_SIZE=2 mpirun -np $NUM_GPUS --prepend-rank \ - python serving.py --repo-id-or-model-path YOUR_REPO_ID_OR_MODEL_PATH --low-bit 'sym_int4' --port 8000 --max-num-seqs 8 --max-num-batched-tokens 8192 - + python serving.py --repo-id-or-model-path YOUR_REPO_ID_OR_MODEL_PATH --low-bit 'fp8' --port 8000 --max-num-seqs 8 --max-num-batched-tokens 8192 diff --git a/python/llm/src/ipex_llm/transformers/streamer.py b/python/llm/src/ipex_llm/transformers/streamer.py new file mode 100644 index 00000000..5d1b457c --- /dev/null +++ b/python/llm/src/ipex_llm/transformers/streamer.py @@ -0,0 +1,114 @@ +# +# Copyright 2016 The BigDL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Some parts of this file is adapted from +# https://github.com/huggingface/transformers/blob/main/src/transformers/generation/streamers.py +# + +from typing import Optional, List + +import torch +from transformers import TextIteratorStreamer + + +class BatchTextIteratorStreamer(TextIteratorStreamer): + """ + A specialized version of TextIteratorStreamer that handles text streams in batches, providing + an efficient way to process large volumes of text data asynchronously. This class is designed + to aggregate multiple texts into batches, making it ideal for applications that need to + perform batch operations on streamed text data, such as bulk text processing or machine + learning model inference in an interactive environment. + + Parameters: + tokenizer (`AutoTokenizer`): + The tokenized used to decode the tokens. + skip_prompt (`bool`, *optional*, defaults to `False`): + Whether to skip the prompt to `.generate()` or not. + timeout (`float`, *optional*): + The timeout for the text queue. If `None`, the queue will + block indefinitely. Useful to handle exceptions + in `.generate()`, when it is called in a separate thread. + decode_kwargs (`dict`, *optional*): + Additional keyword arguments to pass to the tokenizer's `decode` method. + batch_size(`int`) + The size of the batches to process. This parameter must be specified and + determines how many texts are processed together as a single batch. + """ + + def __init__( + self, + batch_size: int, + tokenizer: "AutoTokenizer", + skip_prompt: bool = False, + timeout: Optional[float] = None, + **decode_kwargs + ): + super().__init__(tokenizer, skip_prompt, timeout, **decode_kwargs) + self.batch_size = batch_size + self.token_cache = [[] for _ in range(batch_size)] + self.print_len = [0 for _ in range(batch_size)] + self.generate_exception = None + + def put(self, value): + if len(value.shape) != 2: + value = torch.reshape( + value, (self.batch_size, value.shape[0] // self.batch_size) + ) + + if self.skip_prompt and self.next_tokens_are_prompt: + self.next_tokens_are_prompt = False + return + + printable_texts = list() + for idx in range(self.batch_size): + self.token_cache[idx].extend(value[idx].tolist()) + text = self.tokenizer.decode(self.token_cache[idx], **self.decode_kwargs) + + if text.endswith("\n"): + printable_text = text[self.print_len[idx]:] + self.token_cache[idx] = [] + self.print_len[idx] = 0 + # If the last token is a CJK character, we print the characters. + elif len(text) > 0 and self._is_chinese_char(ord(text[-1])): + printable_text = text[self.print_len[idx]:] + self.print_len[idx] += len(printable_text) + else: + printable_text = text[self.print_len[idx]:text.rfind(" ") + 1] + self.print_len[idx] += len(printable_text) + printable_texts.append(printable_text) + + self.on_finalized_text(printable_texts) + + def end(self): + printable_texts = list() + for idx in range(self.batch_size): + if len(self.token_cache[idx]) > 0: + text = self.tokenizer.decode( + self.token_cache[idx], **self.decode_kwargs + ) + printable_text = text[self.print_len[idx]:] + self.token_cache[idx] = [] + self.print_len[idx] = 0 + else: + printable_text = "" + printable_texts.append(printable_text) + + self.next_tokens_are_prompt = True + self.on_finalized_text(printable_texts, stream_end=True) + + def on_finalized_text(self, texts: List[str], stream_end: bool = False): + self.text_queue.put(texts, timeout=self.timeout) + if stream_end: + self.text_queue.put(self.stop_signal, timeout=self.timeout)