add (#8659)
This commit is contained in:
parent
0714888705
commit
a15a2516e6
1 changed files with 31 additions and 8 deletions
|
|
@ -76,6 +76,8 @@ class TransformersLLM(LLM):
|
|||
"""BigDL-LLM Transformer-INT4 model."""
|
||||
tokenizer: Any #: :meta private:
|
||||
"""Huggingface tokenizer model."""
|
||||
streaming: bool = True
|
||||
"""Whether to stream the results, token by token."""
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
|
@ -196,11 +198,32 @@ class TransformersLLM(LLM):
|
|||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
input_ids = self.tokenizer.encode(prompt, return_tensors="pt")
|
||||
output = self.model.generate(input_ids, **kwargs)
|
||||
text = self.tokenizer.decode(output[0], skip_special_tokens=True)[len(prompt) :]
|
||||
if stop is not None:
|
||||
# This is a bit hacky, but I can't figure out a better way to enforce
|
||||
# stop tokens when making calls to huggingface_hub.
|
||||
text = enforce_stop_tokens(text, stop)
|
||||
return text
|
||||
if self.streaming:
|
||||
from transformers import TextStreamer
|
||||
input_ids = self.tokenizer.encode(prompt, return_tensors="pt")
|
||||
streamer = TextStreamer(self.tokenizer, skip_prompt=True, skip_special_tokens=True)
|
||||
if stop is not None:
|
||||
from transformers.generation.stopping_criteria import StoppingCriteriaList
|
||||
from transformers.tools.agents import StopSequenceCriteria
|
||||
# stop generation when stop words are encountered
|
||||
# TODO: stop generation when the following one is stop word
|
||||
stopping_criteria = StoppingCriteriaList([StopSequenceCriteria(stop,
|
||||
self.tokenizer)])
|
||||
else:
|
||||
stopping_criteria = None
|
||||
output = self.model.generate(input_ids, streamer=streamer,
|
||||
stopping_criteria=stopping_criteria, **kwargs)
|
||||
text = self.tokenizer.decode(output[0], skip_special_tokens=True)
|
||||
return text
|
||||
else:
|
||||
input_ids = self.tokenizer.encode(prompt, return_tensors="pt")
|
||||
if stop is not None:
|
||||
from transformers.generation.stopping_criteria import StoppingCriteriaList
|
||||
from transformers.tools.agents import StopSequenceCriteria
|
||||
stopping_criteria = StoppingCriteriaList([StopSequenceCriteria(stop,
|
||||
self.tokenizer)])
|
||||
else:
|
||||
stopping_criteria = None
|
||||
output = self.model.generate(input_ids, stopping_criteria=stopping_criteria, **kwargs)
|
||||
text = self.tokenizer.decode(output[0], skip_special_tokens=True)[len(prompt) :]
|
||||
return text
|
||||
|
|
|
|||
Loading…
Reference in a new issue