diff --git a/python/llm/src/bigdl/llm/ggml/model/generation/utils.py b/python/llm/src/bigdl/llm/ggml/model/generation/utils.py index 75749443..a033d2e9 100644 --- a/python/llm/src/bigdl/llm/ggml/model/generation/utils.py +++ b/python/llm/src/bigdl/llm/ggml/model/generation/utils.py @@ -22,6 +22,7 @@ from typing import Optional, Union, Sequence, List from bigdl.llm.utils.common import invalidInputError +import torch class GenerationMixin: @@ -100,8 +101,9 @@ class GenerationMixin: def generate( self, - inputs: Union[Optional[Sequence[int]], - Sequence[Sequence[int]]]=None, + inputs: Optional[Union[Sequence[int], + Sequence[Sequence[int]], + torch.Tensor]]=None, max_new_tokens: int = 128, top_k: int = 40, top_p: float = 0.95, @@ -116,9 +118,9 @@ class GenerationMixin: mirostat_eta: float = 0.1, stop: Optional[Union[str, List[str]]]=[], # TODO: rebase to support stopping_criteria **kwargs, - ) -> Union[Optional[Sequence[int]], - Sequence[Sequence[int]], - None]: + ) -> Optional[Union[Sequence[int], + Sequence[Sequence[int]], + None]]: # TODO: modify docs """Create a generator of tokens from a prompt. @@ -140,6 +142,8 @@ class GenerationMixin: Yields: The generated tokens. """ + if isinstance(inputs, torch.Tensor): + inputs = inputs.tolist() if inputs and len(inputs) > 0: if not isinstance(inputs[0], Sequence): inputs = [inputs]