[LLM] llm transformers api support batch actions (#8288)

* llm transformers api support batch actions

* align with transformer

* meet comment
This commit is contained in:
Yina Chen 2023-06-08 15:10:08 +08:00 committed by GitHub
parent ea3cf6783e
commit 637b72f2ad

View file

@ -30,33 +30,78 @@ class GenerationMixin:
Pass custom parameter values to 'generate' . Pass custom parameter values to 'generate' .
""" """
def tokenize(self, text: str, add_bos: bool = True) -> List[int]: def tokenize(self,
text: Union[str, List[str]],
add_bos: bool = True) -> List[int]:
''' '''
Decode the id to words Decode the id to words
:param text: The text to be tokenized :param text: The text or batch of text to be tokenized
:param add_bos: :param add_bos:
:return: list of ids that indicates the tokens :return: list of ids that indicates the tokens
''' '''
if isinstance(text, str): is_batched = True if isinstance(text, (list, tuple)) else False
bstr = text.encode() if not is_batched:
else: text = [text]
bstr = text
return self._tokenize(bstr, add_bos)
result = []
for t in text:
if isinstance(t, str):
bstr = t.encode()
else:
bstr = t
result.append(self._tokenize(bstr, add_bos))
if not is_batched:
result = result[0]
return result
def decode(self, tokens: List[int]) -> str: def decode(self, tokens: List[int]) -> str:
''' '''
Decode the id to words Decode the id to words
Examples:
>>> llm = AutoModelForCausalLM.from_pretrained("gpt4all-model-q4_0.bin",
model_family="llama")
>>> tokens = llm.tokenize("Q: Tell me something about Intel. A:")
>>> tokens_id = llm.generate(tokens, max_new_tokens=32)
>>> llm.decode(tokens_id[0])
:param tokens: list of ids that indicates the tokens, mostly generated by generate :param tokens: list of ids that indicates the tokens, mostly generated by generate
:return: decoded string :return: decoded string
''' '''
return self.detokenize(tokens).decode() return self.detokenize(tokens).decode()
def batch_decode(self,
tokens: Union[List[int], List[List[int]]]) -> str:
'''
Decode the id to words
:param tokens: list or a batch of list of ids that indicates the tokens,
mostly generated by generate
:return: decoded string
'''
is_batched = False
if tokens is not None and len(tokens) > 0:
if isinstance(tokens[0], Sequence):
is_batched = True
else:
tokens = [tokens]
else:
return None
results = []
for t in tokens:
results.append(self.decode(t))
if not is_batched:
results = results[0]
return results
def generate( def generate(
self, self,
inputs: Optional[Sequence[int]]=None, inputs: Union[Optional[Sequence[int]],
Sequence[Sequence[int]]]=None,
max_new_tokens: int = 128, max_new_tokens: int = 128,
top_k: int = 40, top_k: int = 40,
top_p: float = 0.95, top_p: float = 0.95,
@ -71,7 +116,9 @@ class GenerationMixin:
mirostat_eta: float = 0.1, mirostat_eta: float = 0.1,
stop: Optional[Union[str, List[str]]]=[], # TODO: rebase to support stopping_criteria stop: Optional[Union[str, List[str]]]=[], # TODO: rebase to support stopping_criteria
**kwargs, **kwargs,
) -> Union[Optional[Sequence[int]], None]: ) -> Union[Optional[Sequence[int]],
Sequence[Sequence[int]],
None]:
# TODO: modify docs # TODO: modify docs
"""Create a generator of tokens from a prompt. """Create a generator of tokens from a prompt.
@ -80,7 +127,7 @@ class GenerationMixin:
model_family="llama") model_family="llama")
>>> tokens = llm.tokenize("Q: Tell me something about Intel. A:") >>> tokens = llm.tokenize("Q: Tell me something about Intel. A:")
>>> tokens_id = llm.generate(tokens, max_new_tokens=32) >>> tokens_id = llm.generate(tokens, max_new_tokens=32)
>>> llm.decode(tokens_id) >>> llm.batch_decode(tokens_id)
Args: Args:
tokens: The prompt tokens. tokens: The prompt tokens.
@ -93,24 +140,34 @@ class GenerationMixin:
Yields: Yields:
The generated tokens. The generated tokens.
""" """
tokens = self._generate(tokens=inputs, if inputs and len(inputs) > 0:
top_k=top_k, if not isinstance(inputs[0], Sequence):
top_p=top_p, inputs = [inputs]
temp=temperature, else:
repeat_penalty=repetition_penalty, return None
reset=reset,
frequency_penalty=frequency_penalty, results = []
presence_penalty=presence_penalty, for input in inputs:
tfs_z=tfs_z, tokens = self._generate(tokens=input,
mirostat_mode=mirostat_mode, top_k=top_k,
mirostat_tau=mirostat_tau, top_p=top_p,
mirostat_eta=mirostat_eta, temp=temperature,
**kwargs) repeat_penalty=repetition_penalty,
res_list = [] reset=reset,
word_count = 0 frequency_penalty=frequency_penalty,
for token in tokens: presence_penalty=presence_penalty,
if word_count > max_new_tokens: tfs_z=tfs_z,
break mirostat_mode=mirostat_mode,
res_list.append(token) mirostat_tau=mirostat_tau,
word_count += 1 mirostat_eta=mirostat_eta,
return res_list **kwargs)
res_list = []
word_count = 0
for token in tokens:
if word_count > max_new_tokens:
break
res_list.append(token)
word_count += 1
results.append(res_list)
return results