LLM: first add _tokenize, detokenize and _generate for bloom pybinding (#8316)
This commit is contained in:
parent
5576679a92
commit
f64e703083
4 changed files with 300 additions and 20 deletions
|
|
@ -46,35 +46,57 @@
|
|||
# only search the first bigdl package and end up finding only one sub-package.
|
||||
|
||||
from .bloom_cpp import bloom_load, bloom_free, bloom_run
|
||||
from .bloom_cpp import bloom_tokenize, bloom_detokenize, bloom_forward, bloom_eval
|
||||
from bigdl.llm.utils.common import invalidInputError
|
||||
from typing import List, Optional
|
||||
from bigdl.llm.ggml.model.generation import GenerationMixin
|
||||
from typing import List, Optional, Generator, Sequence, Union
|
||||
import time
|
||||
import uuid
|
||||
|
||||
|
||||
class Bloom:
|
||||
class Bloom(GenerationMixin):
|
||||
"""High-level Python wrapper for a bloom.cpp model."""
|
||||
|
||||
def __init__(self,
|
||||
model_path: str,
|
||||
n_ctx: int = 512,
|
||||
seed: int = 1337,
|
||||
logits_all: bool = False,
|
||||
n_threads: int = 2,
|
||||
n_batch: int = 8,
|
||||
last_n_tokens_size: int = 64,
|
||||
verbose: bool = True,
|
||||
):
|
||||
def __init__(
|
||||
self,
|
||||
model_path: str,
|
||||
n_ctx: int = 512,
|
||||
n_parts: int = -1,
|
||||
n_gpu_layers: int = 0,
|
||||
seed: int = -1,
|
||||
f16_kv: bool = True,
|
||||
logits_all: bool = False,
|
||||
vocab_only: bool = False,
|
||||
use_mmap: bool = True,
|
||||
use_mlock: bool = False,
|
||||
embedding: bool = False,
|
||||
n_threads: Optional[int] = 2,
|
||||
n_batch: int = 512,
|
||||
last_n_tokens_size: int = 64,
|
||||
lora_base: Optional[str] = None,
|
||||
lora_path: Optional[str] = None,
|
||||
verbose: bool = True,
|
||||
):
|
||||
"""Load a bloom.cpp model from `model_path`.
|
||||
|
||||
Args:
|
||||
model_path: Path to the model.
|
||||
n_ctx: Maximum context size.
|
||||
seed: Random seed. 0 for random.
|
||||
n_parts: Number of parts to split the model into. If -1, the number of parts
|
||||
is automatically determined.
|
||||
seed: Random seed. For default value -1, current timestamp is used as seed.
|
||||
f16_kv: Use half-precision for key/value cache.
|
||||
logits_all: Return logits for all tokens, not just the last token.
|
||||
vocab_only: Only load the vocabulary no weights.
|
||||
use_mmap: Use mmap if possible.
|
||||
use_mlock: Force the system to keep the model in RAM.
|
||||
embedding: Embedding mode only.
|
||||
n_threads: Number of threads to use. Default to be 2.
|
||||
n_batch: Maximum number of prompt tokens to batch together when calling llama_eval.
|
||||
n_batch: Maximum number of prompt tokens to batch together when calling bloom_eval.
|
||||
last_n_tokens_size: Maximum number of tokens to keep in the last_n_tokens deque.
|
||||
lora_base: Optional path to base model, useful if using a quantized base model and
|
||||
you want to apply LoRA to an f16 model.
|
||||
lora_path: Path to a LoRA file to apply to the model.
|
||||
verbose: Print verbose output to stderr.
|
||||
|
||||
Raises:
|
||||
|
|
@ -87,15 +109,73 @@ class Bloom:
|
|||
self.ctx = bloom_load(bytes(model_path, encoding='utf-8'), n_ctx, n_threads)
|
||||
invalidInputError(self.ctx is not None, f"Failed to load model from {model_path}")
|
||||
self.n_ctx = n_ctx
|
||||
self.n_parts = n_parts
|
||||
self.n_gpu_layers = n_gpu_layers
|
||||
self.f16_kv = f16_kv
|
||||
self.seed = seed
|
||||
self.logits_all = logits_all
|
||||
self.vocab_only = vocab_only
|
||||
self.use_mmap = use_mmap
|
||||
self.use_mlock = use_mlock
|
||||
self.embedding = embedding
|
||||
self.n_threads = n_threads
|
||||
self.n_batch = n_batch
|
||||
self.last_n_tokens_size = last_n_tokens_size
|
||||
self.lora_base = lora_base
|
||||
self.lora_path = lora_path
|
||||
self.verbose = verbose
|
||||
# TODO: Some parameters are temporarily not supported
|
||||
unsupported_arg = {'n_parts': -1, 'n_gpu_layers': 0, 'f16_kv': True, 'logits_all': False,
|
||||
'vocab_only': False, 'use_mmap': True, 'use_mlock': False,
|
||||
'embedding': False, 'last_n_tokens_size': 64, 'lora_base': None,
|
||||
'lora_path': None, 'verbose': True}
|
||||
for arg in unsupported_arg.keys():
|
||||
invalidInputError(getattr(self, arg) == unsupported_arg[arg], f"The parameter {arg}"
|
||||
" is temporarily unsupported, please use the default value.")
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
prompt: str,
|
||||
suffix: Optional[str] = None,
|
||||
max_tokens: int = 128,
|
||||
temperature: float = 0.8,
|
||||
top_p: float = 0.95,
|
||||
logprobs: Optional[int] = None,
|
||||
echo: bool = False,
|
||||
stop: Optional[Union[str, List[str]]]=[],
|
||||
frequency_penalty: float = 0.0,
|
||||
presence_penalty: float = 0.0,
|
||||
repeat_penalty: float = 1.1,
|
||||
top_k: int = 40,
|
||||
stream: bool = False,
|
||||
tfs_z: float = 1.0,
|
||||
mirostat_mode: int = 0,
|
||||
mirostat_tau: float = 5.0,
|
||||
mirostat_eta: float = 0.1,
|
||||
model: Optional[str] = None,
|
||||
):
|
||||
# TODO: Some parameters are temporarily not supported
|
||||
# Unsupported parameters are checked in `_supported_call`
|
||||
return self._supported_call(prompt, max_tokens, stream, stop,
|
||||
suffix, temperature, top_p, logprobs, echo, frequency_penalty,
|
||||
presence_penalty, repeat_penalty, top_k, tfs_z, mirostat_mode,
|
||||
mirostat_tau, mirostat_eta, model)
|
||||
|
||||
def _supported_call(self, prompt: str, max_tokens: int, stream: bool = False,
|
||||
stop: Optional[List[str]] = [], *args):
|
||||
# Check unsupporeted parameters
|
||||
unsupported_arg = ['suffix', 'temperature', 'top_p', 'logprobs', 'echo',
|
||||
'frequency_penalty', 'presence_penalty', 'repeat_penalty', 'top_k',
|
||||
'tfs_z', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', 'model']
|
||||
defult_value = {'suffix': None, 'temperature': 0.8, 'top_p': 0.95, 'logprobs': None,
|
||||
'echo': False, 'frequency_penalty': 0.0, 'presence_penalty': 0.0,
|
||||
'repeat_penalty': 1.1, 'top_k': 40, 'tfs_z': 1.0, 'mirostat_mode': 0,
|
||||
'mirostat_tau': 5.0, 'mirostat_eta': 0.1, 'model': None}
|
||||
for index in range(len(args)):
|
||||
invalidInputError(args[index] == defult_value[unsupported_arg[index]],
|
||||
f"The parameter {unsupported_arg[index]} is temporarily "
|
||||
"unsupported, please use the default value.")
|
||||
|
||||
def __call__(self, prompt: str, max_tokens: int = 128, stream: bool = False,
|
||||
stop: Optional[List[str]] = []):
|
||||
if stream:
|
||||
return self.stream(prompt, max_tokens, stop)
|
||||
else:
|
||||
|
|
@ -221,3 +301,113 @@ class Bloom:
|
|||
|
||||
def free(self):
|
||||
bloom_free(self.ctx)
|
||||
|
||||
def _tokenize(self, text: bytes, add_bos: bool = False) -> List[int]:
|
||||
"""Tokenize a string.
|
||||
|
||||
Args:
|
||||
text: The utf-8 encoded string to tokenize.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If the tokenization failed.
|
||||
|
||||
Returns:
|
||||
A list of tokens.
|
||||
"""
|
||||
invalidInputError(self.ctx is not None, "The attribute `ctx` of `Bloom` object is None.")
|
||||
return bloom_tokenize(self.ctx, text, False)
|
||||
|
||||
def detokenize(self, tokens: List[int]) -> bytes:
|
||||
"""Detokenize a list of tokens.
|
||||
|
||||
Args:
|
||||
tokens: The list of tokens to detokenize.
|
||||
|
||||
Returns:
|
||||
The detokenized string.
|
||||
"""
|
||||
invalidInputError(self.ctx is not None, "The attribute `ctx` of `Bloom` object is None.")
|
||||
output = ""
|
||||
for token in tokens:
|
||||
output += bloom_detokenize(self.ctx, token)
|
||||
return output.encode('utf-8')
|
||||
|
||||
def forward(self, input_ids: List[int]) -> int:
|
||||
return bloom_forward(ctx=self.ctx,
|
||||
input_ids=input_ids,
|
||||
seed=self.seed,
|
||||
n_threads=self.n_threads,
|
||||
n_batch=self.n_batch)
|
||||
|
||||
def eval(self, input_ids: List[int]) -> List[List[float]]:
|
||||
"""Only used for testing accuracy"""
|
||||
return bloom_eval(ctx=self.ctx,
|
||||
input_ids=input_ids,
|
||||
seed=self.seed,
|
||||
n_threads=self.n_threads,
|
||||
n_batch=len(input_ids))
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
tokens: Sequence[int],
|
||||
top_k: int = 40,
|
||||
top_p: float = 0.95,
|
||||
temp: float = 0.80,
|
||||
repeat_penalty: float = 1.1,
|
||||
reset: bool = True,
|
||||
frequency_penalty: float = 0.0,
|
||||
presence_penalty: float = 0.0,
|
||||
tfs_z: float = 1.0,
|
||||
mirostat_mode: int = 0,
|
||||
mirostat_tau: float = 5.0,
|
||||
mirostat_eta: float = 0.1,
|
||||
) -> Generator[int, Optional[Sequence[int]], None]:
|
||||
"""Create a generator of tokens from a prompt.
|
||||
|
||||
Examples:
|
||||
>>> llm = Bloom(your_model_path)
|
||||
>>> tokens = llm._tokenize(b"Learning English is")
|
||||
>>> for token in llm._generate(tokens):
|
||||
>>> print(llm.detokenize([token]).decode("utf-8", errors="ignore"))
|
||||
|
||||
Args:
|
||||
tokens: The prompt tokens.
|
||||
|
||||
Yields:
|
||||
The generated tokens.
|
||||
"""
|
||||
# TODO: Some parameters are temporarily not supported
|
||||
# Unsupported parameters are checked in `_supported_generate`
|
||||
return self._supported_generate(tokens, top_k, top_p, temp, repeat_penalty, reset,
|
||||
frequency_penalty, presence_penalty, tfs_z, mirostat_mode,
|
||||
mirostat_tau, mirostat_eta)
|
||||
|
||||
def _supported_generate(self, tokens: Sequence[int], *args):
|
||||
# Check unsupporeted parameters
|
||||
unsupported_arg = ['top_k', 'top_p', 'temp', 'repeat_penalty', 'reset',
|
||||
'frequency_penalty', 'presence_penalty', 'tfs_z', 'mirostat_mode',
|
||||
'mirostat_tau', 'mirostat_eta']
|
||||
defult_value = {'top_k': 40, 'top_p': 0.95, 'temp': 0.80, 'repeat_penalty': 1.1,
|
||||
'reset': True, 'frequency_penalty': 0.0, 'presence_penalty': 0.0,
|
||||
'tfs_z': 1.0, 'mirostat_mode': 0, 'mirostat_tau': 5.0, 'mirostat_eta': 0.1}
|
||||
for index in range(len(args)):
|
||||
invalidInputError(args[index] == defult_value[unsupported_arg[index]],
|
||||
f"The parameter {unsupported_arg[index]} is temporarily "
|
||||
"unsupported, please use the default value.")
|
||||
|
||||
invalidInputError(self.ctx is not None, "The attribute `ctx` of `Bloom` object is None.")
|
||||
while True:
|
||||
token = self.forward(tokens)
|
||||
tokens_or_none = yield token
|
||||
tokens.append(token)
|
||||
if tokens_or_none is not None:
|
||||
tokens.extend(tokens_or_none)
|
||||
|
||||
def embed(self, prompt: Union[str, bytes]) -> List[float]:
|
||||
"""Only used for langchain"""
|
||||
input_ids = self.tokenize(prompt)
|
||||
return bloom_embed(ctx=self.ctx,
|
||||
input_ids=input_ids,
|
||||
seed=self.seed,
|
||||
n_threads=self.n_threads,
|
||||
n_batch=len(input_ids))
|
||||
|
|
|
|||
|
|
@ -48,13 +48,16 @@
|
|||
import sys
|
||||
import os
|
||||
import ctypes
|
||||
from typing import List
|
||||
from ctypes import (
|
||||
c_int,
|
||||
c_long,
|
||||
c_float,
|
||||
c_char_p,
|
||||
c_void_p,
|
||||
c_bool,
|
||||
POINTER,
|
||||
pointer,
|
||||
Structure,
|
||||
Array,
|
||||
c_uint8,
|
||||
|
|
@ -116,6 +119,14 @@ _lib_base_name = "bloom"
|
|||
_lib = _load_shared_library(_lib_base_name)
|
||||
|
||||
|
||||
def c_free(p: c_void_p):
|
||||
_lib.c_free(p)
|
||||
|
||||
|
||||
_lib.c_free.argtypes = [c_void_p]
|
||||
_lib.c_free.restype = None
|
||||
|
||||
|
||||
def bloom_load(fname: bytes, n_ctx: c_int, n_threads: c_int) -> c_void_p:
|
||||
return _lib.bloom_load(fname, n_ctx, n_threads)
|
||||
|
||||
|
|
@ -146,4 +157,83 @@ def bloom_run(ctx: c_void_p,
|
|||
_lib.bloom_run.argtypes = [c_void_p, c_int, c_int, c_int, c_int, c_bool, c_char_p, c_char_p]
|
||||
_lib.bloom_run.restype = c_int
|
||||
|
||||
|
||||
def bloom_tokenize(ctx: c_void_p,
|
||||
prompt: bytes,
|
||||
bos: bool = False) -> List[int]:
|
||||
n_tokens = c_int(0)
|
||||
c_tokens = _lib.tokenize_api(ctx, prompt, bos, pointer(n_tokens))
|
||||
tokens = [c_tokens[i] for i in range(0, n_tokens.value)]
|
||||
c_free(c_tokens)
|
||||
return tokens
|
||||
|
||||
|
||||
_lib.tokenize_api.argtypes = [c_void_p, c_char_p, c_bool, c_void_p]
|
||||
_lib.tokenize_api.restype = POINTER(c_int)
|
||||
|
||||
|
||||
def bloom_detokenize(ctx: c_void_p,
|
||||
token_id: c_int) -> str:
|
||||
c_chars = _lib.detokenize_api(ctx, token_id)
|
||||
s = c_chars.decode('utf-8')
|
||||
return s
|
||||
|
||||
|
||||
_lib.detokenize_api.argtypes = [c_void_p, c_int]
|
||||
_lib.detokenize_api.restype = c_char_p
|
||||
|
||||
|
||||
def bloom_eval(ctx: c_void_p,
|
||||
input_ids: List[int],
|
||||
seed: c_int,
|
||||
n_threads: c_int,
|
||||
n_batch: c_int) -> List[List[float]]:
|
||||
length = len(input_ids)
|
||||
c_input_ids = (c_int * length)(*input_ids)
|
||||
n_logits = c_long(0)
|
||||
c_logits = _lib.eval_api(ctx, c_input_ids, length, seed, n_threads, n_batch, pointer(n_logits))
|
||||
n_vocab = n_logits.value // length
|
||||
assert(n_vocab * length == n_logits.value)
|
||||
logits = [[c_logits[i * n_vocab + j] for j in range(n_vocab)] for i in range(length)]
|
||||
# do not free c_logits
|
||||
return logits
|
||||
|
||||
|
||||
_lib.eval_api.argtypes = [c_void_p, c_void_p, c_int, c_int, c_int, c_int, c_void_p]
|
||||
_lib.eval_api.restype = POINTER(c_float)
|
||||
|
||||
|
||||
def bloom_embed(ctx: c_void_p,
|
||||
input_ids: List[int],
|
||||
seed: c_int,
|
||||
n_threads: c_int,
|
||||
n_batch: c_int) -> List[float]:
|
||||
length = len(input_ids)
|
||||
c_input_ids = (c_int * length)(*input_ids)
|
||||
n_embd = c_long(0)
|
||||
c_embeddings = _lib.embed_api(ctx, c_input_ids, length, seed, n_threads,
|
||||
n_batch, pointer(n_embd))
|
||||
embeddings = [c_embeddings[i] for i in range(n_embd.value)]
|
||||
# do not free c_embeddings
|
||||
return embeddings
|
||||
|
||||
|
||||
_lib.embed_api.argtypes = [c_void_p, c_void_p, c_int, c_int, c_int, c_int, c_void_p]
|
||||
_lib.embed_api.restype = POINTER(c_float)
|
||||
|
||||
|
||||
def bloom_forward(ctx: c_void_p,
|
||||
input_ids: List[int],
|
||||
seed: c_int,
|
||||
n_threads: c_int,
|
||||
n_batch: c_int) -> int:
|
||||
length = len(input_ids)
|
||||
c_input_ids = (c_int * length)(*input_ids)
|
||||
token_id = _lib.forward_api(ctx, c_input_ids, length, seed, n_threads, n_batch)
|
||||
return token_id
|
||||
|
||||
|
||||
_lib.forward_api.argtypes = [c_void_p, c_void_p, c_int, c_int, c_int, c_int]
|
||||
_lib.forward_api.restype = c_int
|
||||
|
||||
# ------------------------------------------------------------------- #
|
||||
|
|
|
|||
|
|
@ -132,7 +132,7 @@ class Gptneox(GenerationMixin):
|
|||
n_ctx: int = 512,
|
||||
n_parts: int = -1,
|
||||
n_gpu_layers: int = 0,
|
||||
seed: int = 1337,
|
||||
seed: int = -1,
|
||||
f16_kv: bool = True,
|
||||
logits_all: bool = False,
|
||||
vocab_only: bool = False,
|
||||
|
|
@ -153,7 +153,7 @@ class Gptneox(GenerationMixin):
|
|||
n_ctx: Maximum context size.
|
||||
n_parts: Number of parts to split the model into. If -1,
|
||||
the number of parts is automatically determined.
|
||||
seed: Random seed. 0 for random.
|
||||
seed: Random seed. For default value -1, current timestamp is used as seed.
|
||||
f16_kv: Use half-precision for key/value cache.
|
||||
logits_all: Return logits for all tokens, not just the last token.
|
||||
vocab_only: Only load the vocabulary no weights.
|
||||
|
|
|
|||
|
|
@ -130,7 +130,7 @@ class Llama(GenerationMixin):
|
|||
n_ctx: int = 512,
|
||||
n_parts: int = -1,
|
||||
n_gpu_layers: int = 0,
|
||||
seed: int = 1337,
|
||||
seed: int = -1,
|
||||
f16_kv: bool = True,
|
||||
logits_all: bool = False,
|
||||
vocab_only: bool = False,
|
||||
|
|
@ -151,7 +151,7 @@ class Llama(GenerationMixin):
|
|||
n_ctx: Maximum context size.
|
||||
n_parts: Number of parts to split the model into. If -1, the number of parts
|
||||
is automatically determined.
|
||||
seed: Random seed. 0 for random.
|
||||
seed: Random seed. For default value -1, current timestamp is used as seed.
|
||||
f16_kv: Use half-precision for key/value cache.
|
||||
logits_all: Return logits for all tokens, not just the last token.
|
||||
vocab_only: Only load the vocabulary no weights.
|
||||
|
|
|
|||
Loading…
Reference in a new issue