[LLM]Reopen autotp generate_stream (#11120)

* reopen autotp generate_stream

* fix style error

* update
This commit is contained in:
ZehuaCao 2024-05-24 17:16:14 +08:00 committed by GitHub
parent 1dc680341b
commit 63e95698eb
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 414 additions and 69 deletions

View file

@ -58,6 +58,8 @@ If you successfully run the serving, you can get output like this:
We can use `curl` to test serving api We can use `curl` to test serving api
#### generate()
```bash ```bash
# Set http_proxy and https_proxy to null to ensure that requests are not forwarded by a proxy. # Set http_proxy and https_proxy to null to ensure that requests are not forwarded by a proxy.
export http_proxy= export http_proxy=
@ -77,10 +79,68 @@ And you should get output like this:
```json ```json
{ {
"generated_text": "What is AI? Artificial intelligence (AI) refers to the development of computer systems able to perform tasks that would normally require human intelligence, such as visual perception, speech", "index": 0,
"generate_time": "0.45149803161621094s" "message": {
"role": "assistant",
"content": "\n\nArtificial intelligence (AI) is a branch of computer science that deals with the creation of intelligent machines that can perform tasks that typically "
},
"finish_reason": "stop"
} }
```
#### generate_stream()
```bash
# Set http_proxy and https_proxy to null to ensure that requests are not forwarded by a proxy.
export http_proxy=
export https_proxy=
curl -X 'POST' \
'http://127.0.0.1:8000/generate_stream/' \
-H 'accept: application/json' \
-H 'Content-Type: application/json' \
-d '{
"prompt": "What is AI?",
"n_predict": 32
}'
```
And you should get output like this:
```json
{"index": 0, "message": {"role": "assistant", "content": "\n"}, "finish_reason": null}
{"index": 1, "message": {"role": "assistant", "content": "\n"}, "finish_reason": null}
{"index": 2, "message": {"role": "assistant", "content": ""}, "finish_reason": null}
{"index": 3, "message": {"role": "assistant", "content": ""}, "finish_reason": null}
{"index": 4, "message": {"role": "assistant", "content": ""}, "finish_reason": null}
{"index": 5, "message": {"role": "assistant", "content": "Artificial "}, "finish_reason": null}
{"index": 6, "message": {"role": "assistant", "content": "intelligence "}, "finish_reason": null}
{"index": 7, "message": {"role": "assistant", "content": ""}, "finish_reason": null}
{"index": 8, "message": {"role": "assistant", "content": ""}, "finish_reason": null}
{"index": 9, "message": {"role": "assistant", "content": "(AI) "}, "finish_reason": null}
{"index": 10, "message": {"role": "assistant", "content": "is "}, "finish_reason": null}
{"index": 11, "message": {"role": "assistant", "content": "a "}, "finish_reason": null}
{"index": 12, "message": {"role": "assistant", "content": "branch "}, "finish_reason": null}
{"index": 13, "message": {"role": "assistant", "content": "of "}, "finish_reason": null}
{"index": 14, "message": {"role": "assistant", "content": "computer "}, "finish_reason": null}
{"index": 15, "message": {"role": "assistant", "content": "science "}, "finish_reason": null}
{"index": 16, "message": {"role": "assistant", "content": "that "}, "finish_reason": null}
{"index": 17, "message": {"role": "assistant", "content": ""}, "finish_reason": null}
{"index": 18, "message": {"role": "assistant", "content": "deals "}, "finish_reason": null}
{"index": 19, "message": {"role": "assistant", "content": "with "}, "finish_reason": null}
{"index": 20, "message": {"role": "assistant", "content": "the "}, "finish_reason": null}
{"index": 21, "message": {"role": "assistant", "content": "creation "}, "finish_reason": null}
{"index": 22, "message": {"role": "assistant", "content": "of "}, "finish_reason": null}
{"index": 23, "message": {"role": "assistant", "content": ""}, "finish_reason": null}
{"index": 24, "message": {"role": "assistant", "content": "intelligent "}, "finish_reason": null}
{"index": 25, "message": {"role": "assistant", "content": "machines "}, "finish_reason": null}
{"index": 26, "message": {"role": "assistant", "content": "that "}, "finish_reason": null}
{"index": 27, "message": {"role": "assistant", "content": "can "}, "finish_reason": null}
{"index": 28, "message": {"role": "assistant", "content": "perform "}, "finish_reason": null}
{"index": 29, "message": {"role": "assistant", "content": "tasks "}, "finish_reason": null}
{"index": 30, "message": {"role": "assistant", "content": "that "}, "finish_reason": null}
{"index": 31, "message": {"role": "assistant", "content": "typically "}, "finish_reason": null}
{"index": 32, "message": {"role": "assistant", "content": null}, "finish_reason": "length"}
``` ```
**Important**: The first token latency is much larger than rest token latency, you could use [our benchmark tool](https://github.com/intel-analytics/ipex-llm/blob/main/python/llm/dev/benchmark/README.md) to obtain more details about first and rest token latency. **Important**: The first token latency is much larger than rest token latency, you could use [our benchmark tool](https://github.com/intel-analytics/ipex-llm/blob/main/python/llm/dev/benchmark/README.md) to obtain more details about first and rest token latency.

View file

@ -18,17 +18,25 @@ import os
import torch import torch
import transformers import transformers
import time import time
import json
import argparse import argparse
import torch.distributed as dist import torch.distributed as dist
from fastapi import FastAPI, HTTPException from fastapi import FastAPI, HTTPException
from fastapi.responses import StreamingResponse
from pydantic import BaseModel from pydantic import BaseModel
import uvicorn import uvicorn
from threading import Thread
from ipex_llm.transformers.streamer import BatchTextIteratorStreamer
import asyncio, uuid import asyncio, uuid
from collections import deque
from typing import Dict, List, Optional from typing import Dict, List, Optional
from transformers.utils import logging from transformers.utils import logging
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
from ipex_llm.utils.benchmark_util import BenchmarkWrapper from ipex_llm.utils.benchmark_util import BenchmarkWrapper
@ -42,17 +50,31 @@ def get_int_from_env(env_keys, default):
return val return val
return int(default) return int(default)
global max_num_seqs global max_num_seqs
global max_num_batched_tokens global max_num_batched_tokens
local_rank = get_int_from_env(["LOCAL_RANK","PMI_RANK"], "0") local_rank = get_int_from_env(["LOCAL_RANK", "PMI_RANK"], "0")
world_size = get_int_from_env(["WORLD_SIZE","PMI_SIZE"], "1") world_size = get_int_from_env(["WORLD_SIZE", "PMI_SIZE"], "1")
os.environ["RANK"] = str(local_rank) os.environ["RANK"] = str(local_rank)
os.environ["WORLD_SIZE"] = str(world_size) os.environ["WORLD_SIZE"] = str(world_size)
os.environ["MASTER_PORT"] = os.environ.get("MASTER_PORT", "29500") os.environ["MASTER_PORT"] = os.environ.get("MASTER_PORT", "29500")
global model, tokenizer global model, tokenizer
class PromptRequest(BaseModel):
prompt: str
n_predict: int = 32
rest_req_deque = deque(maxlen=128)
request_queue: asyncio.Queue = asyncio.Queue()
result_dict: Dict[str, str] = {}
streamer_dict = {}
empty_req = PromptRequest(prompt="", n_predict=0)
def load_model(model_path, low_bit): def load_model(model_path, low_bit):
from ipex_llm import optimize_model from ipex_llm import optimize_model
@ -61,7 +83,9 @@ def load_model(model_path, low_bit):
import time import time
import argparse import argparse
from transformers import AutoModelForCausalLM # export AutoModelForCausalLM from transformers so that deepspeed use it from transformers import (
AutoModelForCausalLM,
) # export AutoModelForCausalLM from transformers so that deepspeed use it
from transformers import LlamaTokenizer, AutoTokenizer from transformers import LlamaTokenizer, AutoTokenizer
import deepspeed import deepspeed
from deepspeed.accelerator.cpu_accelerator import CPU_Accelerator from deepspeed.accelerator.cpu_accelerator import CPU_Accelerator
@ -73,12 +97,14 @@ def load_model(model_path, low_bit):
current_accel = CPU_Accelerator() current_accel = CPU_Accelerator()
set_accelerator(current_accel) set_accelerator(current_accel)
global model, tokenizer global model, tokenizer
model = AutoModelForCausalLM.from_pretrained(model_path, model = AutoModelForCausalLM.from_pretrained(
device_map={"": "cpu"}, model_path,
low_cpu_mem_usage=True, device_map={"": "cpu"},
torch_dtype=torch.float16, low_cpu_mem_usage=True,
trust_remote_code=True, torch_dtype=torch.float16,
use_cache=True) trust_remote_code=True,
use_cache=True,
)
model = deepspeed.init_inference( model = deepspeed.init_inference(
model, model,
@ -89,70 +115,171 @@ def load_model(model_path, low_bit):
# Use IPEX-LLM `optimize_model` to convert the model into optimized low bit format # Use IPEX-LLM `optimize_model` to convert the model into optimized low bit format
# Convert the rest of the model into float16 to reduce allreduce traffic # Convert the rest of the model into float16 to reduce allreduce traffic
model = optimize_model(model.module.to(f'cpu'), low_bit=low_bit).to(torch.float16) model = optimize_model(model.module.to(f"cpu"), low_bit=low_bit).to(torch.float16)
# Next, use XPU as accelerator to speed up inference # Next, use XPU as accelerator to speed up inference
current_accel = XPU_Accelerator() current_accel = XPU_Accelerator()
set_accelerator(current_accel) set_accelerator(current_accel)
# Move model back to xpu # Move model back to xpu
model = model.to(f'xpu:{local_rank}') model = model.to(f"xpu:{local_rank}")
model = BenchmarkWrapper(model) model = BenchmarkWrapper(model)
# Modify backend related settings # Modify backend related settings
if world_size > 1: if world_size > 1:
get_accelerator().set_device(local_rank) get_accelerator().set_device(local_rank)
dist_backend = get_accelerator().communication_backend_name() dist_backend = get_accelerator().communication_backend_name()
import deepspeed.comm.comm import deepspeed.comm.comm
deepspeed.comm.comm.cdb = None deepspeed.comm.comm.cdb = None
from deepspeed.comm.comm import init_distributed from deepspeed.comm.comm import init_distributed
init_distributed() init_distributed()
# Load tokenizer # Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True, padding_side='left') tokenizer = AutoTokenizer.from_pretrained(
model_path, trust_remote_code=True, padding_side="left"
)
if tokenizer.pad_token is None: if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token tokenizer.pad_token = tokenizer.eos_token
def generate_text(prompt: List[str], n_predict = 32):
async def generate_stream_gate(prompt: List[str], n_predict=32, request_ids=[]):
while prompt[-1] == "": while prompt[-1] == "":
prompt = prompt[:-1] prompt = prompt[:-1]
if isinstance(n_predict, list): if isinstance(n_predict, list):
n_predict = max(n_predict) n_predict = max(n_predict)
inputs = tokenizer(prompt, return_tensors="pt", padding=True) inputs = tokenizer(prompt, return_tensors="pt", padding=True)
input_ids = inputs.input_ids.to(f'xpu:{local_rank}') input_ids = inputs.input_ids.to(f"xpu:{local_rank}")
# print(input_ids) attention_mask = inputs.attention_mask.to(f"xpu:{local_rank}")
attention_mask = inputs.attention_mask.to(f'xpu:{local_rank}')
output = model.generate(input_ids,
attention_mask=attention_mask,
max_new_tokens=n_predict,
use_cache=True)
torch.xpu.synchronize()
return output
for request_id in request_ids:
if request_id not in streamer_dict:
streamer_dict[request_id] = asyncio.Queue()
class PromptRequest(BaseModel): streamer = BatchTextIteratorStreamer(
prompt: str tokenizer=tokenizer,
n_predict: int = 32 timeout=600,
skip_prompt=True,
skip_special_tokens=True,
batch_size=len(prompt),
)
generated_kwargs = dict(
max_new_tokens=n_predict,
min_new_tokens=n_predict,
streamer=streamer,
attention_mask=attention_mask,
do_sample=False,
)
def model_generate():
model.generate(input_ids, **generated_kwargs)
torch.xpu.empty_cache()
torch.xpu.synchronize()
t1 = Thread(target=model_generate)
t1.start()
stopped = False
async def put_item(queue, item):
await queue.put(item)
for i in range(n_predict):
tasks = []
try:
output_token = next(streamer)
except StopIteration:
stopped = True
for index, request_id in enumerate(request_ids):
task = asyncio.create_task(
put_item(
streamer_dict[request_id],
(0 if stopped else n_predict - 1 - i, output_token[index]),
)
)
tasks.append(task)
await asyncio.gather(*tasks)
if stopped:
break
empty_req = PromptRequest(prompt="", n_predict=0)
app = FastAPI() app = FastAPI()
from collections import deque
rest_req_deque = deque(maxlen=128) async def stream_generator(token_queue, request_id):
request_queue: asyncio.Queue = asyncio.Queue() index = 0
result_dict: Dict[str, str] = {} while True:
if not token_queue.empty():
remain, token = await token_queue.get()
response = {
"index": index,
"message": {"role": "assistant", "content": token},
"finish_reason": None,
}
yield json.dumps(response) + "\n"
index = index + 1
if remain == 0:
response = {
"index": index,
"message": {"role": "assistant", "content": None},
"finish_reason": "length",
}
yield json.dumps(response) + "\n"
break
else:
await asyncio.sleep(0)
streamer_dict.pop(request_id, None)
async def generator(token_queue, request_id):
while True:
if not token_queue.empty():
remain, token = await token_queue.get()
yield token
if remain == 0:
break
else:
await asyncio.sleep(0)
streamer_dict.pop(request_id, None)
@app.post("/generate/") @app.post("/generate/")
async def generate(prompt_request: PromptRequest): async def generate(prompt_request: PromptRequest):
request_id = str(uuid.uuid4()) request_id = str(uuid.uuid4())
await request_queue.put((request_id, prompt_request)) await request_queue.put((request_id, prompt_request))
while True: while True:
await asyncio.sleep(0.1) await asyncio.sleep(0)
if request_id in result_dict: if request_id in streamer_dict:
output_str = result_dict.pop(request_id) output_str = []
return {"generated_text": output_str} token_queue = streamer_dict[request_id]
async for item in generator(token_queue, request_id):
output_str.append(item)
return {
"index": 0,
"message": {
"role": "assistant",
"content": "".join(output_str),
},
"finish_reason": "stop",
}
@app.post("/generate_stream/")
async def generate_stream(prompt_request: PromptRequest):
request_id = str(uuid.uuid4()) + "stream"
await request_queue.put((request_id, prompt_request))
while True:
await asyncio.sleep(0)
if request_id in streamer_dict:
token_queue = streamer_dict[request_id]
return StreamingResponse(
stream_generator(token_queue, request_id), media_type="application/json"
)
async def process_requests(): async def process_requests():
@ -164,7 +291,9 @@ async def process_requests():
while rest_req_deque: while rest_req_deque:
request_id, rest_request = rest_req_deque.popleft() request_id, rest_request = rest_req_deque.popleft()
prompt = rest_request.prompt prompt = rest_request.prompt
cur_prompt_len = tokenizer(prompt_request.prompt, return_tensors="pt").input_ids.size(1) cur_prompt_len = tokenizer(
prompt_request.prompt, return_tensors="pt"
).input_ids.size(1)
cur_batched_tokens += cur_prompt_len cur_batched_tokens += cur_prompt_len
if cur_batched_tokens > max_num_batched_tokens: if cur_batched_tokens > max_num_batched_tokens:
cur_batched_tokens -= cur_prompt_len cur_batched_tokens -= cur_prompt_len
@ -179,9 +308,9 @@ async def process_requests():
if request_queue.empty(): if request_queue.empty():
break break
request_id, prompt_request = await request_queue.get() request_id, prompt_request = await request_queue.get()
# import pdb cur_prompt_len = tokenizer(
# pdb.set_trace() prompt_request.prompt, return_tensors="pt"
cur_prompt_len = tokenizer(prompt_request.prompt, return_tensors="pt").input_ids.size(1) ).input_ids.size(1)
cur_batched_tokens += cur_prompt_len cur_batched_tokens += cur_prompt_len
if cur_batched_tokens > max_num_batched_tokens: if cur_batched_tokens > max_num_batched_tokens:
cur_batched_tokens -= cur_prompt_len cur_batched_tokens -= cur_prompt_len
@ -193,21 +322,28 @@ async def process_requests():
if local_rank == 0 and prompt_requests: if local_rank == 0 and prompt_requests:
object_list = prompt_requests object_list = prompt_requests
if len(object_list) < max_num_seqs: if len(object_list) < max_num_seqs:
object_list = object_list + [empty_req] * (max_num_seqs - len(object_list)) object_list = object_list + [empty_req] * (
logger.info(f"Running: {len(prompt_requests)}, Pending: {request_queue.qsize()}") max_num_seqs - len(object_list)
)
logger.info(
f"Running: {len(prompt_requests)}, Pending: {request_queue.qsize()}"
)
dist.broadcast_object_list(object_list, src=0) dist.broadcast_object_list(object_list, src=0)
start_time = time.time() start_time = time.time()
outputs = generate_text([req.prompt for req in object_list], [req.n_predict for req in object_list]) await generate_stream_gate(
[req.prompt for req in object_list],
[req.n_predict for req in object_list],
request_ids,
)
generate_time = time.time() - start_time generate_time = time.time() - start_time
outputs = outputs.cpu()
output_strs = tokenizer.batch_decode(outputs, skip_special_tokens=True)
output_strs = output_strs[:len(prompt_requests)]
for request_id, output_str in zip(request_ids, output_strs): logger.info(
result_dict[request_id] = output_str f"First token latency: {model.first_cost}, next token latency: {model.rest_cost_mean}, generate time: {generate_time}"
logger.info(f"First token latency: {model.first_cost}, next token latency: {model.rest_cost_mean}, generate time: {generate_time}") )
await asyncio.sleep(0.1) await asyncio.sleep(0)
@app.on_event("startup") @app.on_event("startup")
@ -216,19 +352,44 @@ async def startup_event():
asyncio.create_task(process_requests()) asyncio.create_task(process_requests())
if __name__ == "__main__": async def main():
parser = argparse.ArgumentParser(description='Predict Tokens using fastapi by leveraging DeepSpeed-AutoTP') parser = argparse.ArgumentParser(
parser.add_argument('--repo-id-or-model-path', type=str, default="meta-llama/Llama-2-7b-chat-hf", description="Predict Tokens using fastapi by leveraging DeepSpeed-AutoTP"
help='The huggingface repo id for the Llama2 (e.g. `meta-llama/Llama-2-7b-chat-hf`, `meta-llama/Llama-2-13b-chat-hf` and `meta-llama/Llama-2-70b-chat-hf`) to be downloaded' )
', or the path to the huggingface checkpoint folder') parser.add_argument(
parser.add_argument('--low-bit', type=str, default='sym_int4', "--repo-id-or-model-path",
help='The quantization type the model will convert to.') type=str,
parser.add_argument('--port', type=int, default=8000, default="meta-llama/Llama-2-7b-chat-hf",
help='The port number on which the server will run.') help="The huggingface repo id for the Llama2 (e.g. `meta-llama/Llama-2-7b-chat-hf`, `meta-llama/Llama-2-13b-chat-hf` and `meta-llama/Llama-2-70b-chat-hf`) to be downloaded"
parser.add_argument('--max-num-batched-tokens', type=int, default=4096, ", or the path to the huggingface checkpoint folder",
help='Max tokens can be batched by this service.') )
parser.add_argument('--max-num-seqs', type=int, default=8, parser.add_argument(
help='Max requests can be batched by this service.') "--low-bit",
type=str,
default="sym_int4",
help="The quantization type the model will convert to.",
)
parser.add_argument(
"--port",
type=int,
default=8000,
help="The port number on which the server will run.",
)
parser.add_argument(
"--max-num-batched-tokens",
type=int,
default=4096,
help="Max tokens can be batched by this service.",
)
parser.add_argument(
"--max-num-seqs",
type=int,
default=8,
help="Max requests can be batched by this service.",
)
global max_num_seqs
global max_num_batched_tokens
args = parser.parse_args() args = parser.parse_args()
model_path = args.repo_id_or_model_path model_path = args.repo_id_or_model_path
@ -236,10 +397,21 @@ if __name__ == "__main__":
max_num_seqs = args.max_num_seqs max_num_seqs = args.max_num_seqs
max_num_batched_tokens = args.max_num_batched_tokens max_num_batched_tokens = args.max_num_batched_tokens
load_model(model_path, low_bit) load_model(model_path, low_bit)
config = uvicorn.Config(app=app, host="0.0.0.0", port=args.port)
server = uvicorn.Server(config)
if local_rank == 0: if local_rank == 0:
uvicorn.run(app, host="0.0.0.0", port=args.port) await server.serve()
else: else:
while True: while True:
object_list = [None] * max_num_seqs object_list = [None] * max_num_seqs
dist.broadcast_object_list(object_list, src=0) dist.broadcast_object_list(object_list, src=0)
output = generate_text([req.prompt for req in object_list], [req.n_predict for req in object_list]) await generate_stream_gate(
[req.prompt for req in object_list],
[req.n_predict for req in object_list],
)
if __name__ == "__main__":
asyncio.run(main())

