LLM: Support generate(max_new_tokens=...), tokenize and decode for transformers-like API (#8283)

* first push

* fix pep8
This commit is contained in:
Junwei Deng 2023-06-07 11:50:35 +08:00 committed by GitHub
parent 11cd2a07e0
commit 2d14e593f0
3 changed files with 53 additions and 22 deletions

View file

@ -31,6 +31,30 @@ class GenerationMixin:
Pass custom parameter values to 'generate' . Pass custom parameter values to 'generate' .
""" """
def tokenize(self, text: str, add_bos: bool = True) -> List[int]:
'''
Decode the id to words
:param text: The text to be tokenized
:param add_bos:
:return: list of ids that indicates the tokens
'''
if isinstance(text, str):
bstr = text.encode()
else:
bstr = text
return self._tokenize(bstr, add_bos)
def decode(self, tokens: List[int]) -> str:
'''
Decode the id to words
:param tokens: list of ids that indicates the tokens, mostly generated by generate
:return: decoded string
'''
return self.detokenize(tokens).decode()
def generate( def generate(
self, self,
inputs: Union[Optional[Sequence[int]], Sequence[gptneox_cpp.gptneox_token]]=None, inputs: Union[Optional[Sequence[int]], Sequence[gptneox_cpp.gptneox_token]]=None,
@ -46,18 +70,18 @@ class GenerationMixin:
mirostat_mode: int = 0, mirostat_mode: int = 0,
mirostat_tau: float = 5.0, mirostat_tau: float = 5.0,
mirostat_eta: float = 0.1, mirostat_eta: float = 0.1,
stop: Optional[Union[str, List[str]]]=[], stop: Optional[Union[str, List[str]]]=[], # TODO: rebase to support stopping_criteria
**kwargs, **kwargs,
) -> Union[Optional[Sequence[int]], Optional[Sequence[gptneox_cpp.gptneox_token]], None]: ) -> Union[Optional[Sequence[int]], Optional[Sequence[gptneox_cpp.gptneox_token]], None]:
# TODO: modify docs # TODO: modify docs
"""Create a generator of tokens from a prompt. """Create a generator of tokens from a prompt.
Examples: Examples:
>>> llama = Llama("models/ggml-7b.bin") >>> llm = AutoModelForCausalLM.from_pretrained("gpt4all-model-q4_0.bin",
>>> tokens = llama.tokenize(b"Hello, world!") model_family="llama")
>>> for token in llama.generate(tokens, top_k=40, top_p=0.95, >>> tokens = llm.tokenize("Q: Tell me something about Intel. A:")
>>> temp=1.0, repeat_penalty=1.1): >>> tokens_id = llm.generate(tokens, max_new_tokens=32)
... print(llama.detokenize([token])) >>> llm.decode(tokens_id)
Args: Args:
tokens: The prompt tokens. tokens: The prompt tokens.
@ -70,17 +94,24 @@ class GenerationMixin:
Yields: Yields:
The generated tokens. The generated tokens.
""" """
# TODO: stop & max_token tokens = self._generate(tokens=inputs,
self._generate(tokens=inputs, top_k=top_k,
top_k=top_k, top_p=top_p,
top_p=top_p, temp=temperature,
temp=temperature, repeat_penalty=repetition_penalty,
repeat_penalty=repetition_penalty, reset=reset,
reset=reset, frequency_penalty=frequency_penalty,
frequency_penalty=frequency_penalty, presence_penalty=presence_penalty,
presence_penalty=presence_penalty, tfs_z=tfs_z,
tfs_z=tfs_z, mirostat_mode=mirostat_mode,
mirostat_mode=mirostat_mode, mirostat_tau=mirostat_tau,
mirostat_tau=mirostat_tau, mirostat_eta=mirostat_eta,
mirostat_eta=mirostat_eta, **kwargs)
**kwargs) res_list = []
word_count = 0
for token in tokens:
if word_count > max_new_tokens:
break
res_list.append(token)
word_count += 1
return res_list

View file

@ -225,7 +225,7 @@ class Gptneox:
if self.verbose: if self.verbose:
print(gptneox_cpp.gptneox_print_system_info().decode("utf-8"), file=sys.stderr) print(gptneox_cpp.gptneox_print_system_info().decode("utf-8"), file=sys.stderr)
def tokenize( def _tokenize(
self, text: bytes, add_bos: bool = True self, text: bytes, add_bos: bool = True
) -> List[gptneox_cpp.gptneox_token]: ) -> List[gptneox_cpp.gptneox_token]:
"""Tokenize a string. """Tokenize a string.

View file

@ -250,7 +250,7 @@ class Llama(GenerationMixin):
self._token_nl = Llama.token_nl() self._token_nl = Llama.token_nl()
self._token_eos = Llama.token_eos() self._token_eos = Llama.token_eos()
def tokenize(self, text: bytes, add_bos: bool = True) -> List[int]: def _tokenize(self, text: bytes, add_bos: bool = True) -> List[int]:
"""Tokenize a string. """Tokenize a string.
Args: Args: