LLM: support tensor input of native int4 generate (#8620)

This commit is contained in:
binbin Deng 2023-07-27 17:59:49 +08:00 committed by GitHub
parent 5b484ab48d
commit fb32fefcbe

View file

@ -22,6 +22,7 @@
from typing import Optional, Union, Sequence, List
from bigdl.llm.utils.common import invalidInputError
import torch
class GenerationMixin:
@ -100,8 +101,9 @@ class GenerationMixin:
def generate(
self,
inputs: Union[Optional[Sequence[int]],
Sequence[Sequence[int]]]=None,
inputs: Optional[Union[Sequence[int],
Sequence[Sequence[int]],
torch.Tensor]]=None,
max_new_tokens: int = 128,
top_k: int = 40,
top_p: float = 0.95,
@ -116,9 +118,9 @@ class GenerationMixin:
mirostat_eta: float = 0.1,
stop: Optional[Union[str, List[str]]]=[], # TODO: rebase to support stopping_criteria
**kwargs,
) -> Union[Optional[Sequence[int]],
Sequence[Sequence[int]],
None]:
) -> Optional[Union[Sequence[int],
Sequence[Sequence[int]],
None]]:
# TODO: modify docs
"""Create a generator of tokens from a prompt.
@ -140,6 +142,8 @@ class GenerationMixin:
Yields:
The generated tokens.
"""
if isinstance(inputs, torch.Tensor):
inputs = inputs.tolist()
if inputs and len(inputs) > 0:
if not isinstance(inputs[0], Sequence):
inputs = [inputs]