This commit is contained in:
binbin Deng 2023-08-03 10:12:10 +08:00 committed by GitHub
parent 0714888705
commit a15a2516e6

View file

@ -76,6 +76,8 @@ class TransformersLLM(LLM):
"""BigDL-LLM Transformer-INT4 model.""" """BigDL-LLM Transformer-INT4 model."""
tokenizer: Any #: :meta private: tokenizer: Any #: :meta private:
"""Huggingface tokenizer model.""" """Huggingface tokenizer model."""
streaming: bool = True
"""Whether to stream the results, token by token."""
class Config: class Config:
"""Configuration for this pydantic object.""" """Configuration for this pydantic object."""
@ -196,11 +198,32 @@ class TransformersLLM(LLM):
run_manager: Optional[CallbackManagerForLLMRun] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any, **kwargs: Any,
) -> str: ) -> str:
if self.streaming:
from transformers import TextStreamer
input_ids = self.tokenizer.encode(prompt, return_tensors="pt") input_ids = self.tokenizer.encode(prompt, return_tensors="pt")
output = self.model.generate(input_ids, **kwargs) streamer = TextStreamer(self.tokenizer, skip_prompt=True, skip_special_tokens=True)
text = self.tokenizer.decode(output[0], skip_special_tokens=True)[len(prompt) :]
if stop is not None: if stop is not None:
# This is a bit hacky, but I can't figure out a better way to enforce from transformers.generation.stopping_criteria import StoppingCriteriaList
# stop tokens when making calls to huggingface_hub. from transformers.tools.agents import StopSequenceCriteria
text = enforce_stop_tokens(text, stop) # stop generation when stop words are encountered
# TODO: stop generation when the following one is stop word
stopping_criteria = StoppingCriteriaList([StopSequenceCriteria(stop,
self.tokenizer)])
else:
stopping_criteria = None
output = self.model.generate(input_ids, streamer=streamer,
stopping_criteria=stopping_criteria, **kwargs)
text = self.tokenizer.decode(output[0], skip_special_tokens=True)
return text
else:
input_ids = self.tokenizer.encode(prompt, return_tensors="pt")
if stop is not None:
from transformers.generation.stopping_criteria import StoppingCriteriaList
from transformers.tools.agents import StopSequenceCriteria
stopping_criteria = StoppingCriteriaList([StopSequenceCriteria(stop,
self.tokenizer)])
else:
stopping_criteria = None
output = self.model.generate(input_ids, stopping_criteria=stopping_criteria, **kwargs)
text = self.tokenizer.decode(output[0], skip_special_tokens=True)[len(prompt) :]
return text return text