LLM: Support generate(max_new_tokens=...), tokenize and decode for transformers-like API (#8283)

* first push

* fix pep8
This commit is contained in:
Junwei Deng 2023-06-07 11:50:35 +08:00 committed by GitHub
parent 11cd2a07e0
commit 2d14e593f0
3 changed files with 53 additions and 22 deletions

View file

@ -31,6 +31,30 @@ class GenerationMixin:
Pass custom parameter values to 'generate' .
"""
def tokenize(self, text: str, add_bos: bool = True) -> List[int]:
'''
Decode the id to words
:param text: The text to be tokenized
:param add_bos:
:return: list of ids that indicates the tokens
'''
if isinstance(text, str):
bstr = text.encode()
else:
bstr = text
return self._tokenize(bstr, add_bos)
def decode(self, tokens: List[int]) -> str:
'''
Decode the id to words
:param tokens: list of ids that indicates the tokens, mostly generated by generate
:return: decoded string
'''
return self.detokenize(tokens).decode()
def generate(
self,
inputs: Union[Optional[Sequence[int]], Sequence[gptneox_cpp.gptneox_token]]=None,
@ -46,18 +70,18 @@ class GenerationMixin:
mirostat_mode: int = 0,
mirostat_tau: float = 5.0,
mirostat_eta: float = 0.1,
stop: Optional[Union[str, List[str]]]=[],
stop: Optional[Union[str, List[str]]]=[], # TODO: rebase to support stopping_criteria
**kwargs,
) -> Union[Optional[Sequence[int]], Optional[Sequence[gptneox_cpp.gptneox_token]], None]:
# TODO: modify docs
"""Create a generator of tokens from a prompt.
Examples:
>>> llama = Llama("models/ggml-7b.bin")
>>> tokens = llama.tokenize(b"Hello, world!")
>>> for token in llama.generate(tokens, top_k=40, top_p=0.95,
>>> temp=1.0, repeat_penalty=1.1):
... print(llama.detokenize([token]))
>>> llm = AutoModelForCausalLM.from_pretrained("gpt4all-model-q4_0.bin",
model_family="llama")
>>> tokens = llm.tokenize("Q: Tell me something about Intel. A:")
>>> tokens_id = llm.generate(tokens, max_new_tokens=32)
>>> llm.decode(tokens_id)
Args:
tokens: The prompt tokens.
@ -70,8 +94,7 @@ class GenerationMixin:
Yields:
The generated tokens.
"""
# TODO: stop & max_token
self._generate(tokens=inputs,
tokens = self._generate(tokens=inputs,
top_k=top_k,
top_p=top_p,
temp=temperature,
@ -84,3 +107,11 @@ class GenerationMixin:
mirostat_tau=mirostat_tau,
mirostat_eta=mirostat_eta,
**kwargs)
res_list = []
word_count = 0
for token in tokens:
if word_count > max_new_tokens:
break
res_list.append(token)
word_count += 1
return res_list

View file

@ -225,7 +225,7 @@ class Gptneox:
if self.verbose:
print(gptneox_cpp.gptneox_print_system_info().decode("utf-8"), file=sys.stderr)
def tokenize(
def _tokenize(
self, text: bytes, add_bos: bool = True
) -> List[gptneox_cpp.gptneox_token]:
"""Tokenize a string.

View file

@ -250,7 +250,7 @@ class Llama(GenerationMixin):
self._token_nl = Llama.token_nl()
self._token_eos = Llama.token_eos()
def tokenize(self, text: bytes, add_bos: bool = True) -> List[int]:
def _tokenize(self, text: bytes, add_bos: bool = True) -> List[int]:
"""Tokenize a string.
Args: