add (#8659)
This commit is contained in:
parent
0714888705
commit
a15a2516e6
1 changed files with 31 additions and 8 deletions
|
|
@ -76,6 +76,8 @@ class TransformersLLM(LLM):
|
||||||
"""BigDL-LLM Transformer-INT4 model."""
|
"""BigDL-LLM Transformer-INT4 model."""
|
||||||
tokenizer: Any #: :meta private:
|
tokenizer: Any #: :meta private:
|
||||||
"""Huggingface tokenizer model."""
|
"""Huggingface tokenizer model."""
|
||||||
|
streaming: bool = True
|
||||||
|
"""Whether to stream the results, token by token."""
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
"""Configuration for this pydantic object."""
|
"""Configuration for this pydantic object."""
|
||||||
|
|
@ -196,11 +198,32 @@ class TransformersLLM(LLM):
|
||||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> str:
|
) -> str:
|
||||||
|
if self.streaming:
|
||||||
|
from transformers import TextStreamer
|
||||||
input_ids = self.tokenizer.encode(prompt, return_tensors="pt")
|
input_ids = self.tokenizer.encode(prompt, return_tensors="pt")
|
||||||
output = self.model.generate(input_ids, **kwargs)
|
streamer = TextStreamer(self.tokenizer, skip_prompt=True, skip_special_tokens=True)
|
||||||
text = self.tokenizer.decode(output[0], skip_special_tokens=True)[len(prompt) :]
|
|
||||||
if stop is not None:
|
if stop is not None:
|
||||||
# This is a bit hacky, but I can't figure out a better way to enforce
|
from transformers.generation.stopping_criteria import StoppingCriteriaList
|
||||||
# stop tokens when making calls to huggingface_hub.
|
from transformers.tools.agents import StopSequenceCriteria
|
||||||
text = enforce_stop_tokens(text, stop)
|
# stop generation when stop words are encountered
|
||||||
|
# TODO: stop generation when the following one is stop word
|
||||||
|
stopping_criteria = StoppingCriteriaList([StopSequenceCriteria(stop,
|
||||||
|
self.tokenizer)])
|
||||||
|
else:
|
||||||
|
stopping_criteria = None
|
||||||
|
output = self.model.generate(input_ids, streamer=streamer,
|
||||||
|
stopping_criteria=stopping_criteria, **kwargs)
|
||||||
|
text = self.tokenizer.decode(output[0], skip_special_tokens=True)
|
||||||
|
return text
|
||||||
|
else:
|
||||||
|
input_ids = self.tokenizer.encode(prompt, return_tensors="pt")
|
||||||
|
if stop is not None:
|
||||||
|
from transformers.generation.stopping_criteria import StoppingCriteriaList
|
||||||
|
from transformers.tools.agents import StopSequenceCriteria
|
||||||
|
stopping_criteria = StoppingCriteriaList([StopSequenceCriteria(stop,
|
||||||
|
self.tokenizer)])
|
||||||
|
else:
|
||||||
|
stopping_criteria = None
|
||||||
|
output = self.model.generate(input_ids, stopping_criteria=stopping_criteria, **kwargs)
|
||||||
|
text = self.tokenizer.decode(output[0], skip_special_tokens=True)[len(prompt) :]
|
||||||
return text
|
return text
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue