diff --git a/python/llm/src/bigdl/llm/ggml/model/generation/utils.py b/python/llm/src/bigdl/llm/ggml/model/generation/utils.py index 18ffa271..e033c7f9 100644 --- a/python/llm/src/bigdl/llm/ggml/model/generation/utils.py +++ b/python/llm/src/bigdl/llm/ggml/model/generation/utils.py @@ -30,33 +30,78 @@ class GenerationMixin: Pass custom parameter values to 'generate' . """ - def tokenize(self, text: str, add_bos: bool = True) -> List[int]: + def tokenize(self, + text: Union[str, List[str]], + add_bos: bool = True) -> List[int]: ''' Decode the id to words - :param text: The text to be tokenized + :param text: The text or batch of text to be tokenized :param add_bos: :return: list of ids that indicates the tokens ''' - if isinstance(text, str): - bstr = text.encode() - else: - bstr = text - return self._tokenize(bstr, add_bos) + is_batched = True if isinstance(text, (list, tuple)) else False + if not is_batched: + text = [text] + result = [] + for t in text: + if isinstance(t, str): + bstr = t.encode() + else: + bstr = t + result.append(self._tokenize(bstr, add_bos)) + + if not is_batched: + result = result[0] + return result + def decode(self, tokens: List[int]) -> str: ''' Decode the id to words + Examples: + >>> llm = AutoModelForCausalLM.from_pretrained("gpt4all-model-q4_0.bin", + model_family="llama") + >>> tokens = llm.tokenize("Q: Tell me something about Intel. A:") + >>> tokens_id = llm.generate(tokens, max_new_tokens=32) + >>> llm.decode(tokens_id[0]) + :param tokens: list of ids that indicates the tokens, mostly generated by generate :return: decoded string ''' return self.detokenize(tokens).decode() + def batch_decode(self, + tokens: Union[List[int], List[List[int]]]) -> str: + ''' + Decode the id to words + + :param tokens: list or a batch of list of ids that indicates the tokens, + mostly generated by generate + :return: decoded string + ''' + is_batched = False + if tokens is not None and len(tokens) > 0: + if isinstance(tokens[0], Sequence): + is_batched = True + else: + tokens = [tokens] + else: + return None + + results = [] + for t in tokens: + results.append(self.decode(t)) + if not is_batched: + results = results[0] + return results + def generate( self, - inputs: Optional[Sequence[int]]=None, + inputs: Union[Optional[Sequence[int]], + Sequence[Sequence[int]]]=None, max_new_tokens: int = 128, top_k: int = 40, top_p: float = 0.95, @@ -71,7 +116,9 @@ class GenerationMixin: mirostat_eta: float = 0.1, stop: Optional[Union[str, List[str]]]=[], # TODO: rebase to support stopping_criteria **kwargs, - ) -> Union[Optional[Sequence[int]], None]: + ) -> Union[Optional[Sequence[int]], + Sequence[Sequence[int]], + None]: # TODO: modify docs """Create a generator of tokens from a prompt. @@ -80,7 +127,7 @@ class GenerationMixin: model_family="llama") >>> tokens = llm.tokenize("Q: Tell me something about Intel. A:") >>> tokens_id = llm.generate(tokens, max_new_tokens=32) - >>> llm.decode(tokens_id) + >>> llm.batch_decode(tokens_id) Args: tokens: The prompt tokens. @@ -93,24 +140,34 @@ class GenerationMixin: Yields: The generated tokens. """ - tokens = self._generate(tokens=inputs, - top_k=top_k, - top_p=top_p, - temp=temperature, - repeat_penalty=repetition_penalty, - reset=reset, - frequency_penalty=frequency_penalty, - presence_penalty=presence_penalty, - tfs_z=tfs_z, - mirostat_mode=mirostat_mode, - mirostat_tau=mirostat_tau, - mirostat_eta=mirostat_eta, - **kwargs) - res_list = [] - word_count = 0 - for token in tokens: - if word_count > max_new_tokens: - break - res_list.append(token) - word_count += 1 - return res_list + if inputs and len(inputs) > 0: + if not isinstance(inputs[0], Sequence): + inputs = [inputs] + else: + return None + + results = [] + for input in inputs: + tokens = self._generate(tokens=input, + top_k=top_k, + top_p=top_p, + temp=temperature, + repeat_penalty=repetition_penalty, + reset=reset, + frequency_penalty=frequency_penalty, + presence_penalty=presence_penalty, + tfs_z=tfs_z, + mirostat_mode=mirostat_mode, + mirostat_tau=mirostat_tau, + mirostat_eta=mirostat_eta, + **kwargs) + res_list = [] + word_count = 0 + for token in tokens: + if word_count > max_new_tokens: + break + res_list.append(token) + word_count += 1 + results.append(res_list) + + return results