View file

@ -33,5 +33,4 @@ export TORCH_LLM_ALLREDUCE=0
export WORLD_SIZE=2 export WORLD_SIZE=2
mpirun -np $NUM_GPUS --prepend-rank \ mpirun -np $NUM_GPUS --prepend-rank \
python serving.py --repo-id-or-model-path YOUR_REPO_ID_OR_MODEL_PATH --low-bit 'sym_int4' --port 8000 --max-num-seqs 8 --max-num-batched-tokens 8192 python serving.py --repo-id-or-model-path YOUR_REPO_ID_OR_MODEL_PATH --low-bit 'fp8' --port 8000 --max-num-seqs 8 --max-num-batched-tokens 8192

View file

@ -0,0 +1,114 @@
#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Some parts of this file is adapted from
# https://github.com/huggingface/transformers/blob/main/src/transformers/generation/streamers.py
#
from typing import Optional, List
import torch
from transformers import TextIteratorStreamer
class BatchTextIteratorStreamer(TextIteratorStreamer):
"""
A specialized version of TextIteratorStreamer that handles text streams in batches, providing
an efficient way to process large volumes of text data asynchronously. This class is designed
to aggregate multiple texts into batches, making it ideal for applications that need to
perform batch operations on streamed text data, such as bulk text processing or machine
learning model inference in an interactive environment.
Parameters:
tokenizer (`AutoTokenizer`):
The tokenized used to decode the tokens.
skip_prompt (`bool`, *optional*, defaults to `False`):
Whether to skip the prompt to `.generate()` or not.
timeout (`float`, *optional*):
The timeout for the text queue. If `None`, the queue will
block indefinitely. Useful to handle exceptions
in `.generate()`, when it is called in a separate thread.
decode_kwargs (`dict`, *optional*):
Additional keyword arguments to pass to the tokenizer's `decode` method.
batch_size(`int`)
The size of the batches to process. This parameter must be specified and
determines how many texts are processed together as a single batch.
"""
def __init__(
self,
batch_size: int,
tokenizer: "AutoTokenizer",
skip_prompt: bool = False,
timeout: Optional[float] = None,
**decode_kwargs
):
super().__init__(tokenizer, skip_prompt, timeout, **decode_kwargs)
self.batch_size = batch_size
self.token_cache = [[] for _ in range(batch_size)]
self.print_len = [0 for _ in range(batch_size)]
self.generate_exception = None
def put(self, value):
if len(value.shape) != 2:
value = torch.reshape(
value, (self.batch_size, value.shape[0] // self.batch_size)
)
if self.skip_prompt and self.next_tokens_are_prompt:
self.next_tokens_are_prompt = False
return
printable_texts = list()
for idx in range(self.batch_size):
self.token_cache[idx].extend(value[idx].tolist())
text = self.tokenizer.decode(self.token_cache[idx], **self.decode_kwargs)
if text.endswith("\n"):
printable_text = text[self.print_len[idx]:]
self.token_cache[idx] = []
self.print_len[idx] = 0
# If the last token is a CJK character, we print the characters.
elif len(text) > 0 and self._is_chinese_char(ord(text[-1])):
printable_text = text[self.print_len[idx]:]
self.print_len[idx] += len(printable_text)
else:
printable_text = text[self.print_len[idx]:text.rfind(" ") + 1]
self.print_len[idx] += len(printable_text)
printable_texts.append(printable_text)
self.on_finalized_text(printable_texts)
def end(self):
printable_texts = list()
for idx in range(self.batch_size):
if len(self.token_cache[idx]) > 0:
text = self.tokenizer.decode(
self.token_cache[idx], **self.decode_kwargs
)
printable_text = text[self.print_len[idx]:]
self.token_cache[idx] = []
self.print_len[idx] = 0
else:
printable_text = ""
printable_texts.append(printable_text)
self.next_tokens_are_prompt = True
self.on_finalized_text(printable_texts, stream_end=True)
def on_finalized_text(self, texts: List[str], stream_end: bool = False):
self.text_queue.put(texts, timeout=self.timeout)
if stream_end:
self.text_queue.put(self.stop_signal, timeout=self.timeout)