LLM: fix and update related license in llama pybinding (#8250)
This commit is contained in:
		
							parent
							
								
									141febec1f
								
							
						
					
					
						commit
						3a9aa23835
					
				
					 3 changed files with 212 additions and 96 deletions
				
			
		| 
						 | 
					@ -14,6 +14,33 @@
 | 
				
			||||||
# limitations under the License.
 | 
					# limitations under the License.
 | 
				
			||||||
#
 | 
					#
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# ===========================================================================
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# This file is adapted from
 | 
				
			||||||
 | 
					# https://github.com/abetlen/llama-cpp-python/blob/main/llama_cpp/llama.py
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# MIT License
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Copyright (c) 2023 Andrei Betlen
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Permission is hereby granted, free of charge, to any person obtaining a copy
 | 
				
			||||||
 | 
					# of this software and associated documentation files (the "Software"), to deal
 | 
				
			||||||
 | 
					# in the Software without restriction, including without limitation the rights
 | 
				
			||||||
 | 
					# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
 | 
				
			||||||
 | 
					# copies of the Software, and to permit persons to whom the Software is
 | 
				
			||||||
 | 
					# furnished to do so, subject to the following conditions:
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# The above copyright notice and this permission notice shall be included in all
 | 
				
			||||||
 | 
					# copies or substantial portions of the Software.
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
 | 
				
			||||||
 | 
					# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
 | 
				
			||||||
 | 
					# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
 | 
				
			||||||
 | 
					# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
 | 
				
			||||||
 | 
					# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
 | 
				
			||||||
 | 
					# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
 | 
				
			||||||
 | 
					# SOFTWARE.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# This would makes sure Python is aware there is more than one sub-package within bigdl,
 | 
					# This would makes sure Python is aware there is more than one sub-package within bigdl,
 | 
				
			||||||
# physically located elsewhere.
 | 
					# physically located elsewhere.
 | 
				
			||||||
# Otherwise there would be module not found error in non-pip's setting as Python would
 | 
					# Otherwise there would be module not found error in non-pip's setting as Python would
 | 
				
			||||||
| 
						 | 
					@ -27,7 +54,7 @@ import math
 | 
				
			||||||
import multiprocessing
 | 
					import multiprocessing
 | 
				
			||||||
from typing import List, Optional, Union, Generator, Sequence, Iterator, Deque, Tuple
 | 
					from typing import List, Optional, Union, Generator, Sequence, Iterator, Deque, Tuple
 | 
				
			||||||
from collections import deque, OrderedDict
 | 
					from collections import deque, OrderedDict
 | 
				
			||||||
 | 
					from bigdl.llm.utils.common import invalidInputError
 | 
				
			||||||
from . import llama_cpp
 | 
					from . import llama_cpp
 | 
				
			||||||
from .llama_types import *
 | 
					from .llama_types import *
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -61,8 +88,7 @@ class LlamaCache:
 | 
				
			||||||
    def __getitem__(self, key: Sequence[int]) -> "LlamaState":
 | 
					    def __getitem__(self, key: Sequence[int]) -> "LlamaState":
 | 
				
			||||||
        key = tuple(key)
 | 
					        key = tuple(key)
 | 
				
			||||||
        _key = self._find_longest_prefix_key(key)
 | 
					        _key = self._find_longest_prefix_key(key)
 | 
				
			||||||
        if _key is None:
 | 
					        invalidInputError(_key is not None, "Key not found.")
 | 
				
			||||||
            raise KeyError(f"Key not found")
 | 
					 | 
				
			||||||
        value = self.cache_state[_key]
 | 
					        value = self.cache_state[_key]
 | 
				
			||||||
        self.cache_state.move_to_end(_key)
 | 
					        self.cache_state.move_to_end(_key)
 | 
				
			||||||
        return value
 | 
					        return value
 | 
				
			||||||
| 
						 | 
					@ -122,7 +148,8 @@ class Llama:
 | 
				
			||||||
        Args:
 | 
					        Args:
 | 
				
			||||||
            model_path: Path to the model.
 | 
					            model_path: Path to the model.
 | 
				
			||||||
            n_ctx: Maximum context size.
 | 
					            n_ctx: Maximum context size.
 | 
				
			||||||
            n_parts: Number of parts to split the model into. If -1, the number of parts is automatically determined.
 | 
					            n_parts: Number of parts to split the model into. If -1, the number of parts
 | 
				
			||||||
 | 
					            is automatically determined.
 | 
				
			||||||
            seed: Random seed. 0 for random.
 | 
					            seed: Random seed. 0 for random.
 | 
				
			||||||
            f16_kv: Use half-precision for key/value cache.
 | 
					            f16_kv: Use half-precision for key/value cache.
 | 
				
			||||||
            logits_all: Return logits for all tokens, not just the last token.
 | 
					            logits_all: Return logits for all tokens, not just the last token.
 | 
				
			||||||
| 
						 | 
					@ -130,10 +157,12 @@ class Llama:
 | 
				
			||||||
            use_mmap: Use mmap if possible.
 | 
					            use_mmap: Use mmap if possible.
 | 
				
			||||||
            use_mlock: Force the system to keep the model in RAM.
 | 
					            use_mlock: Force the system to keep the model in RAM.
 | 
				
			||||||
            embedding: Embedding mode only.
 | 
					            embedding: Embedding mode only.
 | 
				
			||||||
            n_threads: Number of threads to use. If None, the number of threads is automatically determined.
 | 
					            n_threads: Number of threads to use. If None, the number of threads is
 | 
				
			||||||
 | 
					            automatically determined.
 | 
				
			||||||
            n_batch: Maximum number of prompt tokens to batch together when calling llama_eval.
 | 
					            n_batch: Maximum number of prompt tokens to batch together when calling llama_eval.
 | 
				
			||||||
            last_n_tokens_size: Maximum number of tokens to keep in the last_n_tokens deque.
 | 
					            last_n_tokens_size: Maximum number of tokens to keep in the last_n_tokens deque.
 | 
				
			||||||
            lora_base: Optional path to base model, useful if using a quantized base model and you want to apply LoRA to an f16 model.
 | 
					            lora_base: Optional path to base model, useful if using a quantized base model and
 | 
				
			||||||
 | 
					            you want to apply LoRA to an f16 model.
 | 
				
			||||||
            lora_path: Path to a LoRA file to apply to the model.
 | 
					            lora_path: Path to a LoRA file to apply to the model.
 | 
				
			||||||
            verbose: Print verbose output to stderr.
 | 
					            verbose: Print verbose output to stderr.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -169,18 +198,17 @@ class Llama:
 | 
				
			||||||
        self.lora_base = lora_base
 | 
					        self.lora_base = lora_base
 | 
				
			||||||
        self.lora_path = lora_path
 | 
					        self.lora_path = lora_path
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        ### DEPRECATED ###
 | 
					        # DEPRECATED
 | 
				
			||||||
        self.n_parts = n_parts
 | 
					        self.n_parts = n_parts
 | 
				
			||||||
        ### DEPRECATED ###
 | 
					        # DEPRECATED
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if not os.path.exists(model_path):
 | 
					        invalidInputError(os.path.exists(model_path), f"Model path does not exist: {model_path}.")
 | 
				
			||||||
            raise ValueError(f"Model path does not exist: {model_path}")
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
        self.ctx = llama_cpp.llama_init_from_file(
 | 
					        self.ctx = llama_cpp.llama_init_from_file(
 | 
				
			||||||
            self.model_path.encode("utf-8"), self.params
 | 
					            self.model_path.encode("utf-8"), self.params
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        assert self.ctx is not None
 | 
					        invalidInputError(self.ctx is not None, "The attribute `ctx` of `Llama` object is None.")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if self.lora_path:
 | 
					        if self.lora_path:
 | 
				
			||||||
            if llama_cpp.llama_apply_lora_from_file(
 | 
					            if llama_cpp.llama_apply_lora_from_file(
 | 
				
			||||||
| 
						 | 
					@ -191,9 +219,9 @@ class Llama:
 | 
				
			||||||
                else llama_cpp.c_char_p(0),
 | 
					                else llama_cpp.c_char_p(0),
 | 
				
			||||||
                llama_cpp.c_int(self.n_threads),
 | 
					                llama_cpp.c_int(self.n_threads),
 | 
				
			||||||
            ):
 | 
					            ):
 | 
				
			||||||
                raise RuntimeError(
 | 
					                invalidInputError(False,
 | 
				
			||||||
                    f"Failed to apply LoRA from lora path: {self.lora_path} to base path: {self.lora_base}"
 | 
					                                  "Failed to apply LoRA from lora path: "
 | 
				
			||||||
                )
 | 
					                                  f"{self.lora_path} to base path: {self.lora_base}")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if self.verbose:
 | 
					        if self.verbose:
 | 
				
			||||||
            print(llama_cpp.llama_print_system_info().decode("utf-8"), file=sys.stderr)
 | 
					            print(llama_cpp.llama_print_system_info().decode("utf-8"), file=sys.stderr)
 | 
				
			||||||
| 
						 | 
					@ -233,7 +261,7 @@ class Llama:
 | 
				
			||||||
        Returns:
 | 
					        Returns:
 | 
				
			||||||
            A list of tokens.
 | 
					            A list of tokens.
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
        assert self.ctx is not None
 | 
					        invalidInputError(self.ctx is not None, "The attribute `ctx` of `Llama` object is None.")
 | 
				
			||||||
        n_ctx = llama_cpp.llama_n_ctx(self.ctx)
 | 
					        n_ctx = llama_cpp.llama_n_ctx(self.ctx)
 | 
				
			||||||
        tokens = (llama_cpp.llama_token * int(n_ctx))()
 | 
					        tokens = (llama_cpp.llama_token * int(n_ctx))()
 | 
				
			||||||
        n_tokens = llama_cpp.llama_tokenize(
 | 
					        n_tokens = llama_cpp.llama_tokenize(
 | 
				
			||||||
| 
						 | 
					@ -253,10 +281,8 @@ class Llama:
 | 
				
			||||||
                llama_cpp.c_int(n_tokens),
 | 
					                llama_cpp.c_int(n_tokens),
 | 
				
			||||||
                llama_cpp.c_bool(add_bos),
 | 
					                llama_cpp.c_bool(add_bos),
 | 
				
			||||||
            )
 | 
					            )
 | 
				
			||||||
            if n_tokens < 0:
 | 
					            invalidInputError(n_tokens >= 0,
 | 
				
			||||||
                raise RuntimeError(
 | 
					                              f'Failed to tokenize: text="{text}" n_tokens={n_tokens}')
 | 
				
			||||||
                    f'Failed to tokenize: text="{text}" n_tokens={n_tokens}'
 | 
					 | 
				
			||||||
                )
 | 
					 | 
				
			||||||
        return list(tokens[:n_tokens])
 | 
					        return list(tokens[:n_tokens])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def detokenize(self, tokens: List[int]) -> bytes:
 | 
					    def detokenize(self, tokens: List[int]) -> bytes:
 | 
				
			||||||
| 
						 | 
					@ -268,7 +294,7 @@ class Llama:
 | 
				
			||||||
        Returns:
 | 
					        Returns:
 | 
				
			||||||
            The detokenized string.
 | 
					            The detokenized string.
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
        assert self.ctx is not None
 | 
					        invalidInputError(self.ctx is not None, "The attribute `ctx` of `Llama` object is None.")
 | 
				
			||||||
        output = b""
 | 
					        output = b""
 | 
				
			||||||
        for token in tokens:
 | 
					        for token in tokens:
 | 
				
			||||||
            output += llama_cpp.llama_token_to_str(
 | 
					            output += llama_cpp.llama_token_to_str(
 | 
				
			||||||
| 
						 | 
					@ -295,10 +321,10 @@ class Llama:
 | 
				
			||||||
        Args:
 | 
					        Args:
 | 
				
			||||||
            tokens: The list of tokens to evaluate.
 | 
					            tokens: The list of tokens to evaluate.
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
        assert self.ctx is not None
 | 
					        invalidInputError(self.ctx is not None, "The attribute `ctx` of `Llama` object is None.")
 | 
				
			||||||
        n_ctx = int(llama_cpp.llama_n_ctx(self.ctx))
 | 
					        n_ctx = int(llama_cpp.llama_n_ctx(self.ctx))
 | 
				
			||||||
        for i in range(0, len(tokens), self.n_batch):
 | 
					        for i in range(0, len(tokens), self.n_batch):
 | 
				
			||||||
            batch = tokens[i : min(len(tokens), i + self.n_batch)]
 | 
					            batch = tokens[i: min(len(tokens), i + self.n_batch)]
 | 
				
			||||||
            n_past = min(n_ctx - len(batch), len(self.eval_tokens))
 | 
					            n_past = min(n_ctx - len(batch), len(self.eval_tokens))
 | 
				
			||||||
            n_tokens = len(batch)
 | 
					            n_tokens = len(batch)
 | 
				
			||||||
            return_code = llama_cpp.llama_eval(
 | 
					            return_code = llama_cpp.llama_eval(
 | 
				
			||||||
| 
						 | 
					@ -308,8 +334,7 @@ class Llama:
 | 
				
			||||||
                n_past=llama_cpp.c_int(n_past),
 | 
					                n_past=llama_cpp.c_int(n_past),
 | 
				
			||||||
                n_threads=llama_cpp.c_int(self.n_threads),
 | 
					                n_threads=llama_cpp.c_int(self.n_threads),
 | 
				
			||||||
            )
 | 
					            )
 | 
				
			||||||
            if int(return_code) != 0:
 | 
					            invalidInputError(int(return_code) == 0, f"llama_eval returned {return_code}.")
 | 
				
			||||||
                raise RuntimeError(f"llama_eval returned {return_code}")
 | 
					 | 
				
			||||||
            # Save tokens
 | 
					            # Save tokens
 | 
				
			||||||
            self.eval_tokens.extend(batch)
 | 
					            self.eval_tokens.extend(batch)
 | 
				
			||||||
            # Save logits
 | 
					            # Save logits
 | 
				
			||||||
| 
						 | 
					@ -338,8 +363,9 @@ class Llama:
 | 
				
			||||||
        mirostat_eta: llama_cpp.c_float,
 | 
					        mirostat_eta: llama_cpp.c_float,
 | 
				
			||||||
        penalize_nl: bool = True,
 | 
					        penalize_nl: bool = True,
 | 
				
			||||||
    ):
 | 
					    ):
 | 
				
			||||||
        assert self.ctx is not None
 | 
					        invalidInputError(self.ctx is not None, "The attribute `ctx` of `Llama` object is None.")
 | 
				
			||||||
        assert len(self.eval_logits) > 0
 | 
					        invalidInputError(len(self.eval_logits) > 0,
 | 
				
			||||||
 | 
					                          "The attribute `eval_logits` of `Llama` object is None.")
 | 
				
			||||||
        n_vocab = self.n_vocab()
 | 
					        n_vocab = self.n_vocab()
 | 
				
			||||||
        n_ctx = self.n_ctx()
 | 
					        n_ctx = self.n_ctx()
 | 
				
			||||||
        top_k = llama_cpp.c_int(n_vocab) if top_k.value <= 0 else top_k
 | 
					        top_k = llama_cpp.c_int(n_vocab) if top_k.value <= 0 else top_k
 | 
				
			||||||
| 
						 | 
					@ -467,10 +493,10 @@ class Llama:
 | 
				
			||||||
        Returns:
 | 
					        Returns:
 | 
				
			||||||
            The sampled token.
 | 
					            The sampled token.
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
        assert self.ctx is not None
 | 
					        invalidInputError(self.ctx is not None, "The attribute `ctx` of `Llama` object is None.")
 | 
				
			||||||
        last_n_tokens_data = [llama_cpp.llama_token(0)] * max(
 | 
					        last_n_tokens_data = [llama_cpp.llama_token(0)] * max(
 | 
				
			||||||
            0, self.last_n_tokens_size - len(self.eval_tokens)
 | 
					            0, self.last_n_tokens_size - len(self.eval_tokens)
 | 
				
			||||||
        ) + list(self.eval_tokens)[-self.last_n_tokens_size :]
 | 
					        ) + list(self.eval_tokens)[-self.last_n_tokens_size:]
 | 
				
			||||||
        return self._sample(
 | 
					        return self._sample(
 | 
				
			||||||
            last_n_tokens_data=(llama_cpp.llama_token * self.last_n_tokens_size)(
 | 
					            last_n_tokens_data=(llama_cpp.llama_token * self.last_n_tokens_size)(
 | 
				
			||||||
                *last_n_tokens_data
 | 
					                *last_n_tokens_data
 | 
				
			||||||
| 
						 | 
					@ -509,7 +535,8 @@ class Llama:
 | 
				
			||||||
        Examples:
 | 
					        Examples:
 | 
				
			||||||
            >>> llama = Llama("models/ggml-7b.bin")
 | 
					            >>> llama = Llama("models/ggml-7b.bin")
 | 
				
			||||||
            >>> tokens = llama.tokenize(b"Hello, world!")
 | 
					            >>> tokens = llama.tokenize(b"Hello, world!")
 | 
				
			||||||
            >>> for token in llama.generate(tokens, top_k=40, top_p=0.95, temp=1.0, repeat_penalty=1.1):
 | 
					            >>> for token in llama.generate(tokens, top_k=40, top_p=0.95,
 | 
				
			||||||
 | 
					            >>>                             temp=1.0, repeat_penalty=1.1):
 | 
				
			||||||
            ...     print(llama.detokenize([token]))
 | 
					            ...     print(llama.detokenize([token]))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        Args:
 | 
					        Args:
 | 
				
			||||||
| 
						 | 
					@ -523,7 +550,7 @@ class Llama:
 | 
				
			||||||
        Yields:
 | 
					        Yields:
 | 
				
			||||||
            The generated tokens.
 | 
					            The generated tokens.
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
        assert self.ctx is not None
 | 
					        invalidInputError(self.ctx is not None, "The attribute `ctx` of `Llama` object is None.")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if reset and len(self.eval_tokens) > 0:
 | 
					        if reset and len(self.eval_tokens) > 0:
 | 
				
			||||||
            longest_prefix = 0
 | 
					            longest_prefix = 0
 | 
				
			||||||
| 
						 | 
					@ -577,13 +604,11 @@ class Llama:
 | 
				
			||||||
        Returns:
 | 
					        Returns:
 | 
				
			||||||
            An embedding object.
 | 
					            An embedding object.
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
        assert self.ctx is not None
 | 
					        invalidInputError(self.ctx is not None, "The attribute `ctx` of `Llama` object is None.")
 | 
				
			||||||
        model_name: str = model if model is not None else self.model_path
 | 
					        model_name: str = model if model is not None else self.model_path
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if self.params.embedding == False:
 | 
					        invalidInputError(self.params.embedding,
 | 
				
			||||||
            raise RuntimeError(
 | 
					                          "Llama model must be created with embedding=True to call this method.")
 | 
				
			||||||
                "Llama model must be created with embedding=True to call this method"
 | 
					 | 
				
			||||||
            )
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if self.verbose:
 | 
					        if self.verbose:
 | 
				
			||||||
            llama_cpp.llama_reset_timings(self.ctx)
 | 
					            llama_cpp.llama_reset_timings(self.ctx)
 | 
				
			||||||
| 
						 | 
					@ -645,7 +670,7 @@ class Llama:
 | 
				
			||||||
        top_p: float = 0.95,
 | 
					        top_p: float = 0.95,
 | 
				
			||||||
        logprobs: Optional[int] = None,
 | 
					        logprobs: Optional[int] = None,
 | 
				
			||||||
        echo: bool = False,
 | 
					        echo: bool = False,
 | 
				
			||||||
        stop: Optional[Union[str, List[str]]] = [],
 | 
					        stop: Optional[Union[str, List[str]]]=[],
 | 
				
			||||||
        frequency_penalty: float = 0.0,
 | 
					        frequency_penalty: float = 0.0,
 | 
				
			||||||
        presence_penalty: float = 0.0,
 | 
					        presence_penalty: float = 0.0,
 | 
				
			||||||
        repeat_penalty: float = 1.1,
 | 
					        repeat_penalty: float = 1.1,
 | 
				
			||||||
| 
						 | 
					@ -657,7 +682,7 @@ class Llama:
 | 
				
			||||||
        mirostat_eta: float = 0.1,
 | 
					        mirostat_eta: float = 0.1,
 | 
				
			||||||
        model: Optional[str] = None,
 | 
					        model: Optional[str] = None,
 | 
				
			||||||
    ) -> Union[Iterator[Completion], Iterator[CompletionChunk]]:
 | 
					    ) -> Union[Iterator[Completion], Iterator[CompletionChunk]]:
 | 
				
			||||||
        assert self.ctx is not None
 | 
					        invalidInputError(self.ctx is not None, "The attribute `ctx` of `Llama` object is None.")
 | 
				
			||||||
        completion_id: str = f"cmpl-{str(uuid.uuid4())}"
 | 
					        completion_id: str = f"cmpl-{str(uuid.uuid4())}"
 | 
				
			||||||
        created: int = int(time.time())
 | 
					        created: int = int(time.time())
 | 
				
			||||||
        completion_tokens: List[int] = []
 | 
					        completion_tokens: List[int] = []
 | 
				
			||||||
| 
						 | 
					@ -673,10 +698,9 @@ class Llama:
 | 
				
			||||||
        if self.verbose:
 | 
					        if self.verbose:
 | 
				
			||||||
            llama_cpp.llama_reset_timings(self.ctx)
 | 
					            llama_cpp.llama_reset_timings(self.ctx)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if len(prompt_tokens) + max_tokens > int(llama_cpp.llama_n_ctx(self.ctx)):
 | 
					        invalidInputError(len(prompt_tokens) + max_tokens <= int(llama_cpp.llama_n_ctx(self.ctx)),
 | 
				
			||||||
            raise ValueError(
 | 
					                          "Requested tokens exceed context window of "
 | 
				
			||||||
                f"Requested tokens exceed context window of {llama_cpp.llama_n_ctx(self.ctx)}"
 | 
					                          f"{llama_cpp.llama_n_ctx(self.ctx)}.")
 | 
				
			||||||
            )
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if stop != []:
 | 
					        if stop != []:
 | 
				
			||||||
            stop_sequences = [s.encode("utf-8") for s in stop]
 | 
					            stop_sequences = [s.encode("utf-8") for s in stop]
 | 
				
			||||||
| 
						 | 
					@ -684,9 +708,8 @@ class Llama:
 | 
				
			||||||
            stop_sequences = []
 | 
					            stop_sequences = []
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if logprobs is not None and self.params.logits_all is False:
 | 
					        if logprobs is not None and self.params.logits_all is False:
 | 
				
			||||||
            raise ValueError(
 | 
					            invalidInputError(False,
 | 
				
			||||||
                "logprobs is not supported for models created with logits_all=False"
 | 
					                              "logprobs is not supported for models created with logits_all=False")
 | 
				
			||||||
            )
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if self.cache:
 | 
					        if self.cache:
 | 
				
			||||||
            try:
 | 
					            try:
 | 
				
			||||||
| 
						 | 
					@ -1023,7 +1046,7 @@ class Llama:
 | 
				
			||||||
        top_p: float = 0.95,
 | 
					        top_p: float = 0.95,
 | 
				
			||||||
        logprobs: Optional[int] = None,
 | 
					        logprobs: Optional[int] = None,
 | 
				
			||||||
        echo: bool = False,
 | 
					        echo: bool = False,
 | 
				
			||||||
        stop: Optional[Union[str, List[str]]] = [],
 | 
					        stop: Optional[Union[str, List[str]]]=[],
 | 
				
			||||||
        frequency_penalty: float = 0.0,
 | 
					        frequency_penalty: float = 0.0,
 | 
				
			||||||
        presence_penalty: float = 0.0,
 | 
					        presence_penalty: float = 0.0,
 | 
				
			||||||
        repeat_penalty: float = 1.1,
 | 
					        repeat_penalty: float = 1.1,
 | 
				
			||||||
| 
						 | 
					@ -1092,7 +1115,7 @@ class Llama:
 | 
				
			||||||
        top_p: float = 0.95,
 | 
					        top_p: float = 0.95,
 | 
				
			||||||
        logprobs: Optional[int] = None,
 | 
					        logprobs: Optional[int] = None,
 | 
				
			||||||
        echo: bool = False,
 | 
					        echo: bool = False,
 | 
				
			||||||
        stop: Optional[Union[str, List[str]]] = [],
 | 
					        stop: Optional[Union[str, List[str]]]=[],
 | 
				
			||||||
        frequency_penalty: float = 0.0,
 | 
					        frequency_penalty: float = 0.0,
 | 
				
			||||||
        presence_penalty: float = 0.0,
 | 
					        presence_penalty: float = 0.0,
 | 
				
			||||||
        repeat_penalty: float = 1.1,
 | 
					        repeat_penalty: float = 1.1,
 | 
				
			||||||
| 
						 | 
					@ -1212,7 +1235,7 @@ class Llama:
 | 
				
			||||||
        top_p: float = 0.95,
 | 
					        top_p: float = 0.95,
 | 
				
			||||||
        top_k: int = 40,
 | 
					        top_k: int = 40,
 | 
				
			||||||
        stream: bool = False,
 | 
					        stream: bool = False,
 | 
				
			||||||
        stop: Optional[Union[str, List[str]]] = [],
 | 
					        stop: Optional[Union[str, List[str]]]=[],
 | 
				
			||||||
        max_tokens: int = 256,
 | 
					        max_tokens: int = 256,
 | 
				
			||||||
        presence_penalty: float = 0.0,
 | 
					        presence_penalty: float = 0.0,
 | 
				
			||||||
        frequency_penalty: float = 0.0,
 | 
					        frequency_penalty: float = 0.0,
 | 
				
			||||||
| 
						 | 
					@ -1294,9 +1317,9 @@ class Llama:
 | 
				
			||||||
            n_threads=self.n_threads,
 | 
					            n_threads=self.n_threads,
 | 
				
			||||||
            lora_base=self.lora_base,
 | 
					            lora_base=self.lora_base,
 | 
				
			||||||
            lora_path=self.lora_path,
 | 
					            lora_path=self.lora_path,
 | 
				
			||||||
            ### DEPRECATED ###
 | 
					            # DEPRECATED
 | 
				
			||||||
            n_parts=self.n_parts,
 | 
					            n_parts=self.n_parts,
 | 
				
			||||||
            ### DEPRECATED ###
 | 
					            # DEPRECATED
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def __setstate__(self, state):
 | 
					    def __setstate__(self, state):
 | 
				
			||||||
| 
						 | 
					@ -1321,12 +1344,11 @@ class Llama:
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def save_state(self) -> LlamaState:
 | 
					    def save_state(self) -> LlamaState:
 | 
				
			||||||
        assert self.ctx is not None
 | 
					        invalidInputError(self.ctx is not None, "The attribute `ctx` of `Llama` object is None.")
 | 
				
			||||||
        state_size = llama_cpp.llama_get_state_size(self.ctx)
 | 
					        state_size = llama_cpp.llama_get_state_size(self.ctx)
 | 
				
			||||||
        llama_state = (llama_cpp.c_uint8 * int(state_size))()
 | 
					        llama_state = (llama_cpp.c_uint8 * int(state_size))()
 | 
				
			||||||
        n_bytes = llama_cpp.llama_copy_state_data(self.ctx, llama_state)
 | 
					        n_bytes = llama_cpp.llama_copy_state_data(self.ctx, llama_state)
 | 
				
			||||||
        if int(n_bytes) > int(state_size):
 | 
					        invalidInputError(int(n_bytes) <= int(state_size), "Failed to copy llama state data.")
 | 
				
			||||||
            raise RuntimeError("Failed to copy llama state data")
 | 
					 | 
				
			||||||
        llama_state_compact = (llama_cpp.c_uint8 * int(n_bytes))()
 | 
					        llama_state_compact = (llama_cpp.c_uint8 * int(n_bytes))()
 | 
				
			||||||
        llama_cpp.ctypes.memmove(llama_state_compact, llama_state, int(n_bytes))
 | 
					        llama_cpp.ctypes.memmove(llama_state_compact, llama_state, int(n_bytes))
 | 
				
			||||||
        if self.verbose:
 | 
					        if self.verbose:
 | 
				
			||||||
| 
						 | 
					@ -1342,26 +1364,27 @@ class Llama:
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def load_state(self, state: LlamaState) -> None:
 | 
					    def load_state(self, state: LlamaState) -> None:
 | 
				
			||||||
        assert self.ctx is not None
 | 
					        invalidInputError(self.ctx is not None, "The attribute `ctx` of `Llama` object is None.")
 | 
				
			||||||
        self.eval_tokens = state.eval_tokens.copy()
 | 
					        self.eval_tokens = state.eval_tokens.copy()
 | 
				
			||||||
        self.eval_logits = state.eval_logits.copy()
 | 
					        self.eval_logits = state.eval_logits.copy()
 | 
				
			||||||
        state_size = state.llama_state_size
 | 
					        state_size = state.llama_state_size
 | 
				
			||||||
        if llama_cpp.llama_set_state_data(self.ctx, state.llama_state) != state_size:
 | 
					        invalidInputError(llama_cpp.llama_set_state_data(self.ctx,
 | 
				
			||||||
            raise RuntimeError("Failed to set llama state data")
 | 
					                                                         state.llama_state) == state_size,
 | 
				
			||||||
 | 
					                          "Failed to set llama state data.")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def n_ctx(self) -> int:
 | 
					    def n_ctx(self) -> int:
 | 
				
			||||||
        """Return the context window size."""
 | 
					        """Return the context window size."""
 | 
				
			||||||
        assert self.ctx is not None
 | 
					        invalidInputError(self.ctx is not None, "The attribute `ctx` of `Llama` object is None.")
 | 
				
			||||||
        return llama_cpp.llama_n_ctx(self.ctx)
 | 
					        return llama_cpp.llama_n_ctx(self.ctx)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def n_embd(self) -> int:
 | 
					    def n_embd(self) -> int:
 | 
				
			||||||
        """Return the embedding size."""
 | 
					        """Return the embedding size."""
 | 
				
			||||||
        assert self.ctx is not None
 | 
					        invalidInputError(self.ctx is not None, "The attribute `ctx` of `Llama` object is None.")
 | 
				
			||||||
        return llama_cpp.llama_n_embd(self.ctx)
 | 
					        return llama_cpp.llama_n_embd(self.ctx)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def n_vocab(self) -> int:
 | 
					    def n_vocab(self) -> int:
 | 
				
			||||||
        """Return the vocabulary size."""
 | 
					        """Return the vocabulary size."""
 | 
				
			||||||
        assert self.ctx is not None
 | 
					        invalidInputError(self.ctx is not None, "The attribute `ctx` of `Llama` object is None.")
 | 
				
			||||||
        return llama_cpp.llama_n_vocab(self.ctx)
 | 
					        return llama_cpp.llama_n_vocab(self.ctx)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @staticmethod
 | 
					    @staticmethod
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -13,6 +13,32 @@
 | 
				
			||||||
# See the License for the specific language governing permissions and
 | 
					# See the License for the specific language governing permissions and
 | 
				
			||||||
# limitations under the License.
 | 
					# limitations under the License.
 | 
				
			||||||
#
 | 
					#
 | 
				
			||||||
 | 
					# ===========================================================================
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# This file is adapted from
 | 
				
			||||||
 | 
					# https://github.com/abetlen/llama-cpp-python/blob/main/llama_cpp/llama_cpp.py
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# MIT License
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Copyright (c) 2023 Andrei Betlen
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Permission is hereby granted, free of charge, to any person obtaining a copy
 | 
				
			||||||
 | 
					# of this software and associated documentation files (the "Software"), to deal
 | 
				
			||||||
 | 
					# in the Software without restriction, including without limitation the rights
 | 
				
			||||||
 | 
					# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
 | 
				
			||||||
 | 
					# copies of the Software, and to permit persons to whom the Software is
 | 
				
			||||||
 | 
					# furnished to do so, subject to the following conditions:
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# The above copyright notice and this permission notice shall be included in all
 | 
				
			||||||
 | 
					# copies or substantial portions of the Software.
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
 | 
				
			||||||
 | 
					# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
 | 
				
			||||||
 | 
					# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
 | 
				
			||||||
 | 
					# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
 | 
				
			||||||
 | 
					# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
 | 
				
			||||||
 | 
					# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
 | 
				
			||||||
 | 
					# SOFTWARE.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# This would makes sure Python is aware there is more than one sub-package within bigdl,
 | 
					# This would makes sure Python is aware there is more than one sub-package within bigdl,
 | 
				
			||||||
# physically located elsewhere.
 | 
					# physically located elsewhere.
 | 
				
			||||||
| 
						 | 
					@ -36,6 +62,7 @@ from ctypes import (
 | 
				
			||||||
    c_size_t,
 | 
					    c_size_t,
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
import pathlib
 | 
					import pathlib
 | 
				
			||||||
 | 
					from bigdl.llm.utils.common import invalidInputError
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# Load the library
 | 
					# Load the library
 | 
				
			||||||
| 
						 | 
					@ -48,7 +75,7 @@ def _load_shared_library(lib_base_name: str):
 | 
				
			||||||
    elif sys.platform == "win32":
 | 
					    elif sys.platform == "win32":
 | 
				
			||||||
        lib_ext = ".dll"
 | 
					        lib_ext = ".dll"
 | 
				
			||||||
    else:
 | 
					    else:
 | 
				
			||||||
        raise RuntimeError("Unsupported platform")
 | 
					        invalidInputError(False, "Unsupported platform.")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Construct the paths to the possible shared library names
 | 
					    # Construct the paths to the possible shared library names
 | 
				
			||||||
    _base_path = pathlib.Path(__file__).parent.parent.parent.parent.resolve()
 | 
					    _base_path = pathlib.Path(__file__).parent.parent.parent.parent.resolve()
 | 
				
			||||||
| 
						 | 
					@ -81,11 +108,9 @@ def _load_shared_library(lib_base_name: str):
 | 
				
			||||||
            try:
 | 
					            try:
 | 
				
			||||||
                return ctypes.CDLL(str(_lib_path), **cdll_args)
 | 
					                return ctypes.CDLL(str(_lib_path), **cdll_args)
 | 
				
			||||||
            except Exception as e:
 | 
					            except Exception as e:
 | 
				
			||||||
                raise RuntimeError(f"Failed to load shared library '{_lib_path}': {e}")
 | 
					                invalidInputError(False, f"Failed to load shared library '{_lib_path}': {e}.")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    raise FileNotFoundError(
 | 
					    invalidInputError(False, f"Shared library with base name '{lib_base_name}' not found.")
 | 
				
			||||||
        f"Shared library with base name '{lib_base_name}' not found"
 | 
					 | 
				
			||||||
    )
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# Specify the base name of the shared library to load
 | 
					# Specify the base name of the shared library to load
 | 
				
			||||||
| 
						 | 
					@ -300,7 +325,8 @@ _lib.llama_free.restype = None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# TODO: not great API - very likely to change
 | 
					# TODO: not great API - very likely to change
 | 
				
			||||||
# Returns 0 on success
 | 
					# Returns 0 on success
 | 
				
			||||||
# nthread - how many threads to use. If <=0, will use std::thread::hardware_concurrency(), else the number given
 | 
					# nthread - how many threads to use. If <=0, will use std::thread::hardware_concurrency(),
 | 
				
			||||||
 | 
					# else the number given
 | 
				
			||||||
# LLAMA_API int llama_model_quantize(
 | 
					# LLAMA_API int llama_model_quantize(
 | 
				
			||||||
#         const char * fname_inp,
 | 
					#         const char * fname_inp,
 | 
				
			||||||
#         const char * fname_out,
 | 
					#         const char * fname_out,
 | 
				
			||||||
| 
						 | 
					@ -399,7 +425,8 @@ _lib.llama_set_state_data.restype = c_size_t
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# Save/load session file
 | 
					# Save/load session file
 | 
				
			||||||
# LLAMA_API bool llama_load_session_file(struct llama_context * ctx, const char * path_session, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out);
 | 
					# LLAMA_API bool llama_load_session_file(struct llama_context * ctx, const char * path_session,
 | 
				
			||||||
 | 
					# llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out);
 | 
				
			||||||
def llama_load_session_file(
 | 
					def llama_load_session_file(
 | 
				
			||||||
    ctx: llama_context_p,
 | 
					    ctx: llama_context_p,
 | 
				
			||||||
    path_session: bytes,
 | 
					    path_session: bytes,
 | 
				
			||||||
| 
						 | 
					@ -422,7 +449,8 @@ _lib.llama_load_session_file.argtypes = [
 | 
				
			||||||
_lib.llama_load_session_file.restype = c_size_t
 | 
					_lib.llama_load_session_file.restype = c_size_t
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# LLAMA_API bool llama_save_session_file(struct llama_context * ctx, const char * path_session, const llama_token * tokens, size_t n_token_count);
 | 
					# LLAMA_API bool llama_save_session_file(struct llama_context * ctx, const char * path_session,
 | 
				
			||||||
 | 
					# const llama_token * tokens, size_t n_token_count);
 | 
				
			||||||
def llama_save_session_file(
 | 
					def llama_save_session_file(
 | 
				
			||||||
    ctx: llama_context_p,
 | 
					    ctx: llama_context_p,
 | 
				
			||||||
    path_session: bytes,
 | 
					    path_session: bytes,
 | 
				
			||||||
| 
						 | 
					@ -601,8 +629,10 @@ _lib.llama_init_candidates.argtypes = [
 | 
				
			||||||
_lib.llama_init_candidates.restype = None
 | 
					_lib.llama_init_candidates.restype = None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# @details Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix.
 | 
					# @details Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858,
 | 
				
			||||||
# LLAMA_API void llama_sample_repetition_penalty(struct llama_context * ctx, llama_token_data_array * candidates, const llama_token * last_tokens, size_t last_tokens_size, float penalty);
 | 
					# with negative logit fix.
 | 
				
			||||||
 | 
					# LLAMA_API void llama_sample_repetition_penalty(struct llama_context * ctx, llama_token_data_array
 | 
				
			||||||
 | 
					# * candidates, const llama_token * last_tokens, size_t last_tokens_size, float penalty);
 | 
				
			||||||
def llama_sample_repetition_penalty(
 | 
					def llama_sample_repetition_penalty(
 | 
				
			||||||
    ctx: llama_context_p,
 | 
					    ctx: llama_context_p,
 | 
				
			||||||
    candidates,  # type: _Pointer[llama_token_data_array]
 | 
					    candidates,  # type: _Pointer[llama_token_data_array]
 | 
				
			||||||
| 
						 | 
					@ -625,8 +655,11 @@ _lib.llama_sample_repetition_penalty.argtypes = [
 | 
				
			||||||
_lib.llama_sample_repetition_penalty.restype = None
 | 
					_lib.llama_sample_repetition_penalty.restype = None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# @details Frequency and presence penalties described in OpenAI API https://platform.openai.com/docs/api-reference/parameter-details.
 | 
					# @details Frequency and presence penalties described in OpenAI API
 | 
				
			||||||
# LLAMA_API void llama_sample_frequency_and_presence_penalties(struct llama_context * ctx, llama_token_data_array * candidates, const llama_token * last_tokens, size_t last_tokens_size, float alpha_frequency, float alpha_presence);
 | 
					# https://platform.openai.com/docs/api-reference/parameter-details.
 | 
				
			||||||
 | 
					# LLAMA_API void llama_sample_frequency_and_presence_penalties(struct llama_context * ctx,
 | 
				
			||||||
 | 
					# llama_token_data_array * candidates, const llama_token * last_tokens,
 | 
				
			||||||
 | 
					# size_t last_tokens_size, float alpha_frequency, float alpha_presence);
 | 
				
			||||||
def llama_sample_frequency_and_presence_penalties(
 | 
					def llama_sample_frequency_and_presence_penalties(
 | 
				
			||||||
    ctx: llama_context_p,
 | 
					    ctx: llama_context_p,
 | 
				
			||||||
    candidates,  # type: _Pointer[llama_token_data_array]
 | 
					    candidates,  # type: _Pointer[llama_token_data_array]
 | 
				
			||||||
| 
						 | 
					@ -656,8 +689,10 @@ _lib.llama_sample_frequency_and_presence_penalties.argtypes = [
 | 
				
			||||||
_lib.llama_sample_frequency_and_presence_penalties.restype = None
 | 
					_lib.llama_sample_frequency_and_presence_penalties.restype = None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# @details Sorts candidate tokens by their logits in descending order and calculate probabilities based on logits.
 | 
					# @details Sorts candidate tokens by their logits in descending order and calculate probabilities
 | 
				
			||||||
# LLAMA_API void llama_sample_softmax(struct llama_context * ctx, llama_token_data_array * candidates);
 | 
					# based on logits.
 | 
				
			||||||
 | 
					# LLAMA_API void llama_sample_softmax(struct llama_context * ctx,
 | 
				
			||||||
 | 
					# llama_token_data_array * candidates);
 | 
				
			||||||
def llama_sample_softmax(
 | 
					def llama_sample_softmax(
 | 
				
			||||||
    ctx: llama_context_p, candidates  # type: _Pointer[llama_token_data]
 | 
					    ctx: llama_context_p, candidates  # type: _Pointer[llama_token_data]
 | 
				
			||||||
):
 | 
					):
 | 
				
			||||||
| 
						 | 
					@ -671,8 +706,10 @@ _lib.llama_sample_softmax.argtypes = [
 | 
				
			||||||
_lib.llama_sample_softmax.restype = None
 | 
					_lib.llama_sample_softmax.restype = None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# @details Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
 | 
					# @details Top-K sampling described in academic paper
 | 
				
			||||||
# LLAMA_API void llama_sample_top_k(struct llama_context * ctx, llama_token_data_array * candidates, int k, size_t min_keep);
 | 
					# "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
 | 
				
			||||||
 | 
					# LLAMA_API void llama_sample_top_k(struct llama_context * ctx,
 | 
				
			||||||
 | 
					# llama_token_data_array * candidates, int k, size_t min_keep);
 | 
				
			||||||
def llama_sample_top_k(
 | 
					def llama_sample_top_k(
 | 
				
			||||||
    ctx: llama_context_p,
 | 
					    ctx: llama_context_p,
 | 
				
			||||||
    candidates,  # type: _Pointer[llama_token_data_array]
 | 
					    candidates,  # type: _Pointer[llama_token_data_array]
 | 
				
			||||||
| 
						 | 
					@ -691,8 +728,10 @@ _lib.llama_sample_top_k.argtypes = [
 | 
				
			||||||
_lib.llama_sample_top_k.restype = None
 | 
					_lib.llama_sample_top_k.restype = None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# @details Nucleus sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
 | 
					# @details Nucleus sampling described in academic paper
 | 
				
			||||||
# LLAMA_API void llama_sample_top_p(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep);
 | 
					# "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
 | 
				
			||||||
 | 
					# LLAMA_API void llama_sample_top_p(struct llama_context * ctx,
 | 
				
			||||||
 | 
					# llama_token_data_array * candidates, float p, size_t min_keep);
 | 
				
			||||||
def llama_sample_top_p(
 | 
					def llama_sample_top_p(
 | 
				
			||||||
    ctx: llama_context_p,
 | 
					    ctx: llama_context_p,
 | 
				
			||||||
    candidates,  # type: _Pointer[llama_token_data_array]
 | 
					    candidates,  # type: _Pointer[llama_token_data_array]
 | 
				
			||||||
| 
						 | 
					@ -711,8 +750,10 @@ _lib.llama_sample_top_p.argtypes = [
 | 
				
			||||||
_lib.llama_sample_top_p.restype = None
 | 
					_lib.llama_sample_top_p.restype = None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# @details Tail Free Sampling described in https://www.trentonbricken.com/Tail-Free-Sampling/.
 | 
					# @details Tail Free Sampling described in
 | 
				
			||||||
# LLAMA_API void llama_sample_tail_free(struct llama_context * ctx, llama_token_data_array * candidates, float z, size_t min_keep);
 | 
					# https://www.trentonbricken.com/Tail-Free-Sampling/.
 | 
				
			||||||
 | 
					# LLAMA_API void llama_sample_tail_free(struct llama_context * ctx,
 | 
				
			||||||
 | 
					# llama_token_data_array * candidates, float z, size_t min_keep);
 | 
				
			||||||
def llama_sample_tail_free(
 | 
					def llama_sample_tail_free(
 | 
				
			||||||
    ctx: llama_context_p,
 | 
					    ctx: llama_context_p,
 | 
				
			||||||
    candidates,  # type: _Pointer[llama_token_data_array]
 | 
					    candidates,  # type: _Pointer[llama_token_data_array]
 | 
				
			||||||
| 
						 | 
					@ -731,8 +772,10 @@ _lib.llama_sample_tail_free.argtypes = [
 | 
				
			||||||
_lib.llama_sample_tail_free.restype = None
 | 
					_lib.llama_sample_tail_free.restype = None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# @details Locally Typical Sampling implementation described in the paper https://arxiv.org/abs/2202.00666.
 | 
					# @details Locally Typical Sampling implementation described in the paper
 | 
				
			||||||
# LLAMA_API void llama_sample_typical(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep);
 | 
					# https://arxiv.org/abs/2202.00666.
 | 
				
			||||||
 | 
					# LLAMA_API void llama_sample_typical(struct llama_context * ctx,
 | 
				
			||||||
 | 
					# llama_token_data_array * candidates, float p, size_t min_keep);
 | 
				
			||||||
def llama_sample_typical(
 | 
					def llama_sample_typical(
 | 
				
			||||||
    ctx: llama_context_p,
 | 
					    ctx: llama_context_p,
 | 
				
			||||||
    candidates,  # type: _Pointer[llama_token_data_array]
 | 
					    candidates,  # type: _Pointer[llama_token_data_array]
 | 
				
			||||||
| 
						 | 
					@ -751,7 +794,8 @@ _lib.llama_sample_typical.argtypes = [
 | 
				
			||||||
_lib.llama_sample_typical.restype = None
 | 
					_lib.llama_sample_typical.restype = None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# LLAMA_API void llama_sample_temperature(struct llama_context * ctx, llama_token_data_array * candidates, float temp);
 | 
					# LLAMA_API void llama_sample_temperature(struct llama_context * ctx,
 | 
				
			||||||
 | 
					# llama_token_data_array * candidates, float temp);
 | 
				
			||||||
def llama_sample_temperature(
 | 
					def llama_sample_temperature(
 | 
				
			||||||
    ctx: llama_context_p,
 | 
					    ctx: llama_context_p,
 | 
				
			||||||
    candidates,  # type: _Pointer[llama_token_data_array]
 | 
					    candidates,  # type: _Pointer[llama_token_data_array]
 | 
				
			||||||
| 
						 | 
					@ -768,13 +812,25 @@ _lib.llama_sample_temperature.argtypes = [
 | 
				
			||||||
_lib.llama_sample_temperature.restype = None
 | 
					_lib.llama_sample_temperature.restype = None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# @details Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words.
 | 
					# @details Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966.
 | 
				
			||||||
# @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text.
 | 
					# Uses tokens instead of words.
 | 
				
			||||||
# @param tau  The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text.
 | 
					# @param candidates A vector of `llama_token_data` containing the candidate tokens,
 | 
				
			||||||
# @param eta The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates.
 | 
					# their probabilities (p), and log-odds (logit) for the current position in the generated text.
 | 
				
			||||||
# @param m The number of tokens considered in the estimation of `s_hat`. This is an arbitrary value that is used to calculate `s_hat`, which in turn helps to calculate the value of `k`. In the paper, they use `m = 100`, but you can experiment with different values to see how it affects the performance of the algorithm.
 | 
					# @param tau  The target cross-entropy (or surprise) value you want to achieve for the generated
 | 
				
			||||||
# @param mu Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal.
 | 
					# text. A higher value corresponds to more surprising or less predictable text, while a lower
 | 
				
			||||||
# LLAMA_API llama_token llama_sample_token_mirostat(struct llama_context * ctx, llama_token_data_array * candidates, float tau, float eta, int m, float * mu);
 | 
					# value corresponds to less surprising or more predictable text.
 | 
				
			||||||
 | 
					# @param eta The learning rate used to update `mu` based on the error between the target and
 | 
				
			||||||
 | 
					# observed surprisal of the sampled word. A larger learning rate will cause `mu` to be
 | 
				
			||||||
 | 
					# updated more quickly, while a smaller learning rate will result in slower updates.
 | 
				
			||||||
 | 
					# @param m The number of tokens considered in the estimation of `s_hat`. This is an arbitrary value
 | 
				
			||||||
 | 
					# that is used to calculate `s_hat`, which in turn helps to calculate the value of `k`.
 | 
				
			||||||
 | 
					# In the paper, they use `m = 100`, but you can experiment with different values to see
 | 
				
			||||||
 | 
					# how it affects the performance of the algorithm.
 | 
				
			||||||
 | 
					# @param mu Maximum cross-entropy. This value is initialized to be twice the target cross-entropy
 | 
				
			||||||
 | 
					# (`2 * tau`) and is updated in the algorithm based on the error between the target and
 | 
				
			||||||
 | 
					# observed surprisal.
 | 
				
			||||||
 | 
					# LLAMA_API llama_token llama_sample_token_mirostat(struct llama_context * ctx,
 | 
				
			||||||
 | 
					# llama_token_data_array * candidates, float tau, float eta, int m, float * mu);
 | 
				
			||||||
def llama_sample_token_mirostat(
 | 
					def llama_sample_token_mirostat(
 | 
				
			||||||
    ctx: llama_context_p,
 | 
					    ctx: llama_context_p,
 | 
				
			||||||
    candidates,  # type: _Pointer[llama_token_data_array]
 | 
					    candidates,  # type: _Pointer[llama_token_data_array]
 | 
				
			||||||
| 
						 | 
					@ -797,12 +853,21 @@ _lib.llama_sample_token_mirostat.argtypes = [
 | 
				
			||||||
_lib.llama_sample_token_mirostat.restype = llama_token
 | 
					_lib.llama_sample_token_mirostat.restype = llama_token
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# @details Mirostat 2.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words.
 | 
					# @details Mirostat 2.0 algorithm described in the paper https://arxiv.org/abs/2007.14966.
 | 
				
			||||||
# @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text.
 | 
					# Uses tokens instead of words.
 | 
				
			||||||
# @param tau  The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text.
 | 
					# @param candidates A vector of `llama_token_data` containing the candidate tokens,
 | 
				
			||||||
# @param eta The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates.
 | 
					# their probabilities (p), and log-odds (logit) for the current position in the generated text.
 | 
				
			||||||
# @param mu Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal.
 | 
					# @param tau  The target cross-entropy (or surprise) value you want to achieve for the generated
 | 
				
			||||||
# LLAMA_API llama_token llama_sample_token_mirostat_v2(struct llama_context * ctx, llama_token_data_array * candidates, float tau, float eta, float * mu);
 | 
					# text. A higher value corresponds to more surprising or less predictable text, while a lower value
 | 
				
			||||||
 | 
					# corresponds to less surprising or more predictable text.
 | 
				
			||||||
 | 
					# @param eta The learning rate used to update `mu` based on the error between the target and
 | 
				
			||||||
 | 
					# observed surprisal of the sampled word. A larger learning rate will cause `mu` to be
 | 
				
			||||||
 | 
					# updated more quickly, while a smaller learning rate will result in slower updates.
 | 
				
			||||||
 | 
					# @param mu Maximum cross-entropy. This value is initialized to be twice the target cross-entropy
 | 
				
			||||||
 | 
					# (`2 * tau`) and is updated in the algorithm based on the error between the target
 | 
				
			||||||
 | 
					# and observed surprisal.
 | 
				
			||||||
 | 
					# LLAMA_API llama_token llama_sample_token_mirostat_v2(struct llama_context * ctx,
 | 
				
			||||||
 | 
					# llama_token_data_array * candidates, float tau, float eta, float * mu);
 | 
				
			||||||
def llama_sample_token_mirostat_v2(
 | 
					def llama_sample_token_mirostat_v2(
 | 
				
			||||||
    ctx: llama_context_p,
 | 
					    ctx: llama_context_p,
 | 
				
			||||||
    candidates,  # type: _Pointer[llama_token_data_array]
 | 
					    candidates,  # type: _Pointer[llama_token_data_array]
 | 
				
			||||||
| 
						 | 
					@ -824,7 +889,8 @@ _lib.llama_sample_token_mirostat_v2.restype = llama_token
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# @details Selects the token with the highest probability.
 | 
					# @details Selects the token with the highest probability.
 | 
				
			||||||
# LLAMA_API llama_token llama_sample_token_greedy(struct llama_context * ctx, llama_token_data_array * candidates);
 | 
					# LLAMA_API llama_token llama_sample_token_greedy(struct llama_context * ctx,
 | 
				
			||||||
 | 
					# llama_token_data_array * candidates);
 | 
				
			||||||
def llama_sample_token_greedy(
 | 
					def llama_sample_token_greedy(
 | 
				
			||||||
    ctx: llama_context_p,
 | 
					    ctx: llama_context_p,
 | 
				
			||||||
    candidates,  # type: _Pointer[llama_token_data_array]
 | 
					    candidates,  # type: _Pointer[llama_token_data_array]
 | 
				
			||||||
| 
						 | 
					@ -840,7 +906,8 @@ _lib.llama_sample_token_greedy.restype = llama_token
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# @details Randomly selects a token from the candidates based on their probabilities.
 | 
					# @details Randomly selects a token from the candidates based on their probabilities.
 | 
				
			||||||
# LLAMA_API llama_token llama_sample_token(struct llama_context * ctx, llama_token_data_array * candidates);
 | 
					# LLAMA_API llama_token llama_sample_token(struct llama_context * ctx,
 | 
				
			||||||
 | 
					# llama_token_data_array * candidates);
 | 
				
			||||||
def llama_sample_token(
 | 
					def llama_sample_token(
 | 
				
			||||||
    ctx: llama_context_p,
 | 
					    ctx: llama_context_p,
 | 
				
			||||||
    candidates,  # type: _Pointer[llama_token_data_array]
 | 
					    candidates,  # type: _Pointer[llama_token_data_array]
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -13,6 +13,32 @@
 | 
				
			||||||
# See the License for the specific language governing permissions and
 | 
					# See the License for the specific language governing permissions and
 | 
				
			||||||
# limitations under the License.
 | 
					# limitations under the License.
 | 
				
			||||||
#
 | 
					#
 | 
				
			||||||
 | 
					# ===========================================================================
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# This file is adapted from
 | 
				
			||||||
 | 
					# https://github.com/abetlen/llama-cpp-python/blob/main/llama_cpp/llama_types.py
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# MIT License
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Copyright (c) 2023 Andrei Betlen
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# Permission is hereby granted, free of charge, to any person obtaining a copy
 | 
				
			||||||
 | 
					# of this software and associated documentation files (the "Software"), to deal
 | 
				
			||||||
 | 
					# in the Software without restriction, including without limitation the rights
 | 
				
			||||||
 | 
					# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
 | 
				
			||||||
 | 
					# copies of the Software, and to permit persons to whom the Software is
 | 
				
			||||||
 | 
					# furnished to do so, subject to the following conditions:
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# The above copyright notice and this permission notice shall be included in all
 | 
				
			||||||
 | 
					# copies or substantial portions of the Software.
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
 | 
				
			||||||
 | 
					# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
 | 
				
			||||||
 | 
					# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
 | 
				
			||||||
 | 
					# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
 | 
				
			||||||
 | 
					# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
 | 
				
			||||||
 | 
					# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
 | 
				
			||||||
 | 
					# SOFTWARE.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# This would makes sure Python is aware there is more than one sub-package within bigdl,
 | 
					# This would makes sure Python is aware there is more than one sub-package within bigdl,
 | 
				
			||||||
# physically located elsewhere.
 | 
					# physically located elsewhere.
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in a new issue