LLM: Support generate(max_new_tokens=...), tokenize and decode for transformers-like API (#8283)
* first push * fix pep8
This commit is contained in:
parent
11cd2a07e0
commit
2d14e593f0
3 changed files with 53 additions and 22 deletions
|
|
@ -31,6 +31,30 @@ class GenerationMixin:
|
||||||
|
|
||||||
Pass custom parameter values to 'generate' .
|
Pass custom parameter values to 'generate' .
|
||||||
"""
|
"""
|
||||||
|
def tokenize(self, text: str, add_bos: bool = True) -> List[int]:
|
||||||
|
'''
|
||||||
|
Decode the id to words
|
||||||
|
|
||||||
|
:param text: The text to be tokenized
|
||||||
|
:param add_bos:
|
||||||
|
|
||||||
|
:return: list of ids that indicates the tokens
|
||||||
|
'''
|
||||||
|
if isinstance(text, str):
|
||||||
|
bstr = text.encode()
|
||||||
|
else:
|
||||||
|
bstr = text
|
||||||
|
return self._tokenize(bstr, add_bos)
|
||||||
|
|
||||||
|
def decode(self, tokens: List[int]) -> str:
|
||||||
|
'''
|
||||||
|
Decode the id to words
|
||||||
|
|
||||||
|
:param tokens: list of ids that indicates the tokens, mostly generated by generate
|
||||||
|
:return: decoded string
|
||||||
|
'''
|
||||||
|
return self.detokenize(tokens).decode()
|
||||||
|
|
||||||
def generate(
|
def generate(
|
||||||
self,
|
self,
|
||||||
inputs: Union[Optional[Sequence[int]], Sequence[gptneox_cpp.gptneox_token]]=None,
|
inputs: Union[Optional[Sequence[int]], Sequence[gptneox_cpp.gptneox_token]]=None,
|
||||||
|
|
@ -46,18 +70,18 @@ class GenerationMixin:
|
||||||
mirostat_mode: int = 0,
|
mirostat_mode: int = 0,
|
||||||
mirostat_tau: float = 5.0,
|
mirostat_tau: float = 5.0,
|
||||||
mirostat_eta: float = 0.1,
|
mirostat_eta: float = 0.1,
|
||||||
stop: Optional[Union[str, List[str]]]=[],
|
stop: Optional[Union[str, List[str]]]=[], # TODO: rebase to support stopping_criteria
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Union[Optional[Sequence[int]], Optional[Sequence[gptneox_cpp.gptneox_token]], None]:
|
) -> Union[Optional[Sequence[int]], Optional[Sequence[gptneox_cpp.gptneox_token]], None]:
|
||||||
# TODO: modify docs
|
# TODO: modify docs
|
||||||
"""Create a generator of tokens from a prompt.
|
"""Create a generator of tokens from a prompt.
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
>>> llama = Llama("models/ggml-7b.bin")
|
>>> llm = AutoModelForCausalLM.from_pretrained("gpt4all-model-q4_0.bin",
|
||||||
>>> tokens = llama.tokenize(b"Hello, world!")
|
model_family="llama")
|
||||||
>>> for token in llama.generate(tokens, top_k=40, top_p=0.95,
|
>>> tokens = llm.tokenize("Q: Tell me something about Intel. A:")
|
||||||
>>> temp=1.0, repeat_penalty=1.1):
|
>>> tokens_id = llm.generate(tokens, max_new_tokens=32)
|
||||||
... print(llama.detokenize([token]))
|
>>> llm.decode(tokens_id)
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
tokens: The prompt tokens.
|
tokens: The prompt tokens.
|
||||||
|
|
@ -70,17 +94,24 @@ class GenerationMixin:
|
||||||
Yields:
|
Yields:
|
||||||
The generated tokens.
|
The generated tokens.
|
||||||
"""
|
"""
|
||||||
# TODO: stop & max_token
|
tokens = self._generate(tokens=inputs,
|
||||||
self._generate(tokens=inputs,
|
top_k=top_k,
|
||||||
top_k=top_k,
|
top_p=top_p,
|
||||||
top_p=top_p,
|
temp=temperature,
|
||||||
temp=temperature,
|
repeat_penalty=repetition_penalty,
|
||||||
repeat_penalty=repetition_penalty,
|
reset=reset,
|
||||||
reset=reset,
|
frequency_penalty=frequency_penalty,
|
||||||
frequency_penalty=frequency_penalty,
|
presence_penalty=presence_penalty,
|
||||||
presence_penalty=presence_penalty,
|
tfs_z=tfs_z,
|
||||||
tfs_z=tfs_z,
|
mirostat_mode=mirostat_mode,
|
||||||
mirostat_mode=mirostat_mode,
|
mirostat_tau=mirostat_tau,
|
||||||
mirostat_tau=mirostat_tau,
|
mirostat_eta=mirostat_eta,
|
||||||
mirostat_eta=mirostat_eta,
|
**kwargs)
|
||||||
**kwargs)
|
res_list = []
|
||||||
|
word_count = 0
|
||||||
|
for token in tokens:
|
||||||
|
if word_count > max_new_tokens:
|
||||||
|
break
|
||||||
|
res_list.append(token)
|
||||||
|
word_count += 1
|
||||||
|
return res_list
|
||||||
|
|
|
||||||
|
|
@ -225,7 +225,7 @@ class Gptneox:
|
||||||
if self.verbose:
|
if self.verbose:
|
||||||
print(gptneox_cpp.gptneox_print_system_info().decode("utf-8"), file=sys.stderr)
|
print(gptneox_cpp.gptneox_print_system_info().decode("utf-8"), file=sys.stderr)
|
||||||
|
|
||||||
def tokenize(
|
def _tokenize(
|
||||||
self, text: bytes, add_bos: bool = True
|
self, text: bytes, add_bos: bool = True
|
||||||
) -> List[gptneox_cpp.gptneox_token]:
|
) -> List[gptneox_cpp.gptneox_token]:
|
||||||
"""Tokenize a string.
|
"""Tokenize a string.
|
||||||
|
|
|
||||||
|
|
@ -250,7 +250,7 @@ class Llama(GenerationMixin):
|
||||||
self._token_nl = Llama.token_nl()
|
self._token_nl = Llama.token_nl()
|
||||||
self._token_eos = Llama.token_eos()
|
self._token_eos = Llama.token_eos()
|
||||||
|
|
||||||
def tokenize(self, text: bytes, add_bos: bool = True) -> List[int]:
|
def _tokenize(self, text: bytes, add_bos: bool = True) -> List[int]:
|
||||||
"""Tokenize a string.
|
"""Tokenize a string.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue