diff --git a/python/llm/src/bigdl/llm/langchain/llms/transformersllm.py b/python/llm/src/bigdl/llm/langchain/llms/transformersllm.py index ade3cd01..6b5385a8 100644 --- a/python/llm/src/bigdl/llm/langchain/llms/transformersllm.py +++ b/python/llm/src/bigdl/llm/langchain/llms/transformersllm.py @@ -76,6 +76,8 @@ class TransformersLLM(LLM): """BigDL-LLM Transformer-INT4 model.""" tokenizer: Any #: :meta private: """Huggingface tokenizer model.""" + streaming: bool = True + """Whether to stream the results, token by token.""" class Config: """Configuration for this pydantic object.""" @@ -196,11 +198,32 @@ class TransformersLLM(LLM): run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> str: - input_ids = self.tokenizer.encode(prompt, return_tensors="pt") - output = self.model.generate(input_ids, **kwargs) - text = self.tokenizer.decode(output[0], skip_special_tokens=True)[len(prompt) :] - if stop is not None: - # This is a bit hacky, but I can't figure out a better way to enforce - # stop tokens when making calls to huggingface_hub. - text = enforce_stop_tokens(text, stop) - return text + if self.streaming: + from transformers import TextStreamer + input_ids = self.tokenizer.encode(prompt, return_tensors="pt") + streamer = TextStreamer(self.tokenizer, skip_prompt=True, skip_special_tokens=True) + if stop is not None: + from transformers.generation.stopping_criteria import StoppingCriteriaList + from transformers.tools.agents import StopSequenceCriteria + # stop generation when stop words are encountered + # TODO: stop generation when the following one is stop word + stopping_criteria = StoppingCriteriaList([StopSequenceCriteria(stop, + self.tokenizer)]) + else: + stopping_criteria = None + output = self.model.generate(input_ids, streamer=streamer, + stopping_criteria=stopping_criteria, **kwargs) + text = self.tokenizer.decode(output[0], skip_special_tokens=True) + return text + else: + input_ids = self.tokenizer.encode(prompt, return_tensors="pt") + if stop is not None: + from transformers.generation.stopping_criteria import StoppingCriteriaList + from transformers.tools.agents import StopSequenceCriteria + stopping_criteria = StoppingCriteriaList([StopSequenceCriteria(stop, + self.tokenizer)]) + else: + stopping_criteria = None + output = self.model.generate(input_ids, stopping_criteria=stopping_criteria, **kwargs) + text = self.tokenizer.decode(output[0], skip_special_tokens=True)[len(prompt) :] + return text