ipex-llm/python/llm/src/ipex_llm/transformers/streamer.py
ZehuaCao 63e95698eb
[LLM]Reopen autotp generate_stream (#11120)
* reopen autotp generate_stream

* fix style error

* update
2024-05-24 17:16:14 +08:00

114 lines
4.7 KiB
Python

#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Some parts of this file is adapted from
# https://github.com/huggingface/transformers/blob/main/src/transformers/generation/streamers.py
#
from typing import Optional, List
import torch
from transformers import TextIteratorStreamer
class BatchTextIteratorStreamer(TextIteratorStreamer):
"""
A specialized version of TextIteratorStreamer that handles text streams in batches, providing
an efficient way to process large volumes of text data asynchronously. This class is designed
to aggregate multiple texts into batches, making it ideal for applications that need to
perform batch operations on streamed text data, such as bulk text processing or machine
learning model inference in an interactive environment.
Parameters:
tokenizer (`AutoTokenizer`):
The tokenized used to decode the tokens.
skip_prompt (`bool`, *optional*, defaults to `False`):
Whether to skip the prompt to `.generate()` or not.
timeout (`float`, *optional*):
The timeout for the text queue. If `None`, the queue will
block indefinitely. Useful to handle exceptions
in `.generate()`, when it is called in a separate thread.
decode_kwargs (`dict`, *optional*):
Additional keyword arguments to pass to the tokenizer's `decode` method.
batch_size(`int`)
The size of the batches to process. This parameter must be specified and
determines how many texts are processed together as a single batch.
"""
def __init__(
self,
batch_size: int,
tokenizer: "AutoTokenizer",
skip_prompt: bool = False,
timeout: Optional[float] = None,
**decode_kwargs
):
super().__init__(tokenizer, skip_prompt, timeout, **decode_kwargs)
self.batch_size = batch_size
self.token_cache = [[] for _ in range(batch_size)]
self.print_len = [0 for _ in range(batch_size)]
self.generate_exception = None
def put(self, value):
if len(value.shape) != 2:
value = torch.reshape(
value, (self.batch_size, value.shape[0] // self.batch_size)
)
if self.skip_prompt and self.next_tokens_are_prompt:
self.next_tokens_are_prompt = False
return
printable_texts = list()
for idx in range(self.batch_size):
self.token_cache[idx].extend(value[idx].tolist())
text = self.tokenizer.decode(self.token_cache[idx], **self.decode_kwargs)
if text.endswith("\n"):
printable_text = text[self.print_len[idx]:]
self.token_cache[idx] = []
self.print_len[idx] = 0
# If the last token is a CJK character, we print the characters.
elif len(text) > 0 and self._is_chinese_char(ord(text[-1])):
printable_text = text[self.print_len[idx]:]
self.print_len[idx] += len(printable_text)
else:
printable_text = text[self.print_len[idx]:text.rfind(" ") + 1]
self.print_len[idx] += len(printable_text)
printable_texts.append(printable_text)
self.on_finalized_text(printable_texts)
def end(self):
printable_texts = list()
for idx in range(self.batch_size):
if len(self.token_cache[idx]) > 0:
text = self.tokenizer.decode(
self.token_cache[idx], **self.decode_kwargs
)
printable_text = text[self.print_len[idx]:]
self.token_cache[idx] = []
self.print_len[idx] = 0
else:
printable_text = ""
printable_texts.append(printable_text)
self.next_tokens_are_prompt = True
self.on_finalized_text(printable_texts, stream_end=True)
def on_finalized_text(self, texts: List[str], stream_end: bool = False):
self.text_queue.put(texts, timeout=self.timeout)
if stream_end:
self.text_queue.put(self.stop_signal, timeout=self.timeout)