From fb32fefcbeb20c95229b9291d1f1152a74483aad Mon Sep 17 00:00:00 2001 From: binbin Deng <108676127+plusbang@users.noreply.github.com> Date: Thu, 27 Jul 2023 17:59:49 +0800 Subject: [PATCH] LLM: support tensor input of native int4 `generate` (#8620) --- .../src/bigdl/llm/ggml/model/generation/utils.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/python/llm/src/bigdl/llm/ggml/model/generation/utils.py b/python/llm/src/bigdl/llm/ggml/model/generation/utils.py index 75749443..a033d2e9 100644 --- a/python/llm/src/bigdl/llm/ggml/model/generation/utils.py +++ b/python/llm/src/bigdl/llm/ggml/model/generation/utils.py @@ -22,6 +22,7 @@ from typing import Optional, Union, Sequence, List from bigdl.llm.utils.common import invalidInputError +import torch class GenerationMixin: @@ -100,8 +101,9 @@ class GenerationMixin: def generate( self, - inputs: Union[Optional[Sequence[int]], - Sequence[Sequence[int]]]=None, + inputs: Optional[Union[Sequence[int], + Sequence[Sequence[int]], + torch.Tensor]]=None, max_new_tokens: int = 128, top_k: int = 40, top_p: float = 0.95, @@ -116,9 +118,9 @@ class GenerationMixin: mirostat_eta: float = 0.1, stop: Optional[Union[str, List[str]]]=[], # TODO: rebase to support stopping_criteria **kwargs, - ) -> Union[Optional[Sequence[int]], - Sequence[Sequence[int]], - None]: + ) -> Optional[Union[Sequence[int], + Sequence[Sequence[int]], + None]]: # TODO: modify docs """Create a generator of tokens from a prompt. @@ -140,6 +142,8 @@ class GenerationMixin: Yields: The generated tokens. """ + if isinstance(inputs, torch.Tensor): + inputs = inputs.tolist() if inputs and len(inputs) > 0: if not isinstance(inputs[0], Sequence): inputs = [inputs]