LLM: fix and update related license in llama pybinding (#8250)
This commit is contained in:
parent
141febec1f
commit
3a9aa23835
3 changed files with 212 additions and 96 deletions
|
|
@ -14,6 +14,33 @@
|
|||
# limitations under the License.
|
||||
#
|
||||
|
||||
# ===========================================================================
|
||||
#
|
||||
# This file is adapted from
|
||||
# https://github.com/abetlen/llama-cpp-python/blob/main/llama_cpp/llama.py
|
||||
#
|
||||
# MIT License
|
||||
#
|
||||
# Copyright (c) 2023 Andrei Betlen
|
||||
#
|
||||
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
# of this software and associated documentation files (the "Software"), to deal
|
||||
# in the Software without restriction, including without limitation the rights
|
||||
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
# copies of the Software, and to permit persons to whom the Software is
|
||||
# furnished to do so, subject to the following conditions:
|
||||
#
|
||||
# The above copyright notice and this permission notice shall be included in all
|
||||
# copies or substantial portions of the Software.
|
||||
#
|
||||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
# SOFTWARE.
|
||||
|
||||
# This would makes sure Python is aware there is more than one sub-package within bigdl,
|
||||
# physically located elsewhere.
|
||||
# Otherwise there would be module not found error in non-pip's setting as Python would
|
||||
|
|
@ -27,7 +54,7 @@ import math
|
|||
import multiprocessing
|
||||
from typing import List, Optional, Union, Generator, Sequence, Iterator, Deque, Tuple
|
||||
from collections import deque, OrderedDict
|
||||
|
||||
from bigdl.llm.utils.common import invalidInputError
|
||||
from . import llama_cpp
|
||||
from .llama_types import *
|
||||
|
||||
|
|
@ -61,8 +88,7 @@ class LlamaCache:
|
|||
def __getitem__(self, key: Sequence[int]) -> "LlamaState":
|
||||
key = tuple(key)
|
||||
_key = self._find_longest_prefix_key(key)
|
||||
if _key is None:
|
||||
raise KeyError(f"Key not found")
|
||||
invalidInputError(_key is not None, "Key not found.")
|
||||
value = self.cache_state[_key]
|
||||
self.cache_state.move_to_end(_key)
|
||||
return value
|
||||
|
|
@ -122,7 +148,8 @@ class Llama:
|
|||
Args:
|
||||
model_path: Path to the model.
|
||||
n_ctx: Maximum context size.
|
||||
n_parts: Number of parts to split the model into. If -1, the number of parts is automatically determined.
|
||||
n_parts: Number of parts to split the model into. If -1, the number of parts
|
||||
is automatically determined.
|
||||
seed: Random seed. 0 for random.
|
||||
f16_kv: Use half-precision for key/value cache.
|
||||
logits_all: Return logits for all tokens, not just the last token.
|
||||
|
|
@ -130,10 +157,12 @@ class Llama:
|
|||
use_mmap: Use mmap if possible.
|
||||
use_mlock: Force the system to keep the model in RAM.
|
||||
embedding: Embedding mode only.
|
||||
n_threads: Number of threads to use. If None, the number of threads is automatically determined.
|
||||
n_threads: Number of threads to use. If None, the number of threads is
|
||||
automatically determined.
|
||||
n_batch: Maximum number of prompt tokens to batch together when calling llama_eval.
|
||||
last_n_tokens_size: Maximum number of tokens to keep in the last_n_tokens deque.
|
||||
lora_base: Optional path to base model, useful if using a quantized base model and you want to apply LoRA to an f16 model.
|
||||
lora_base: Optional path to base model, useful if using a quantized base model and
|
||||
you want to apply LoRA to an f16 model.
|
||||
lora_path: Path to a LoRA file to apply to the model.
|
||||
verbose: Print verbose output to stderr.
|
||||
|
||||
|
|
@ -169,18 +198,17 @@ class Llama:
|
|||
self.lora_base = lora_base
|
||||
self.lora_path = lora_path
|
||||
|
||||
### DEPRECATED ###
|
||||
# DEPRECATED
|
||||
self.n_parts = n_parts
|
||||
### DEPRECATED ###
|
||||
# DEPRECATED
|
||||
|
||||
if not os.path.exists(model_path):
|
||||
raise ValueError(f"Model path does not exist: {model_path}")
|
||||
invalidInputError(os.path.exists(model_path), f"Model path does not exist: {model_path}.")
|
||||
|
||||
self.ctx = llama_cpp.llama_init_from_file(
|
||||
self.model_path.encode("utf-8"), self.params
|
||||
)
|
||||
|
||||
assert self.ctx is not None
|
||||
invalidInputError(self.ctx is not None, "The attribute `ctx` of `Llama` object is None.")
|
||||
|
||||
if self.lora_path:
|
||||
if llama_cpp.llama_apply_lora_from_file(
|
||||
|
|
@ -191,9 +219,9 @@ class Llama:
|
|||
else llama_cpp.c_char_p(0),
|
||||
llama_cpp.c_int(self.n_threads),
|
||||
):
|
||||
raise RuntimeError(
|
||||
f"Failed to apply LoRA from lora path: {self.lora_path} to base path: {self.lora_base}"
|
||||
)
|
||||
invalidInputError(False,
|
||||
"Failed to apply LoRA from lora path: "
|
||||
f"{self.lora_path} to base path: {self.lora_base}")
|
||||
|
||||
if self.verbose:
|
||||
print(llama_cpp.llama_print_system_info().decode("utf-8"), file=sys.stderr)
|
||||
|
|
@ -233,7 +261,7 @@ class Llama:
|
|||
Returns:
|
||||
A list of tokens.
|
||||
"""
|
||||
assert self.ctx is not None
|
||||
invalidInputError(self.ctx is not None, "The attribute `ctx` of `Llama` object is None.")
|
||||
n_ctx = llama_cpp.llama_n_ctx(self.ctx)
|
||||
tokens = (llama_cpp.llama_token * int(n_ctx))()
|
||||
n_tokens = llama_cpp.llama_tokenize(
|
||||
|
|
@ -253,10 +281,8 @@ class Llama:
|
|||
llama_cpp.c_int(n_tokens),
|
||||
llama_cpp.c_bool(add_bos),
|
||||
)
|
||||
if n_tokens < 0:
|
||||
raise RuntimeError(
|
||||
f'Failed to tokenize: text="{text}" n_tokens={n_tokens}'
|
||||
)
|
||||
invalidInputError(n_tokens >= 0,
|
||||
f'Failed to tokenize: text="{text}" n_tokens={n_tokens}')
|
||||
return list(tokens[:n_tokens])
|
||||
|
||||
def detokenize(self, tokens: List[int]) -> bytes:
|
||||
|
|
@ -268,7 +294,7 @@ class Llama:
|
|||
Returns:
|
||||
The detokenized string.
|
||||
"""
|
||||
assert self.ctx is not None
|
||||
invalidInputError(self.ctx is not None, "The attribute `ctx` of `Llama` object is None.")
|
||||
output = b""
|
||||
for token in tokens:
|
||||
output += llama_cpp.llama_token_to_str(
|
||||
|
|
@ -295,7 +321,7 @@ class Llama:
|
|||
Args:
|
||||
tokens: The list of tokens to evaluate.
|
||||
"""
|
||||
assert self.ctx is not None
|
||||
invalidInputError(self.ctx is not None, "The attribute `ctx` of `Llama` object is None.")
|
||||
n_ctx = int(llama_cpp.llama_n_ctx(self.ctx))
|
||||
for i in range(0, len(tokens), self.n_batch):
|
||||
batch = tokens[i: min(len(tokens), i + self.n_batch)]
|
||||
|
|
@ -308,8 +334,7 @@ class Llama:
|
|||
n_past=llama_cpp.c_int(n_past),
|
||||
n_threads=llama_cpp.c_int(self.n_threads),
|
||||
)
|
||||
if int(return_code) != 0:
|
||||
raise RuntimeError(f"llama_eval returned {return_code}")
|
||||
invalidInputError(int(return_code) == 0, f"llama_eval returned {return_code}.")
|
||||
# Save tokens
|
||||
self.eval_tokens.extend(batch)
|
||||
# Save logits
|
||||
|
|
@ -338,8 +363,9 @@ class Llama:
|
|||
mirostat_eta: llama_cpp.c_float,
|
||||
penalize_nl: bool = True,
|
||||
):
|
||||
assert self.ctx is not None
|
||||
assert len(self.eval_logits) > 0
|
||||
invalidInputError(self.ctx is not None, "The attribute `ctx` of `Llama` object is None.")
|
||||
invalidInputError(len(self.eval_logits) > 0,
|
||||
"The attribute `eval_logits` of `Llama` object is None.")
|
||||
n_vocab = self.n_vocab()
|
||||
n_ctx = self.n_ctx()
|
||||
top_k = llama_cpp.c_int(n_vocab) if top_k.value <= 0 else top_k
|
||||
|
|
@ -467,7 +493,7 @@ class Llama:
|
|||
Returns:
|
||||
The sampled token.
|
||||
"""
|
||||
assert self.ctx is not None
|
||||
invalidInputError(self.ctx is not None, "The attribute `ctx` of `Llama` object is None.")
|
||||
last_n_tokens_data = [llama_cpp.llama_token(0)] * max(
|
||||
0, self.last_n_tokens_size - len(self.eval_tokens)
|
||||
) + list(self.eval_tokens)[-self.last_n_tokens_size:]
|
||||
|
|
@ -509,7 +535,8 @@ class Llama:
|
|||
Examples:
|
||||
>>> llama = Llama("models/ggml-7b.bin")
|
||||
>>> tokens = llama.tokenize(b"Hello, world!")
|
||||
>>> for token in llama.generate(tokens, top_k=40, top_p=0.95, temp=1.0, repeat_penalty=1.1):
|
||||
>>> for token in llama.generate(tokens, top_k=40, top_p=0.95,
|
||||
>>> temp=1.0, repeat_penalty=1.1):
|
||||
... print(llama.detokenize([token]))
|
||||
|
||||
Args:
|
||||
|
|
@ -523,7 +550,7 @@ class Llama:
|
|||
Yields:
|
||||
The generated tokens.
|
||||
"""
|
||||
assert self.ctx is not None
|
||||
invalidInputError(self.ctx is not None, "The attribute `ctx` of `Llama` object is None.")
|
||||
|
||||
if reset and len(self.eval_tokens) > 0:
|
||||
longest_prefix = 0
|
||||
|
|
@ -577,13 +604,11 @@ class Llama:
|
|||
Returns:
|
||||
An embedding object.
|
||||
"""
|
||||
assert self.ctx is not None
|
||||
invalidInputError(self.ctx is not None, "The attribute `ctx` of `Llama` object is None.")
|
||||
model_name: str = model if model is not None else self.model_path
|
||||
|
||||
if self.params.embedding == False:
|
||||
raise RuntimeError(
|
||||
"Llama model must be created with embedding=True to call this method"
|
||||
)
|
||||
invalidInputError(self.params.embedding,
|
||||
"Llama model must be created with embedding=True to call this method.")
|
||||
|
||||
if self.verbose:
|
||||
llama_cpp.llama_reset_timings(self.ctx)
|
||||
|
|
@ -657,7 +682,7 @@ class Llama:
|
|||
mirostat_eta: float = 0.1,
|
||||
model: Optional[str] = None,
|
||||
) -> Union[Iterator[Completion], Iterator[CompletionChunk]]:
|
||||
assert self.ctx is not None
|
||||
invalidInputError(self.ctx is not None, "The attribute `ctx` of `Llama` object is None.")
|
||||
completion_id: str = f"cmpl-{str(uuid.uuid4())}"
|
||||
created: int = int(time.time())
|
||||
completion_tokens: List[int] = []
|
||||
|
|
@ -673,10 +698,9 @@ class Llama:
|
|||
if self.verbose:
|
||||
llama_cpp.llama_reset_timings(self.ctx)
|
||||
|
||||
if len(prompt_tokens) + max_tokens > int(llama_cpp.llama_n_ctx(self.ctx)):
|
||||
raise ValueError(
|
||||
f"Requested tokens exceed context window of {llama_cpp.llama_n_ctx(self.ctx)}"
|
||||
)
|
||||
invalidInputError(len(prompt_tokens) + max_tokens <= int(llama_cpp.llama_n_ctx(self.ctx)),
|
||||
"Requested tokens exceed context window of "
|
||||
f"{llama_cpp.llama_n_ctx(self.ctx)}.")
|
||||
|
||||
if stop != []:
|
||||
stop_sequences = [s.encode("utf-8") for s in stop]
|
||||
|
|
@ -684,9 +708,8 @@ class Llama:
|
|||
stop_sequences = []
|
||||
|
||||
if logprobs is not None and self.params.logits_all is False:
|
||||
raise ValueError(
|
||||
"logprobs is not supported for models created with logits_all=False"
|
||||
)
|
||||
invalidInputError(False,
|
||||
"logprobs is not supported for models created with logits_all=False")
|
||||
|
||||
if self.cache:
|
||||
try:
|
||||
|
|
@ -1294,9 +1317,9 @@ class Llama:
|
|||
n_threads=self.n_threads,
|
||||
lora_base=self.lora_base,
|
||||
lora_path=self.lora_path,
|
||||
### DEPRECATED ###
|
||||
# DEPRECATED
|
||||
n_parts=self.n_parts,
|
||||
### DEPRECATED ###
|
||||
# DEPRECATED
|
||||
)
|
||||
|
||||
def __setstate__(self, state):
|
||||
|
|
@ -1321,12 +1344,11 @@ class Llama:
|
|||
)
|
||||
|
||||
def save_state(self) -> LlamaState:
|
||||
assert self.ctx is not None
|
||||
invalidInputError(self.ctx is not None, "The attribute `ctx` of `Llama` object is None.")
|
||||
state_size = llama_cpp.llama_get_state_size(self.ctx)
|
||||
llama_state = (llama_cpp.c_uint8 * int(state_size))()
|
||||
n_bytes = llama_cpp.llama_copy_state_data(self.ctx, llama_state)
|
||||
if int(n_bytes) > int(state_size):
|
||||
raise RuntimeError("Failed to copy llama state data")
|
||||
invalidInputError(int(n_bytes) <= int(state_size), "Failed to copy llama state data.")
|
||||
llama_state_compact = (llama_cpp.c_uint8 * int(n_bytes))()
|
||||
llama_cpp.ctypes.memmove(llama_state_compact, llama_state, int(n_bytes))
|
||||
if self.verbose:
|
||||
|
|
@ -1342,26 +1364,27 @@ class Llama:
|
|||
)
|
||||
|
||||
def load_state(self, state: LlamaState) -> None:
|
||||
assert self.ctx is not None
|
||||
invalidInputError(self.ctx is not None, "The attribute `ctx` of `Llama` object is None.")
|
||||
self.eval_tokens = state.eval_tokens.copy()
|
||||
self.eval_logits = state.eval_logits.copy()
|
||||
state_size = state.llama_state_size
|
||||
if llama_cpp.llama_set_state_data(self.ctx, state.llama_state) != state_size:
|
||||
raise RuntimeError("Failed to set llama state data")
|
||||
invalidInputError(llama_cpp.llama_set_state_data(self.ctx,
|
||||
state.llama_state) == state_size,
|
||||
"Failed to set llama state data.")
|
||||
|
||||
def n_ctx(self) -> int:
|
||||
"""Return the context window size."""
|
||||
assert self.ctx is not None
|
||||
invalidInputError(self.ctx is not None, "The attribute `ctx` of `Llama` object is None.")
|
||||
return llama_cpp.llama_n_ctx(self.ctx)
|
||||
|
||||
def n_embd(self) -> int:
|
||||
"""Return the embedding size."""
|
||||
assert self.ctx is not None
|
||||
invalidInputError(self.ctx is not None, "The attribute `ctx` of `Llama` object is None.")
|
||||
return llama_cpp.llama_n_embd(self.ctx)
|
||||
|
||||
def n_vocab(self) -> int:
|
||||
"""Return the vocabulary size."""
|
||||
assert self.ctx is not None
|
||||
invalidInputError(self.ctx is not None, "The attribute `ctx` of `Llama` object is None.")
|
||||
return llama_cpp.llama_n_vocab(self.ctx)
|
||||
|
||||
@staticmethod
|
||||
|
|
|
|||
|
|
@ -13,6 +13,32 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
# ===========================================================================
|
||||
#
|
||||
# This file is adapted from
|
||||
# https://github.com/abetlen/llama-cpp-python/blob/main/llama_cpp/llama_cpp.py
|
||||
#
|
||||
# MIT License
|
||||
#
|
||||
# Copyright (c) 2023 Andrei Betlen
|
||||
#
|
||||
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
# of this software and associated documentation files (the "Software"), to deal
|
||||
# in the Software without restriction, including without limitation the rights
|
||||
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
# copies of the Software, and to permit persons to whom the Software is
|
||||
# furnished to do so, subject to the following conditions:
|
||||
#
|
||||
# The above copyright notice and this permission notice shall be included in all
|
||||
# copies or substantial portions of the Software.
|
||||
#
|
||||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
# SOFTWARE.
|
||||
|
||||
# This would makes sure Python is aware there is more than one sub-package within bigdl,
|
||||
# physically located elsewhere.
|
||||
|
|
@ -36,6 +62,7 @@ from ctypes import (
|
|||
c_size_t,
|
||||
)
|
||||
import pathlib
|
||||
from bigdl.llm.utils.common import invalidInputError
|
||||
|
||||
|
||||
# Load the library
|
||||
|
|
@ -48,7 +75,7 @@ def _load_shared_library(lib_base_name: str):
|
|||
elif sys.platform == "win32":
|
||||
lib_ext = ".dll"
|
||||
else:
|
||||
raise RuntimeError("Unsupported platform")
|
||||
invalidInputError(False, "Unsupported platform.")
|
||||
|
||||
# Construct the paths to the possible shared library names
|
||||
_base_path = pathlib.Path(__file__).parent.parent.parent.parent.resolve()
|
||||
|
|
@ -81,11 +108,9 @@ def _load_shared_library(lib_base_name: str):
|
|||
try:
|
||||
return ctypes.CDLL(str(_lib_path), **cdll_args)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to load shared library '{_lib_path}': {e}")
|
||||
invalidInputError(False, f"Failed to load shared library '{_lib_path}': {e}.")
|
||||
|
||||
raise FileNotFoundError(
|
||||
f"Shared library with base name '{lib_base_name}' not found"
|
||||
)
|
||||
invalidInputError(False, f"Shared library with base name '{lib_base_name}' not found.")
|
||||
|
||||
|
||||
# Specify the base name of the shared library to load
|
||||
|
|
@ -300,7 +325,8 @@ _lib.llama_free.restype = None
|
|||
|
||||
# TODO: not great API - very likely to change
|
||||
# Returns 0 on success
|
||||
# nthread - how many threads to use. If <=0, will use std::thread::hardware_concurrency(), else the number given
|
||||
# nthread - how many threads to use. If <=0, will use std::thread::hardware_concurrency(),
|
||||
# else the number given
|
||||
# LLAMA_API int llama_model_quantize(
|
||||
# const char * fname_inp,
|
||||
# const char * fname_out,
|
||||
|
|
@ -399,7 +425,8 @@ _lib.llama_set_state_data.restype = c_size_t
|
|||
|
||||
|
||||
# Save/load session file
|
||||
# LLAMA_API bool llama_load_session_file(struct llama_context * ctx, const char * path_session, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out);
|
||||
# LLAMA_API bool llama_load_session_file(struct llama_context * ctx, const char * path_session,
|
||||
# llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out);
|
||||
def llama_load_session_file(
|
||||
ctx: llama_context_p,
|
||||
path_session: bytes,
|
||||
|
|
@ -422,7 +449,8 @@ _lib.llama_load_session_file.argtypes = [
|
|||
_lib.llama_load_session_file.restype = c_size_t
|
||||
|
||||
|
||||
# LLAMA_API bool llama_save_session_file(struct llama_context * ctx, const char * path_session, const llama_token * tokens, size_t n_token_count);
|
||||
# LLAMA_API bool llama_save_session_file(struct llama_context * ctx, const char * path_session,
|
||||
# const llama_token * tokens, size_t n_token_count);
|
||||
def llama_save_session_file(
|
||||
ctx: llama_context_p,
|
||||
path_session: bytes,
|
||||
|
|
@ -601,8 +629,10 @@ _lib.llama_init_candidates.argtypes = [
|
|||
_lib.llama_init_candidates.restype = None
|
||||
|
||||
|
||||
# @details Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix.
|
||||
# LLAMA_API void llama_sample_repetition_penalty(struct llama_context * ctx, llama_token_data_array * candidates, const llama_token * last_tokens, size_t last_tokens_size, float penalty);
|
||||
# @details Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858,
|
||||
# with negative logit fix.
|
||||
# LLAMA_API void llama_sample_repetition_penalty(struct llama_context * ctx, llama_token_data_array
|
||||
# * candidates, const llama_token * last_tokens, size_t last_tokens_size, float penalty);
|
||||
def llama_sample_repetition_penalty(
|
||||
ctx: llama_context_p,
|
||||
candidates, # type: _Pointer[llama_token_data_array]
|
||||
|
|
@ -625,8 +655,11 @@ _lib.llama_sample_repetition_penalty.argtypes = [
|
|||
_lib.llama_sample_repetition_penalty.restype = None
|
||||
|
||||
|
||||
# @details Frequency and presence penalties described in OpenAI API https://platform.openai.com/docs/api-reference/parameter-details.
|
||||
# LLAMA_API void llama_sample_frequency_and_presence_penalties(struct llama_context * ctx, llama_token_data_array * candidates, const llama_token * last_tokens, size_t last_tokens_size, float alpha_frequency, float alpha_presence);
|
||||
# @details Frequency and presence penalties described in OpenAI API
|
||||
# https://platform.openai.com/docs/api-reference/parameter-details.
|
||||
# LLAMA_API void llama_sample_frequency_and_presence_penalties(struct llama_context * ctx,
|
||||
# llama_token_data_array * candidates, const llama_token * last_tokens,
|
||||
# size_t last_tokens_size, float alpha_frequency, float alpha_presence);
|
||||
def llama_sample_frequency_and_presence_penalties(
|
||||
ctx: llama_context_p,
|
||||
candidates, # type: _Pointer[llama_token_data_array]
|
||||
|
|
@ -656,8 +689,10 @@ _lib.llama_sample_frequency_and_presence_penalties.argtypes = [
|
|||
_lib.llama_sample_frequency_and_presence_penalties.restype = None
|
||||
|
||||
|
||||
# @details Sorts candidate tokens by their logits in descending order and calculate probabilities based on logits.
|
||||
# LLAMA_API void llama_sample_softmax(struct llama_context * ctx, llama_token_data_array * candidates);
|
||||
# @details Sorts candidate tokens by their logits in descending order and calculate probabilities
|
||||
# based on logits.
|
||||
# LLAMA_API void llama_sample_softmax(struct llama_context * ctx,
|
||||
# llama_token_data_array * candidates);
|
||||
def llama_sample_softmax(
|
||||
ctx: llama_context_p, candidates # type: _Pointer[llama_token_data]
|
||||
):
|
||||
|
|
@ -671,8 +706,10 @@ _lib.llama_sample_softmax.argtypes = [
|
|||
_lib.llama_sample_softmax.restype = None
|
||||
|
||||
|
||||
# @details Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
|
||||
# LLAMA_API void llama_sample_top_k(struct llama_context * ctx, llama_token_data_array * candidates, int k, size_t min_keep);
|
||||
# @details Top-K sampling described in academic paper
|
||||
# "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
|
||||
# LLAMA_API void llama_sample_top_k(struct llama_context * ctx,
|
||||
# llama_token_data_array * candidates, int k, size_t min_keep);
|
||||
def llama_sample_top_k(
|
||||
ctx: llama_context_p,
|
||||
candidates, # type: _Pointer[llama_token_data_array]
|
||||
|
|
@ -691,8 +728,10 @@ _lib.llama_sample_top_k.argtypes = [
|
|||
_lib.llama_sample_top_k.restype = None
|
||||
|
||||
|
||||
# @details Nucleus sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
|
||||
# LLAMA_API void llama_sample_top_p(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep);
|
||||
# @details Nucleus sampling described in academic paper
|
||||
# "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
|
||||
# LLAMA_API void llama_sample_top_p(struct llama_context * ctx,
|
||||
# llama_token_data_array * candidates, float p, size_t min_keep);
|
||||
def llama_sample_top_p(
|
||||
ctx: llama_context_p,
|
||||
candidates, # type: _Pointer[llama_token_data_array]
|
||||
|
|
@ -711,8 +750,10 @@ _lib.llama_sample_top_p.argtypes = [
|
|||
_lib.llama_sample_top_p.restype = None
|
||||
|
||||
|
||||
# @details Tail Free Sampling described in https://www.trentonbricken.com/Tail-Free-Sampling/.
|
||||
# LLAMA_API void llama_sample_tail_free(struct llama_context * ctx, llama_token_data_array * candidates, float z, size_t min_keep);
|
||||
# @details Tail Free Sampling described in
|
||||
# https://www.trentonbricken.com/Tail-Free-Sampling/.
|
||||
# LLAMA_API void llama_sample_tail_free(struct llama_context * ctx,
|
||||
# llama_token_data_array * candidates, float z, size_t min_keep);
|
||||
def llama_sample_tail_free(
|
||||
ctx: llama_context_p,
|
||||
candidates, # type: _Pointer[llama_token_data_array]
|
||||
|
|
@ -731,8 +772,10 @@ _lib.llama_sample_tail_free.argtypes = [
|
|||
_lib.llama_sample_tail_free.restype = None
|
||||
|
||||
|
||||
# @details Locally Typical Sampling implementation described in the paper https://arxiv.org/abs/2202.00666.
|
||||
# LLAMA_API void llama_sample_typical(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep);
|
||||
# @details Locally Typical Sampling implementation described in the paper
|
||||
# https://arxiv.org/abs/2202.00666.
|
||||
# LLAMA_API void llama_sample_typical(struct llama_context * ctx,
|
||||
# llama_token_data_array * candidates, float p, size_t min_keep);
|
||||
def llama_sample_typical(
|
||||
ctx: llama_context_p,
|
||||
candidates, # type: _Pointer[llama_token_data_array]
|
||||
|
|
@ -751,7 +794,8 @@ _lib.llama_sample_typical.argtypes = [
|
|||
_lib.llama_sample_typical.restype = None
|
||||
|
||||
|
||||
# LLAMA_API void llama_sample_temperature(struct llama_context * ctx, llama_token_data_array * candidates, float temp);
|
||||
# LLAMA_API void llama_sample_temperature(struct llama_context * ctx,
|
||||
# llama_token_data_array * candidates, float temp);
|
||||
def llama_sample_temperature(
|
||||
ctx: llama_context_p,
|
||||
candidates, # type: _Pointer[llama_token_data_array]
|
||||
|
|
@ -768,13 +812,25 @@ _lib.llama_sample_temperature.argtypes = [
|
|||
_lib.llama_sample_temperature.restype = None
|
||||
|
||||
|
||||
# @details Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words.
|
||||
# @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text.
|
||||
# @param tau The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text.
|
||||
# @param eta The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates.
|
||||
# @param m The number of tokens considered in the estimation of `s_hat`. This is an arbitrary value that is used to calculate `s_hat`, which in turn helps to calculate the value of `k`. In the paper, they use `m = 100`, but you can experiment with different values to see how it affects the performance of the algorithm.
|
||||
# @param mu Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal.
|
||||
# LLAMA_API llama_token llama_sample_token_mirostat(struct llama_context * ctx, llama_token_data_array * candidates, float tau, float eta, int m, float * mu);
|
||||
# @details Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966.
|
||||
# Uses tokens instead of words.
|
||||
# @param candidates A vector of `llama_token_data` containing the candidate tokens,
|
||||
# their probabilities (p), and log-odds (logit) for the current position in the generated text.
|
||||
# @param tau The target cross-entropy (or surprise) value you want to achieve for the generated
|
||||
# text. A higher value corresponds to more surprising or less predictable text, while a lower
|
||||
# value corresponds to less surprising or more predictable text.
|
||||
# @param eta The learning rate used to update `mu` based on the error between the target and
|
||||
# observed surprisal of the sampled word. A larger learning rate will cause `mu` to be
|
||||
# updated more quickly, while a smaller learning rate will result in slower updates.
|
||||
# @param m The number of tokens considered in the estimation of `s_hat`. This is an arbitrary value
|
||||
# that is used to calculate `s_hat`, which in turn helps to calculate the value of `k`.
|
||||
# In the paper, they use `m = 100`, but you can experiment with different values to see
|
||||
# how it affects the performance of the algorithm.
|
||||
# @param mu Maximum cross-entropy. This value is initialized to be twice the target cross-entropy
|
||||
# (`2 * tau`) and is updated in the algorithm based on the error between the target and
|
||||
# observed surprisal.
|
||||
# LLAMA_API llama_token llama_sample_token_mirostat(struct llama_context * ctx,
|
||||
# llama_token_data_array * candidates, float tau, float eta, int m, float * mu);
|
||||
def llama_sample_token_mirostat(
|
||||
ctx: llama_context_p,
|
||||
candidates, # type: _Pointer[llama_token_data_array]
|
||||
|
|
@ -797,12 +853,21 @@ _lib.llama_sample_token_mirostat.argtypes = [
|
|||
_lib.llama_sample_token_mirostat.restype = llama_token
|
||||
|
||||
|
||||
# @details Mirostat 2.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words.
|
||||
# @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text.
|
||||
# @param tau The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text.
|
||||
# @param eta The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates.
|
||||
# @param mu Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal.
|
||||
# LLAMA_API llama_token llama_sample_token_mirostat_v2(struct llama_context * ctx, llama_token_data_array * candidates, float tau, float eta, float * mu);
|
||||
# @details Mirostat 2.0 algorithm described in the paper https://arxiv.org/abs/2007.14966.
|
||||
# Uses tokens instead of words.
|
||||
# @param candidates A vector of `llama_token_data` containing the candidate tokens,
|
||||
# their probabilities (p), and log-odds (logit) for the current position in the generated text.
|
||||
# @param tau The target cross-entropy (or surprise) value you want to achieve for the generated
|
||||
# text. A higher value corresponds to more surprising or less predictable text, while a lower value
|
||||
# corresponds to less surprising or more predictable text.
|
||||
# @param eta The learning rate used to update `mu` based on the error between the target and
|
||||
# observed surprisal of the sampled word. A larger learning rate will cause `mu` to be
|
||||
# updated more quickly, while a smaller learning rate will result in slower updates.
|
||||
# @param mu Maximum cross-entropy. This value is initialized to be twice the target cross-entropy
|
||||
# (`2 * tau`) and is updated in the algorithm based on the error between the target
|
||||
# and observed surprisal.
|
||||
# LLAMA_API llama_token llama_sample_token_mirostat_v2(struct llama_context * ctx,
|
||||
# llama_token_data_array * candidates, float tau, float eta, float * mu);
|
||||
def llama_sample_token_mirostat_v2(
|
||||
ctx: llama_context_p,
|
||||
candidates, # type: _Pointer[llama_token_data_array]
|
||||
|
|
@ -824,7 +889,8 @@ _lib.llama_sample_token_mirostat_v2.restype = llama_token
|
|||
|
||||
|
||||
# @details Selects the token with the highest probability.
|
||||
# LLAMA_API llama_token llama_sample_token_greedy(struct llama_context * ctx, llama_token_data_array * candidates);
|
||||
# LLAMA_API llama_token llama_sample_token_greedy(struct llama_context * ctx,
|
||||
# llama_token_data_array * candidates);
|
||||
def llama_sample_token_greedy(
|
||||
ctx: llama_context_p,
|
||||
candidates, # type: _Pointer[llama_token_data_array]
|
||||
|
|
@ -840,7 +906,8 @@ _lib.llama_sample_token_greedy.restype = llama_token
|
|||
|
||||
|
||||
# @details Randomly selects a token from the candidates based on their probabilities.
|
||||
# LLAMA_API llama_token llama_sample_token(struct llama_context * ctx, llama_token_data_array * candidates);
|
||||
# LLAMA_API llama_token llama_sample_token(struct llama_context * ctx,
|
||||
# llama_token_data_array * candidates);
|
||||
def llama_sample_token(
|
||||
ctx: llama_context_p,
|
||||
candidates, # type: _Pointer[llama_token_data_array]
|
||||
|
|
|
|||
|
|
@ -13,6 +13,32 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
# ===========================================================================
|
||||
#
|
||||
# This file is adapted from
|
||||
# https://github.com/abetlen/llama-cpp-python/blob/main/llama_cpp/llama_types.py
|
||||
#
|
||||
# MIT License
|
||||
#
|
||||
# Copyright (c) 2023 Andrei Betlen
|
||||
#
|
||||
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
# of this software and associated documentation files (the "Software"), to deal
|
||||
# in the Software without restriction, including without limitation the rights
|
||||
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
# copies of the Software, and to permit persons to whom the Software is
|
||||
# furnished to do so, subject to the following conditions:
|
||||
#
|
||||
# The above copyright notice and this permission notice shall be included in all
|
||||
# copies or substantial portions of the Software.
|
||||
#
|
||||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
# SOFTWARE.
|
||||
|
||||
# This would makes sure Python is aware there is more than one sub-package within bigdl,
|
||||
# physically located elsewhere.
|
||||
|
|
|
|||
Loading…
Reference in a new issue