LLM: first add _tokenize, detokenize and _generate for bloom pybinding (#8316)

This commit is contained in:
binbin Deng 2023-06-14 17:29:57 +08:00 committed by GitHub
parent 5576679a92
commit f64e703083
4 changed files with 300 additions and 20 deletions

View file

@ -46,35 +46,57 @@
# only search the first bigdl package and end up finding only one sub-package. # only search the first bigdl package and end up finding only one sub-package.
from .bloom_cpp import bloom_load, bloom_free, bloom_run from .bloom_cpp import bloom_load, bloom_free, bloom_run
from .bloom_cpp import bloom_tokenize, bloom_detokenize, bloom_forward, bloom_eval
from bigdl.llm.utils.common import invalidInputError from bigdl.llm.utils.common import invalidInputError
from typing import List, Optional from bigdl.llm.ggml.model.generation import GenerationMixin
from typing import List, Optional, Generator, Sequence, Union
import time import time
import uuid import uuid
class Bloom: class Bloom(GenerationMixin):
"""High-level Python wrapper for a bloom.cpp model.""" """High-level Python wrapper for a bloom.cpp model."""
def __init__(self, def __init__(
model_path: str, self,
n_ctx: int = 512, model_path: str,
seed: int = 1337, n_ctx: int = 512,
logits_all: bool = False, n_parts: int = -1,
n_threads: int = 2, n_gpu_layers: int = 0,
n_batch: int = 8, seed: int = -1,
last_n_tokens_size: int = 64, f16_kv: bool = True,
verbose: bool = True, logits_all: bool = False,
): vocab_only: bool = False,
use_mmap: bool = True,
use_mlock: bool = False,
embedding: bool = False,
n_threads: Optional[int] = 2,
n_batch: int = 512,
last_n_tokens_size: int = 64,
lora_base: Optional[str] = None,
lora_path: Optional[str] = None,
verbose: bool = True,
):
"""Load a bloom.cpp model from `model_path`. """Load a bloom.cpp model from `model_path`.
Args: Args:
model_path: Path to the model. model_path: Path to the model.
n_ctx: Maximum context size. n_ctx: Maximum context size.
seed: Random seed. 0 for random. n_parts: Number of parts to split the model into. If -1, the number of parts
is automatically determined.
seed: Random seed. For default value -1, current timestamp is used as seed.
f16_kv: Use half-precision for key/value cache.
logits_all: Return logits for all tokens, not just the last token. logits_all: Return logits for all tokens, not just the last token.
vocab_only: Only load the vocabulary no weights.
use_mmap: Use mmap if possible.
use_mlock: Force the system to keep the model in RAM.
embedding: Embedding mode only.
n_threads: Number of threads to use. Default to be 2. n_threads: Number of threads to use. Default to be 2.
n_batch: Maximum number of prompt tokens to batch together when calling llama_eval. n_batch: Maximum number of prompt tokens to batch together when calling bloom_eval.
last_n_tokens_size: Maximum number of tokens to keep in the last_n_tokens deque. last_n_tokens_size: Maximum number of tokens to keep in the last_n_tokens deque.
lora_base: Optional path to base model, useful if using a quantized base model and
you want to apply LoRA to an f16 model.
lora_path: Path to a LoRA file to apply to the model.
verbose: Print verbose output to stderr. verbose: Print verbose output to stderr.
Raises: Raises:
@ -87,15 +109,73 @@ class Bloom:
self.ctx = bloom_load(bytes(model_path, encoding='utf-8'), n_ctx, n_threads) self.ctx = bloom_load(bytes(model_path, encoding='utf-8'), n_ctx, n_threads)
invalidInputError(self.ctx is not None, f"Failed to load model from {model_path}") invalidInputError(self.ctx is not None, f"Failed to load model from {model_path}")
self.n_ctx = n_ctx self.n_ctx = n_ctx
self.n_parts = n_parts
self.n_gpu_layers = n_gpu_layers
self.f16_kv = f16_kv
self.seed = seed self.seed = seed
self.logits_all = logits_all self.logits_all = logits_all
self.vocab_only = vocab_only
self.use_mmap = use_mmap
self.use_mlock = use_mlock
self.embedding = embedding
self.n_threads = n_threads self.n_threads = n_threads
self.n_batch = n_batch self.n_batch = n_batch
self.last_n_tokens_size = last_n_tokens_size self.last_n_tokens_size = last_n_tokens_size
self.lora_base = lora_base
self.lora_path = lora_path
self.verbose = verbose self.verbose = verbose
# TODO: Some parameters are temporarily not supported
unsupported_arg = {'n_parts': -1, 'n_gpu_layers': 0, 'f16_kv': True, 'logits_all': False,
'vocab_only': False, 'use_mmap': True, 'use_mlock': False,
'embedding': False, 'last_n_tokens_size': 64, 'lora_base': None,
'lora_path': None, 'verbose': True}
for arg in unsupported_arg.keys():
invalidInputError(getattr(self, arg) == unsupported_arg[arg], f"The parameter {arg}"
" is temporarily unsupported, please use the default value.")
def __call__(
self,
prompt: str,
suffix: Optional[str] = None,
max_tokens: int = 128,
temperature: float = 0.8,
top_p: float = 0.95,
logprobs: Optional[int] = None,
echo: bool = False,
stop: Optional[Union[str, List[str]]]=[],
frequency_penalty: float = 0.0,
presence_penalty: float = 0.0,
repeat_penalty: float = 1.1,
top_k: int = 40,
stream: bool = False,
tfs_z: float = 1.0,
mirostat_mode: int = 0,
mirostat_tau: float = 5.0,
mirostat_eta: float = 0.1,
model: Optional[str] = None,
):
# TODO: Some parameters are temporarily not supported
# Unsupported parameters are checked in `_supported_call`
return self._supported_call(prompt, max_tokens, stream, stop,
suffix, temperature, top_p, logprobs, echo, frequency_penalty,
presence_penalty, repeat_penalty, top_k, tfs_z, mirostat_mode,
mirostat_tau, mirostat_eta, model)
def _supported_call(self, prompt: str, max_tokens: int, stream: bool = False,
stop: Optional[List[str]] = [], *args):
# Check unsupporeted parameters
unsupported_arg = ['suffix', 'temperature', 'top_p', 'logprobs', 'echo',
'frequency_penalty', 'presence_penalty', 'repeat_penalty', 'top_k',
'tfs_z', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', 'model']
defult_value = {'suffix': None, 'temperature': 0.8, 'top_p': 0.95, 'logprobs': None,
'echo': False, 'frequency_penalty': 0.0, 'presence_penalty': 0.0,
'repeat_penalty': 1.1, 'top_k': 40, 'tfs_z': 1.0, 'mirostat_mode': 0,
'mirostat_tau': 5.0, 'mirostat_eta': 0.1, 'model': None}
for index in range(len(args)):
invalidInputError(args[index] == defult_value[unsupported_arg[index]],
f"The parameter {unsupported_arg[index]} is temporarily "
"unsupported, please use the default value.")
def __call__(self, prompt: str, max_tokens: int = 128, stream: bool = False,
stop: Optional[List[str]] = []):
if stream: if stream:
return self.stream(prompt, max_tokens, stop) return self.stream(prompt, max_tokens, stop)
else: else:
@ -221,3 +301,113 @@ class Bloom:
def free(self): def free(self):
bloom_free(self.ctx) bloom_free(self.ctx)
def _tokenize(self, text: bytes, add_bos: bool = False) -> List[int]:
"""Tokenize a string.
Args:
text: The utf-8 encoded string to tokenize.
Raises:
RuntimeError: If the tokenization failed.
Returns:
A list of tokens.
"""
invalidInputError(self.ctx is not None, "The attribute `ctx` of `Bloom` object is None.")
return bloom_tokenize(self.ctx, text, False)
def detokenize(self, tokens: List[int]) -> bytes:
"""Detokenize a list of tokens.
Args:
tokens: The list of tokens to detokenize.
Returns:
The detokenized string.
"""
invalidInputError(self.ctx is not None, "The attribute `ctx` of `Bloom` object is None.")
output = ""
for token in tokens:
output += bloom_detokenize(self.ctx, token)
return output.encode('utf-8')
def forward(self, input_ids: List[int]) -> int:
return bloom_forward(ctx=self.ctx,
input_ids=input_ids,
seed=self.seed,
n_threads=self.n_threads,
n_batch=self.n_batch)
def eval(self, input_ids: List[int]) -> List[List[float]]:
"""Only used for testing accuracy"""
return bloom_eval(ctx=self.ctx,
input_ids=input_ids,
seed=self.seed,
n_threads=self.n_threads,
n_batch=len(input_ids))
def _generate(
self,
tokens: Sequence[int],
top_k: int = 40,
top_p: float = 0.95,
temp: float = 0.80,
repeat_penalty: float = 1.1,
reset: bool = True,
frequency_penalty: float = 0.0,
presence_penalty: float = 0.0,
tfs_z: float = 1.0,
mirostat_mode: int = 0,
mirostat_tau: float = 5.0,
mirostat_eta: float = 0.1,
) -> Generator[int, Optional[Sequence[int]], None]:
"""Create a generator of tokens from a prompt.
Examples:
>>> llm = Bloom(your_model_path)
>>> tokens = llm._tokenize(b"Learning English is")
>>> for token in llm._generate(tokens):
>>> print(llm.detokenize([token]).decode("utf-8", errors="ignore"))
Args:
tokens: The prompt tokens.
Yields:
The generated tokens.
"""
# TODO: Some parameters are temporarily not supported
# Unsupported parameters are checked in `_supported_generate`
return self._supported_generate(tokens, top_k, top_p, temp, repeat_penalty, reset,
frequency_penalty, presence_penalty, tfs_z, mirostat_mode,
mirostat_tau, mirostat_eta)
def _supported_generate(self, tokens: Sequence[int], *args):
# Check unsupporeted parameters
unsupported_arg = ['top_k', 'top_p', 'temp', 'repeat_penalty', 'reset',
'frequency_penalty', 'presence_penalty', 'tfs_z', 'mirostat_mode',
'mirostat_tau', 'mirostat_eta']
defult_value = {'top_k': 40, 'top_p': 0.95, 'temp': 0.80, 'repeat_penalty': 1.1,
'reset': True, 'frequency_penalty': 0.0, 'presence_penalty': 0.0,
'tfs_z': 1.0, 'mirostat_mode': 0, 'mirostat_tau': 5.0, 'mirostat_eta': 0.1}
for index in range(len(args)):
invalidInputError(args[index] == defult_value[unsupported_arg[index]],
f"The parameter {unsupported_arg[index]} is temporarily "
"unsupported, please use the default value.")
invalidInputError(self.ctx is not None, "The attribute `ctx` of `Bloom` object is None.")
while True:
token = self.forward(tokens)
tokens_or_none = yield token
tokens.append(token)
if tokens_or_none is not None:
tokens.extend(tokens_or_none)
def embed(self, prompt: Union[str, bytes]) -> List[float]:
"""Only used for langchain"""
input_ids = self.tokenize(prompt)
return bloom_embed(ctx=self.ctx,
input_ids=input_ids,
seed=self.seed,
n_threads=self.n_threads,
n_batch=len(input_ids))

View file

@ -48,13 +48,16 @@
import sys import sys
import os import os
import ctypes import ctypes
from typing import List
from ctypes import ( from ctypes import (
c_int, c_int,
c_long,
c_float, c_float,
c_char_p, c_char_p,
c_void_p, c_void_p,
c_bool, c_bool,
POINTER, POINTER,
pointer,
Structure, Structure,
Array, Array,
c_uint8, c_uint8,
@ -116,6 +119,14 @@ _lib_base_name = "bloom"
_lib = _load_shared_library(_lib_base_name) _lib = _load_shared_library(_lib_base_name)
def c_free(p: c_void_p):
_lib.c_free(p)
_lib.c_free.argtypes = [c_void_p]
_lib.c_free.restype = None
def bloom_load(fname: bytes, n_ctx: c_int, n_threads: c_int) -> c_void_p: def bloom_load(fname: bytes, n_ctx: c_int, n_threads: c_int) -> c_void_p:
return _lib.bloom_load(fname, n_ctx, n_threads) return _lib.bloom_load(fname, n_ctx, n_threads)
@ -146,4 +157,83 @@ def bloom_run(ctx: c_void_p,
_lib.bloom_run.argtypes = [c_void_p, c_int, c_int, c_int, c_int, c_bool, c_char_p, c_char_p] _lib.bloom_run.argtypes = [c_void_p, c_int, c_int, c_int, c_int, c_bool, c_char_p, c_char_p]
_lib.bloom_run.restype = c_int _lib.bloom_run.restype = c_int
def bloom_tokenize(ctx: c_void_p,
prompt: bytes,
bos: bool = False) -> List[int]:
n_tokens = c_int(0)
c_tokens = _lib.tokenize_api(ctx, prompt, bos, pointer(n_tokens))
tokens = [c_tokens[i] for i in range(0, n_tokens.value)]
c_free(c_tokens)
return tokens
_lib.tokenize_api.argtypes = [c_void_p, c_char_p, c_bool, c_void_p]
_lib.tokenize_api.restype = POINTER(c_int)
def bloom_detokenize(ctx: c_void_p,
token_id: c_int) -> str:
c_chars = _lib.detokenize_api(ctx, token_id)
s = c_chars.decode('utf-8')
return s
_lib.detokenize_api.argtypes = [c_void_p, c_int]
_lib.detokenize_api.restype = c_char_p
def bloom_eval(ctx: c_void_p,
input_ids: List[int],
seed: c_int,
n_threads: c_int,
n_batch: c_int) -> List[List[float]]:
length = len(input_ids)
c_input_ids = (c_int * length)(*input_ids)
n_logits = c_long(0)
c_logits = _lib.eval_api(ctx, c_input_ids, length, seed, n_threads, n_batch, pointer(n_logits))
n_vocab = n_logits.value // length
assert(n_vocab * length == n_logits.value)
logits = [[c_logits[i * n_vocab + j] for j in range(n_vocab)] for i in range(length)]
# do not free c_logits
return logits
_lib.eval_api.argtypes = [c_void_p, c_void_p, c_int, c_int, c_int, c_int, c_void_p]
_lib.eval_api.restype = POINTER(c_float)
def bloom_embed(ctx: c_void_p,
input_ids: List[int],
seed: c_int,
n_threads: c_int,
n_batch: c_int) -> List[float]:
length = len(input_ids)
c_input_ids = (c_int * length)(*input_ids)
n_embd = c_long(0)
c_embeddings = _lib.embed_api(ctx, c_input_ids, length, seed, n_threads,
n_batch, pointer(n_embd))
embeddings = [c_embeddings[i] for i in range(n_embd.value)]
# do not free c_embeddings
return embeddings
_lib.embed_api.argtypes = [c_void_p, c_void_p, c_int, c_int, c_int, c_int, c_void_p]
_lib.embed_api.restype = POINTER(c_float)
def bloom_forward(ctx: c_void_p,
input_ids: List[int],
seed: c_int,
n_threads: c_int,
n_batch: c_int) -> int:
length = len(input_ids)
c_input_ids = (c_int * length)(*input_ids)
token_id = _lib.forward_api(ctx, c_input_ids, length, seed, n_threads, n_batch)
return token_id
_lib.forward_api.argtypes = [c_void_p, c_void_p, c_int, c_int, c_int, c_int]
_lib.forward_api.restype = c_int
# ------------------------------------------------------------------- # # ------------------------------------------------------------------- #

View file

@ -132,7 +132,7 @@ class Gptneox(GenerationMixin):
n_ctx: int = 512, n_ctx: int = 512,
n_parts: int = -1, n_parts: int = -1,
n_gpu_layers: int = 0, n_gpu_layers: int = 0,
seed: int = 1337, seed: int = -1,
f16_kv: bool = True, f16_kv: bool = True,
logits_all: bool = False, logits_all: bool = False,
vocab_only: bool = False, vocab_only: bool = False,
@ -153,7 +153,7 @@ class Gptneox(GenerationMixin):
n_ctx: Maximum context size. n_ctx: Maximum context size.
n_parts: Number of parts to split the model into. If -1, n_parts: Number of parts to split the model into. If -1,
the number of parts is automatically determined. the number of parts is automatically determined.
seed: Random seed. 0 for random. seed: Random seed. For default value -1, current timestamp is used as seed.
f16_kv: Use half-precision for key/value cache. f16_kv: Use half-precision for key/value cache.
logits_all: Return logits for all tokens, not just the last token. logits_all: Return logits for all tokens, not just the last token.
vocab_only: Only load the vocabulary no weights. vocab_only: Only load the vocabulary no weights.

View file

@ -130,7 +130,7 @@ class Llama(GenerationMixin):
n_ctx: int = 512, n_ctx: int = 512,
n_parts: int = -1, n_parts: int = -1,
n_gpu_layers: int = 0, n_gpu_layers: int = 0,
seed: int = 1337, seed: int = -1,
f16_kv: bool = True, f16_kv: bool = True,
logits_all: bool = False, logits_all: bool = False,
vocab_only: bool = False, vocab_only: bool = False,
@ -151,7 +151,7 @@ class Llama(GenerationMixin):
n_ctx: Maximum context size. n_ctx: Maximum context size.
n_parts: Number of parts to split the model into. If -1, the number of parts n_parts: Number of parts to split the model into. If -1, the number of parts
is automatically determined. is automatically determined.
seed: Random seed. 0 for random. seed: Random seed. For default value -1, current timestamp is used as seed.
f16_kv: Use half-precision for key/value cache. f16_kv: Use half-precision for key/value cache.
logits_all: Return logits for all tokens, not just the last token. logits_all: Return logits for all tokens, not just the last token.
vocab_only: Only load the vocabulary no weights. vocab_only: Only load the vocabulary no weights.