LLM: support tensor input of native int4 generate (#8620)
This commit is contained in:
parent
5b484ab48d
commit
fb32fefcbe
1 changed files with 9 additions and 5 deletions
|
|
@ -22,6 +22,7 @@
|
||||||
|
|
||||||
from typing import Optional, Union, Sequence, List
|
from typing import Optional, Union, Sequence, List
|
||||||
from bigdl.llm.utils.common import invalidInputError
|
from bigdl.llm.utils.common import invalidInputError
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
class GenerationMixin:
|
class GenerationMixin:
|
||||||
|
|
@ -100,8 +101,9 @@ class GenerationMixin:
|
||||||
|
|
||||||
def generate(
|
def generate(
|
||||||
self,
|
self,
|
||||||
inputs: Union[Optional[Sequence[int]],
|
inputs: Optional[Union[Sequence[int],
|
||||||
Sequence[Sequence[int]]]=None,
|
Sequence[Sequence[int]],
|
||||||
|
torch.Tensor]]=None,
|
||||||
max_new_tokens: int = 128,
|
max_new_tokens: int = 128,
|
||||||
top_k: int = 40,
|
top_k: int = 40,
|
||||||
top_p: float = 0.95,
|
top_p: float = 0.95,
|
||||||
|
|
@ -116,9 +118,9 @@ class GenerationMixin:
|
||||||
mirostat_eta: float = 0.1,
|
mirostat_eta: float = 0.1,
|
||||||
stop: Optional[Union[str, List[str]]]=[], # TODO: rebase to support stopping_criteria
|
stop: Optional[Union[str, List[str]]]=[], # TODO: rebase to support stopping_criteria
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Union[Optional[Sequence[int]],
|
) -> Optional[Union[Sequence[int],
|
||||||
Sequence[Sequence[int]],
|
Sequence[Sequence[int]],
|
||||||
None]:
|
None]]:
|
||||||
# TODO: modify docs
|
# TODO: modify docs
|
||||||
"""Create a generator of tokens from a prompt.
|
"""Create a generator of tokens from a prompt.
|
||||||
|
|
||||||
|
|
@ -140,6 +142,8 @@ class GenerationMixin:
|
||||||
Yields:
|
Yields:
|
||||||
The generated tokens.
|
The generated tokens.
|
||||||
"""
|
"""
|
||||||
|
if isinstance(inputs, torch.Tensor):
|
||||||
|
inputs = inputs.tolist()
|
||||||
if inputs and len(inputs) > 0:
|
if inputs and len(inputs) > 0:
|
||||||
if not isinstance(inputs[0], Sequence):
|
if not isinstance(inputs[0], Sequence):
|
||||||
inputs = [inputs]
|
inputs = [inputs]
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue