LLM: fix and update related license in llama pybinding (#8250)

This commit is contained in:
binbin Deng 2023-06-01 17:09:15 +08:00 committed by GitHub
parent 141febec1f
commit 3a9aa23835
3 changed files with 212 additions and 96 deletions

View file

@ -14,6 +14,33 @@
# limitations under the License.
#
# ===========================================================================
#
# This file is adapted from
# https://github.com/abetlen/llama-cpp-python/blob/main/llama_cpp/llama.py
#
# MIT License
#
# Copyright (c) 2023 Andrei Betlen
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
# This would makes sure Python is aware there is more than one sub-package within bigdl,
# physically located elsewhere.
# Otherwise there would be module not found error in non-pip's setting as Python would
@ -27,7 +54,7 @@ import math
import multiprocessing
from typing import List, Optional, Union, Generator, Sequence, Iterator, Deque, Tuple
from collections import deque, OrderedDict
from bigdl.llm.utils.common import invalidInputError
from . import llama_cpp
from .llama_types import *
@ -61,8 +88,7 @@ class LlamaCache:
def __getitem__(self, key: Sequence[int]) -> "LlamaState":
key = tuple(key)
_key = self._find_longest_prefix_key(key)
if _key is None:
raise KeyError(f"Key not found")
invalidInputError(_key is not None, "Key not found.")
value = self.cache_state[_key]
self.cache_state.move_to_end(_key)
return value
@ -122,7 +148,8 @@ class Llama:
Args:
model_path: Path to the model.
n_ctx: Maximum context size.
n_parts: Number of parts to split the model into. If -1, the number of parts is automatically determined.
n_parts: Number of parts to split the model into. If -1, the number of parts
is automatically determined.
seed: Random seed. 0 for random.
f16_kv: Use half-precision for key/value cache.
logits_all: Return logits for all tokens, not just the last token.
@ -130,10 +157,12 @@ class Llama:
use_mmap: Use mmap if possible.
use_mlock: Force the system to keep the model in RAM.
embedding: Embedding mode only.
n_threads: Number of threads to use. If None, the number of threads is automatically determined.
n_threads: Number of threads to use. If None, the number of threads is
automatically determined.
n_batch: Maximum number of prompt tokens to batch together when calling llama_eval.
last_n_tokens_size: Maximum number of tokens to keep in the last_n_tokens deque.
lora_base: Optional path to base model, useful if using a quantized base model and you want to apply LoRA to an f16 model.
lora_base: Optional path to base model, useful if using a quantized base model and
you want to apply LoRA to an f16 model.
lora_path: Path to a LoRA file to apply to the model.
verbose: Print verbose output to stderr.
@ -169,18 +198,17 @@ class Llama:
self.lora_base = lora_base
self.lora_path = lora_path
### DEPRECATED ###
# DEPRECATED
self.n_parts = n_parts
### DEPRECATED ###
# DEPRECATED
if not os.path.exists(model_path):
raise ValueError(f"Model path does not exist: {model_path}")
invalidInputError(os.path.exists(model_path), f"Model path does not exist: {model_path}.")
self.ctx = llama_cpp.llama_init_from_file(
self.model_path.encode("utf-8"), self.params
)
assert self.ctx is not None
invalidInputError(self.ctx is not None, "The attribute `ctx` of `Llama` object is None.")
if self.lora_path:
if llama_cpp.llama_apply_lora_from_file(
@ -191,9 +219,9 @@ class Llama:
else llama_cpp.c_char_p(0),
llama_cpp.c_int(self.n_threads),
):
raise RuntimeError(
f"Failed to apply LoRA from lora path: {self.lora_path} to base path: {self.lora_base}"
)
invalidInputError(False,
"Failed to apply LoRA from lora path: "
f"{self.lora_path} to base path: {self.lora_base}")
if self.verbose:
print(llama_cpp.llama_print_system_info().decode("utf-8"), file=sys.stderr)
@ -233,7 +261,7 @@ class Llama:
Returns:
A list of tokens.
"""
assert self.ctx is not None
invalidInputError(self.ctx is not None, "The attribute `ctx` of `Llama` object is None.")
n_ctx = llama_cpp.llama_n_ctx(self.ctx)
tokens = (llama_cpp.llama_token * int(n_ctx))()
n_tokens = llama_cpp.llama_tokenize(
@ -253,10 +281,8 @@ class Llama:
llama_cpp.c_int(n_tokens),
llama_cpp.c_bool(add_bos),
)
if n_tokens < 0:
raise RuntimeError(
f'Failed to tokenize: text="{text}" n_tokens={n_tokens}'
)
invalidInputError(n_tokens >= 0,
f'Failed to tokenize: text="{text}" n_tokens={n_tokens}')
return list(tokens[:n_tokens])
def detokenize(self, tokens: List[int]) -> bytes:
@ -268,7 +294,7 @@ class Llama:
Returns:
The detokenized string.
"""
assert self.ctx is not None
invalidInputError(self.ctx is not None, "The attribute `ctx` of `Llama` object is None.")
output = b""
for token in tokens:
output += llama_cpp.llama_token_to_str(
@ -295,7 +321,7 @@ class Llama:
Args:
tokens: The list of tokens to evaluate.
"""
assert self.ctx is not None
invalidInputError(self.ctx is not None, "The attribute `ctx` of `Llama` object is None.")
n_ctx = int(llama_cpp.llama_n_ctx(self.ctx))
for i in range(0, len(tokens), self.n_batch):
batch = tokens[i: min(len(tokens), i + self.n_batch)]
@ -308,8 +334,7 @@ class Llama:
n_past=llama_cpp.c_int(n_past),
n_threads=llama_cpp.c_int(self.n_threads),
)
if int(return_code) != 0:
raise RuntimeError(f"llama_eval returned {return_code}")
invalidInputError(int(return_code) == 0, f"llama_eval returned {return_code}.")
# Save tokens
self.eval_tokens.extend(batch)
# Save logits
@ -338,8 +363,9 @@ class Llama:
mirostat_eta: llama_cpp.c_float,
penalize_nl: bool = True,
):
assert self.ctx is not None
assert len(self.eval_logits) > 0
invalidInputError(self.ctx is not None, "The attribute `ctx` of `Llama` object is None.")
invalidInputError(len(self.eval_logits) > 0,
"The attribute `eval_logits` of `Llama` object is None.")
n_vocab = self.n_vocab()
n_ctx = self.n_ctx()
top_k = llama_cpp.c_int(n_vocab) if top_k.value <= 0 else top_k
@ -467,7 +493,7 @@ class Llama:
Returns:
The sampled token.
"""
assert self.ctx is not None
invalidInputError(self.ctx is not None, "The attribute `ctx` of `Llama` object is None.")
last_n_tokens_data = [llama_cpp.llama_token(0)] * max(
0, self.last_n_tokens_size - len(self.eval_tokens)
) + list(self.eval_tokens)[-self.last_n_tokens_size:]
@ -509,7 +535,8 @@ class Llama:
Examples:
>>> llama = Llama("models/ggml-7b.bin")
>>> tokens = llama.tokenize(b"Hello, world!")
>>> for token in llama.generate(tokens, top_k=40, top_p=0.95, temp=1.0, repeat_penalty=1.1):
>>> for token in llama.generate(tokens, top_k=40, top_p=0.95,
>>> temp=1.0, repeat_penalty=1.1):
... print(llama.detokenize([token]))
Args:
@ -523,7 +550,7 @@ class Llama:
Yields:
The generated tokens.
"""
assert self.ctx is not None
invalidInputError(self.ctx is not None, "The attribute `ctx` of `Llama` object is None.")
if reset and len(self.eval_tokens) > 0:
longest_prefix = 0
@ -577,13 +604,11 @@ class Llama:
Returns:
An embedding object.
"""
assert self.ctx is not None
invalidInputError(self.ctx is not None, "The attribute `ctx` of `Llama` object is None.")
model_name: str = model if model is not None else self.model_path
if self.params.embedding == False:
raise RuntimeError(
"Llama model must be created with embedding=True to call this method"
)
invalidInputError(self.params.embedding,
"Llama model must be created with embedding=True to call this method.")
if self.verbose:
llama_cpp.llama_reset_timings(self.ctx)
@ -657,7 +682,7 @@ class Llama:
mirostat_eta: float = 0.1,
model: Optional[str] = None,
) -> Union[Iterator[Completion], Iterator[CompletionChunk]]:
assert self.ctx is not None
invalidInputError(self.ctx is not None, "The attribute `ctx` of `Llama` object is None.")
completion_id: str = f"cmpl-{str(uuid.uuid4())}"
created: int = int(time.time())
completion_tokens: List[int] = []
@ -673,10 +698,9 @@ class Llama:
if self.verbose:
llama_cpp.llama_reset_timings(self.ctx)
if len(prompt_tokens) + max_tokens > int(llama_cpp.llama_n_ctx(self.ctx)):
raise ValueError(
f"Requested tokens exceed context window of {llama_cpp.llama_n_ctx(self.ctx)}"
)
invalidInputError(len(prompt_tokens) + max_tokens <= int(llama_cpp.llama_n_ctx(self.ctx)),
"Requested tokens exceed context window of "
f"{llama_cpp.llama_n_ctx(self.ctx)}.")
if stop != []:
stop_sequences = [s.encode("utf-8") for s in stop]
@ -684,9 +708,8 @@ class Llama:
stop_sequences = []
if logprobs is not None and self.params.logits_all is False:
raise ValueError(
"logprobs is not supported for models created with logits_all=False"
)
invalidInputError(False,
"logprobs is not supported for models created with logits_all=False")
if self.cache:
try:
@ -1294,9 +1317,9 @@ class Llama:
n_threads=self.n_threads,
lora_base=self.lora_base,
lora_path=self.lora_path,
### DEPRECATED ###
# DEPRECATED
n_parts=self.n_parts,
### DEPRECATED ###
# DEPRECATED
)
def __setstate__(self, state):
@ -1321,12 +1344,11 @@ class Llama:
)
def save_state(self) -> LlamaState:
assert self.ctx is not None
invalidInputError(self.ctx is not None, "The attribute `ctx` of `Llama` object is None.")
state_size = llama_cpp.llama_get_state_size(self.ctx)
llama_state = (llama_cpp.c_uint8 * int(state_size))()
n_bytes = llama_cpp.llama_copy_state_data(self.ctx, llama_state)
if int(n_bytes) > int(state_size):
raise RuntimeError("Failed to copy llama state data")
invalidInputError(int(n_bytes) <= int(state_size), "Failed to copy llama state data.")
llama_state_compact = (llama_cpp.c_uint8 * int(n_bytes))()
llama_cpp.ctypes.memmove(llama_state_compact, llama_state, int(n_bytes))
if self.verbose:
@ -1342,26 +1364,27 @@ class Llama:
)
def load_state(self, state: LlamaState) -> None:
assert self.ctx is not None
invalidInputError(self.ctx is not None, "The attribute `ctx` of `Llama` object is None.")
self.eval_tokens = state.eval_tokens.copy()
self.eval_logits = state.eval_logits.copy()
state_size = state.llama_state_size
if llama_cpp.llama_set_state_data(self.ctx, state.llama_state) != state_size:
raise RuntimeError("Failed to set llama state data")
invalidInputError(llama_cpp.llama_set_state_data(self.ctx,
state.llama_state) == state_size,
"Failed to set llama state data.")
def n_ctx(self) -> int:
"""Return the context window size."""
assert self.ctx is not None
invalidInputError(self.ctx is not None, "The attribute `ctx` of `Llama` object is None.")
return llama_cpp.llama_n_ctx(self.ctx)
def n_embd(self) -> int:
"""Return the embedding size."""
assert self.ctx is not None
invalidInputError(self.ctx is not None, "The attribute `ctx` of `Llama` object is None.")
return llama_cpp.llama_n_embd(self.ctx)
def n_vocab(self) -> int:
"""Return the vocabulary size."""
assert self.ctx is not None
invalidInputError(self.ctx is not None, "The attribute `ctx` of `Llama` object is None.")
return llama_cpp.llama_n_vocab(self.ctx)
@staticmethod

View file

@ -13,6 +13,32 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
# ===========================================================================
#
# This file is adapted from
# https://github.com/abetlen/llama-cpp-python/blob/main/llama_cpp/llama_cpp.py
#
# MIT License
#
# Copyright (c) 2023 Andrei Betlen
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
# This would makes sure Python is aware there is more than one sub-package within bigdl,
# physically located elsewhere.
@ -36,6 +62,7 @@ from ctypes import (
c_size_t,
)
import pathlib
from bigdl.llm.utils.common import invalidInputError
# Load the library
@ -48,7 +75,7 @@ def _load_shared_library(lib_base_name: str):
elif sys.platform == "win32":
lib_ext = ".dll"
else:
raise RuntimeError("Unsupported platform")
invalidInputError(False, "Unsupported platform.")
# Construct the paths to the possible shared library names
_base_path = pathlib.Path(__file__).parent.parent.parent.parent.resolve()
@ -81,11 +108,9 @@ def _load_shared_library(lib_base_name: str):
try:
return ctypes.CDLL(str(_lib_path), **cdll_args)
except Exception as e:
raise RuntimeError(f"Failed to load shared library '{_lib_path}': {e}")
invalidInputError(False, f"Failed to load shared library '{_lib_path}': {e}.")
raise FileNotFoundError(
f"Shared library with base name '{lib_base_name}' not found"
)
invalidInputError(False, f"Shared library with base name '{lib_base_name}' not found.")
# Specify the base name of the shared library to load
@ -300,7 +325,8 @@ _lib.llama_free.restype = None
# TODO: not great API - very likely to change
# Returns 0 on success
# nthread - how many threads to use. If <=0, will use std::thread::hardware_concurrency(), else the number given
# nthread - how many threads to use. If <=0, will use std::thread::hardware_concurrency(),
# else the number given
# LLAMA_API int llama_model_quantize(
# const char * fname_inp,
# const char * fname_out,
@ -399,7 +425,8 @@ _lib.llama_set_state_data.restype = c_size_t
# Save/load session file
# LLAMA_API bool llama_load_session_file(struct llama_context * ctx, const char * path_session, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out);
# LLAMA_API bool llama_load_session_file(struct llama_context * ctx, const char * path_session,
# llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out);
def llama_load_session_file(
ctx: llama_context_p,
path_session: bytes,
@ -422,7 +449,8 @@ _lib.llama_load_session_file.argtypes = [
_lib.llama_load_session_file.restype = c_size_t
# LLAMA_API bool llama_save_session_file(struct llama_context * ctx, const char * path_session, const llama_token * tokens, size_t n_token_count);
# LLAMA_API bool llama_save_session_file(struct llama_context * ctx, const char * path_session,
# const llama_token * tokens, size_t n_token_count);
def llama_save_session_file(
ctx: llama_context_p,
path_session: bytes,
@ -601,8 +629,10 @@ _lib.llama_init_candidates.argtypes = [
_lib.llama_init_candidates.restype = None
# @details Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix.
# LLAMA_API void llama_sample_repetition_penalty(struct llama_context * ctx, llama_token_data_array * candidates, const llama_token * last_tokens, size_t last_tokens_size, float penalty);
# @details Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858,
# with negative logit fix.
# LLAMA_API void llama_sample_repetition_penalty(struct llama_context * ctx, llama_token_data_array
# * candidates, const llama_token * last_tokens, size_t last_tokens_size, float penalty);
def llama_sample_repetition_penalty(
ctx: llama_context_p,
candidates, # type: _Pointer[llama_token_data_array]
@ -625,8 +655,11 @@ _lib.llama_sample_repetition_penalty.argtypes = [
_lib.llama_sample_repetition_penalty.restype = None
# @details Frequency and presence penalties described in OpenAI API https://platform.openai.com/docs/api-reference/parameter-details.
# LLAMA_API void llama_sample_frequency_and_presence_penalties(struct llama_context * ctx, llama_token_data_array * candidates, const llama_token * last_tokens, size_t last_tokens_size, float alpha_frequency, float alpha_presence);
# @details Frequency and presence penalties described in OpenAI API
# https://platform.openai.com/docs/api-reference/parameter-details.
# LLAMA_API void llama_sample_frequency_and_presence_penalties(struct llama_context * ctx,
# llama_token_data_array * candidates, const llama_token * last_tokens,
# size_t last_tokens_size, float alpha_frequency, float alpha_presence);
def llama_sample_frequency_and_presence_penalties(
ctx: llama_context_p,
candidates, # type: _Pointer[llama_token_data_array]
@ -656,8 +689,10 @@ _lib.llama_sample_frequency_and_presence_penalties.argtypes = [
_lib.llama_sample_frequency_and_presence_penalties.restype = None
# @details Sorts candidate tokens by their logits in descending order and calculate probabilities based on logits.
# LLAMA_API void llama_sample_softmax(struct llama_context * ctx, llama_token_data_array * candidates);
# @details Sorts candidate tokens by their logits in descending order and calculate probabilities
# based on logits.
# LLAMA_API void llama_sample_softmax(struct llama_context * ctx,
# llama_token_data_array * candidates);
def llama_sample_softmax(
ctx: llama_context_p, candidates # type: _Pointer[llama_token_data]
):
@ -671,8 +706,10 @@ _lib.llama_sample_softmax.argtypes = [
_lib.llama_sample_softmax.restype = None
# @details Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
# LLAMA_API void llama_sample_top_k(struct llama_context * ctx, llama_token_data_array * candidates, int k, size_t min_keep);
# @details Top-K sampling described in academic paper
# "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
# LLAMA_API void llama_sample_top_k(struct llama_context * ctx,
# llama_token_data_array * candidates, int k, size_t min_keep);
def llama_sample_top_k(
ctx: llama_context_p,
candidates, # type: _Pointer[llama_token_data_array]
@ -691,8 +728,10 @@ _lib.llama_sample_top_k.argtypes = [
_lib.llama_sample_top_k.restype = None
# @details Nucleus sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
# LLAMA_API void llama_sample_top_p(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep);
# @details Nucleus sampling described in academic paper
# "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
# LLAMA_API void llama_sample_top_p(struct llama_context * ctx,
# llama_token_data_array * candidates, float p, size_t min_keep);
def llama_sample_top_p(
ctx: llama_context_p,
candidates, # type: _Pointer[llama_token_data_array]
@ -711,8 +750,10 @@ _lib.llama_sample_top_p.argtypes = [
_lib.llama_sample_top_p.restype = None
# @details Tail Free Sampling described in https://www.trentonbricken.com/Tail-Free-Sampling/.
# LLAMA_API void llama_sample_tail_free(struct llama_context * ctx, llama_token_data_array * candidates, float z, size_t min_keep);
# @details Tail Free Sampling described in
# https://www.trentonbricken.com/Tail-Free-Sampling/.
# LLAMA_API void llama_sample_tail_free(struct llama_context * ctx,
# llama_token_data_array * candidates, float z, size_t min_keep);
def llama_sample_tail_free(
ctx: llama_context_p,
candidates, # type: _Pointer[llama_token_data_array]
@ -731,8 +772,10 @@ _lib.llama_sample_tail_free.argtypes = [
_lib.llama_sample_tail_free.restype = None
# @details Locally Typical Sampling implementation described in the paper https://arxiv.org/abs/2202.00666.
# LLAMA_API void llama_sample_typical(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep);
# @details Locally Typical Sampling implementation described in the paper
# https://arxiv.org/abs/2202.00666.
# LLAMA_API void llama_sample_typical(struct llama_context * ctx,
# llama_token_data_array * candidates, float p, size_t min_keep);
def llama_sample_typical(
ctx: llama_context_p,
candidates, # type: _Pointer[llama_token_data_array]
@ -751,7 +794,8 @@ _lib.llama_sample_typical.argtypes = [
_lib.llama_sample_typical.restype = None
# LLAMA_API void llama_sample_temperature(struct llama_context * ctx, llama_token_data_array * candidates, float temp);
# LLAMA_API void llama_sample_temperature(struct llama_context * ctx,
# llama_token_data_array * candidates, float temp);
def llama_sample_temperature(
ctx: llama_context_p,
candidates, # type: _Pointer[llama_token_data_array]
@ -768,13 +812,25 @@ _lib.llama_sample_temperature.argtypes = [
_lib.llama_sample_temperature.restype = None
# @details Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words.
# @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text.
# @param tau The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text.
# @param eta The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates.
# @param m The number of tokens considered in the estimation of `s_hat`. This is an arbitrary value that is used to calculate `s_hat`, which in turn helps to calculate the value of `k`. In the paper, they use `m = 100`, but you can experiment with different values to see how it affects the performance of the algorithm.
# @param mu Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal.
# LLAMA_API llama_token llama_sample_token_mirostat(struct llama_context * ctx, llama_token_data_array * candidates, float tau, float eta, int m, float * mu);
# @details Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966.
# Uses tokens instead of words.
# @param candidates A vector of `llama_token_data` containing the candidate tokens,
# their probabilities (p), and log-odds (logit) for the current position in the generated text.
# @param tau The target cross-entropy (or surprise) value you want to achieve for the generated
# text. A higher value corresponds to more surprising or less predictable text, while a lower
# value corresponds to less surprising or more predictable text.
# @param eta The learning rate used to update `mu` based on the error between the target and
# observed surprisal of the sampled word. A larger learning rate will cause `mu` to be
# updated more quickly, while a smaller learning rate will result in slower updates.
# @param m The number of tokens considered in the estimation of `s_hat`. This is an arbitrary value
# that is used to calculate `s_hat`, which in turn helps to calculate the value of `k`.
# In the paper, they use `m = 100`, but you can experiment with different values to see
# how it affects the performance of the algorithm.
# @param mu Maximum cross-entropy. This value is initialized to be twice the target cross-entropy
# (`2 * tau`) and is updated in the algorithm based on the error between the target and
# observed surprisal.
# LLAMA_API llama_token llama_sample_token_mirostat(struct llama_context * ctx,
# llama_token_data_array * candidates, float tau, float eta, int m, float * mu);
def llama_sample_token_mirostat(
ctx: llama_context_p,
candidates, # type: _Pointer[llama_token_data_array]
@ -797,12 +853,21 @@ _lib.llama_sample_token_mirostat.argtypes = [
_lib.llama_sample_token_mirostat.restype = llama_token
# @details Mirostat 2.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words.
# @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text.
# @param tau The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text.
# @param eta The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates.
# @param mu Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal.
# LLAMA_API llama_token llama_sample_token_mirostat_v2(struct llama_context * ctx, llama_token_data_array * candidates, float tau, float eta, float * mu);
# @details Mirostat 2.0 algorithm described in the paper https://arxiv.org/abs/2007.14966.
# Uses tokens instead of words.
# @param candidates A vector of `llama_token_data` containing the candidate tokens,
# their probabilities (p), and log-odds (logit) for the current position in the generated text.
# @param tau The target cross-entropy (or surprise) value you want to achieve for the generated
# text. A higher value corresponds to more surprising or less predictable text, while a lower value
# corresponds to less surprising or more predictable text.
# @param eta The learning rate used to update `mu` based on the error between the target and
# observed surprisal of the sampled word. A larger learning rate will cause `mu` to be
# updated more quickly, while a smaller learning rate will result in slower updates.
# @param mu Maximum cross-entropy. This value is initialized to be twice the target cross-entropy
# (`2 * tau`) and is updated in the algorithm based on the error between the target
# and observed surprisal.
# LLAMA_API llama_token llama_sample_token_mirostat_v2(struct llama_context * ctx,
# llama_token_data_array * candidates, float tau, float eta, float * mu);
def llama_sample_token_mirostat_v2(
ctx: llama_context_p,
candidates, # type: _Pointer[llama_token_data_array]
@ -824,7 +889,8 @@ _lib.llama_sample_token_mirostat_v2.restype = llama_token
# @details Selects the token with the highest probability.
# LLAMA_API llama_token llama_sample_token_greedy(struct llama_context * ctx, llama_token_data_array * candidates);
# LLAMA_API llama_token llama_sample_token_greedy(struct llama_context * ctx,
# llama_token_data_array * candidates);
def llama_sample_token_greedy(
ctx: llama_context_p,
candidates, # type: _Pointer[llama_token_data_array]
@ -840,7 +906,8 @@ _lib.llama_sample_token_greedy.restype = llama_token
# @details Randomly selects a token from the candidates based on their probabilities.
# LLAMA_API llama_token llama_sample_token(struct llama_context * ctx, llama_token_data_array * candidates);
# LLAMA_API llama_token llama_sample_token(struct llama_context * ctx,
# llama_token_data_array * candidates);
def llama_sample_token(
ctx: llama_context_p,
candidates, # type: _Pointer[llama_token_data_array]

View file

@ -13,6 +13,32 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
# ===========================================================================
#
# This file is adapted from
# https://github.com/abetlen/llama-cpp-python/blob/main/llama_cpp/llama_types.py
#
# MIT License
#
# Copyright (c) 2023 Andrei Betlen
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
# This would makes sure Python is aware there is more than one sub-package within bigdl,
# physically located elsewhere.