[LLM] llm transformers api support batch actions (#8288)
* llm transformers api support batch actions * align with transformer * meet comment
This commit is contained in:
parent
ea3cf6783e
commit
637b72f2ad
1 changed files with 88 additions and 31 deletions
|
|
@ -30,33 +30,78 @@ class GenerationMixin:
|
||||||
|
|
||||||
Pass custom parameter values to 'generate' .
|
Pass custom parameter values to 'generate' .
|
||||||
"""
|
"""
|
||||||
def tokenize(self, text: str, add_bos: bool = True) -> List[int]:
|
def tokenize(self,
|
||||||
|
text: Union[str, List[str]],
|
||||||
|
add_bos: bool = True) -> List[int]:
|
||||||
'''
|
'''
|
||||||
Decode the id to words
|
Decode the id to words
|
||||||
|
|
||||||
:param text: The text to be tokenized
|
:param text: The text or batch of text to be tokenized
|
||||||
:param add_bos:
|
:param add_bos:
|
||||||
|
|
||||||
:return: list of ids that indicates the tokens
|
:return: list of ids that indicates the tokens
|
||||||
'''
|
'''
|
||||||
if isinstance(text, str):
|
is_batched = True if isinstance(text, (list, tuple)) else False
|
||||||
bstr = text.encode()
|
if not is_batched:
|
||||||
|
text = [text]
|
||||||
|
|
||||||
|
result = []
|
||||||
|
for t in text:
|
||||||
|
if isinstance(t, str):
|
||||||
|
bstr = t.encode()
|
||||||
else:
|
else:
|
||||||
bstr = text
|
bstr = t
|
||||||
return self._tokenize(bstr, add_bos)
|
result.append(self._tokenize(bstr, add_bos))
|
||||||
|
|
||||||
|
if not is_batched:
|
||||||
|
result = result[0]
|
||||||
|
return result
|
||||||
|
|
||||||
def decode(self, tokens: List[int]) -> str:
|
def decode(self, tokens: List[int]) -> str:
|
||||||
'''
|
'''
|
||||||
Decode the id to words
|
Decode the id to words
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> llm = AutoModelForCausalLM.from_pretrained("gpt4all-model-q4_0.bin",
|
||||||
|
model_family="llama")
|
||||||
|
>>> tokens = llm.tokenize("Q: Tell me something about Intel. A:")
|
||||||
|
>>> tokens_id = llm.generate(tokens, max_new_tokens=32)
|
||||||
|
>>> llm.decode(tokens_id[0])
|
||||||
|
|
||||||
:param tokens: list of ids that indicates the tokens, mostly generated by generate
|
:param tokens: list of ids that indicates the tokens, mostly generated by generate
|
||||||
:return: decoded string
|
:return: decoded string
|
||||||
'''
|
'''
|
||||||
return self.detokenize(tokens).decode()
|
return self.detokenize(tokens).decode()
|
||||||
|
|
||||||
|
def batch_decode(self,
|
||||||
|
tokens: Union[List[int], List[List[int]]]) -> str:
|
||||||
|
'''
|
||||||
|
Decode the id to words
|
||||||
|
|
||||||
|
:param tokens: list or a batch of list of ids that indicates the tokens,
|
||||||
|
mostly generated by generate
|
||||||
|
:return: decoded string
|
||||||
|
'''
|
||||||
|
is_batched = False
|
||||||
|
if tokens is not None and len(tokens) > 0:
|
||||||
|
if isinstance(tokens[0], Sequence):
|
||||||
|
is_batched = True
|
||||||
|
else:
|
||||||
|
tokens = [tokens]
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
results = []
|
||||||
|
for t in tokens:
|
||||||
|
results.append(self.decode(t))
|
||||||
|
if not is_batched:
|
||||||
|
results = results[0]
|
||||||
|
return results
|
||||||
|
|
||||||
def generate(
|
def generate(
|
||||||
self,
|
self,
|
||||||
inputs: Optional[Sequence[int]]=None,
|
inputs: Union[Optional[Sequence[int]],
|
||||||
|
Sequence[Sequence[int]]]=None,
|
||||||
max_new_tokens: int = 128,
|
max_new_tokens: int = 128,
|
||||||
top_k: int = 40,
|
top_k: int = 40,
|
||||||
top_p: float = 0.95,
|
top_p: float = 0.95,
|
||||||
|
|
@ -71,7 +116,9 @@ class GenerationMixin:
|
||||||
mirostat_eta: float = 0.1,
|
mirostat_eta: float = 0.1,
|
||||||
stop: Optional[Union[str, List[str]]]=[], # TODO: rebase to support stopping_criteria
|
stop: Optional[Union[str, List[str]]]=[], # TODO: rebase to support stopping_criteria
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Union[Optional[Sequence[int]], None]:
|
) -> Union[Optional[Sequence[int]],
|
||||||
|
Sequence[Sequence[int]],
|
||||||
|
None]:
|
||||||
# TODO: modify docs
|
# TODO: modify docs
|
||||||
"""Create a generator of tokens from a prompt.
|
"""Create a generator of tokens from a prompt.
|
||||||
|
|
||||||
|
|
@ -80,7 +127,7 @@ class GenerationMixin:
|
||||||
model_family="llama")
|
model_family="llama")
|
||||||
>>> tokens = llm.tokenize("Q: Tell me something about Intel. A:")
|
>>> tokens = llm.tokenize("Q: Tell me something about Intel. A:")
|
||||||
>>> tokens_id = llm.generate(tokens, max_new_tokens=32)
|
>>> tokens_id = llm.generate(tokens, max_new_tokens=32)
|
||||||
>>> llm.decode(tokens_id)
|
>>> llm.batch_decode(tokens_id)
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
tokens: The prompt tokens.
|
tokens: The prompt tokens.
|
||||||
|
|
@ -93,7 +140,15 @@ class GenerationMixin:
|
||||||
Yields:
|
Yields:
|
||||||
The generated tokens.
|
The generated tokens.
|
||||||
"""
|
"""
|
||||||
tokens = self._generate(tokens=inputs,
|
if inputs and len(inputs) > 0:
|
||||||
|
if not isinstance(inputs[0], Sequence):
|
||||||
|
inputs = [inputs]
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
results = []
|
||||||
|
for input in inputs:
|
||||||
|
tokens = self._generate(tokens=input,
|
||||||
top_k=top_k,
|
top_k=top_k,
|
||||||
top_p=top_p,
|
top_p=top_p,
|
||||||
temp=temperature,
|
temp=temperature,
|
||||||
|
|
@ -113,4 +168,6 @@ class GenerationMixin:
|
||||||
break
|
break
|
||||||
res_list.append(token)
|
res_list.append(token)
|
||||||
word_count += 1
|
word_count += 1
|
||||||
return res_list
|
results.append(res_list)
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue