[LLM] Unify Langchain Native and Transformers LLM API (#8752)
* deprecate BigDLNativeTransformers and add specific LMEmbedding method * deprecate and add LM methods for langchain llms * add native params to native langchain * new imple for embedding * move ut from bigdlnative to casual llm * rename embeddings api and examples update align with usage updating * docqa example hot-fix * add more api docs * add langchain ut for starcoder * support model_kwargs for transformer methods when calling causalLM and add ut * ut fix for transformers embedding * update for langchain causal supporting transformers * remove model_family in readme doc * add model_families params to support more models * update api docs and remove chatglm embeddings for now * remove chatglm embeddings in examples * new refactor for ut to add bloom and transformers llama ut * disable llama transformers embedding ut
This commit is contained in:
parent
5582872744
commit
d2926c7672
11 changed files with 747 additions and 75 deletions
12
.github/workflows/llm_unit_tests.yml
vendored
12
.github/workflows/llm_unit_tests.yml
vendored
|
|
@ -77,6 +77,8 @@ jobs:
|
|||
run: |
|
||||
echo "SPEECH_DATASET_PATH=${DATASET_DIR}/librispeech_asr_dummy" >> "$GITHUB_ENV"
|
||||
|
||||
echo "LLAMA_ORIGIN_PATH=${ORIGIN_DIR}/llama-7b-hf" >> "$GITHUB_ENV"
|
||||
echo "BLOOM_ORIGIN_PATH=${ORIGIN_DIR}/bloom-7b1" >> "$GITHUB_ENV"
|
||||
echo "ORIGINAL_CHATGLM2_6B_PATH=${ORIGIN_DIR}/chatglm2-6b" >> "$GITHUB_ENV"
|
||||
echo "ORIGINAL_REPLIT_CODE_PATH=${ORIGIN_DIR}/replit-code-v1-3b" >> "$GITHUB_ENV"
|
||||
echo "ORIGINAL_WHISPER_TINY_PATH=${ORIGIN_DIR}/whisper-tiny" >> "$GITHUB_ENV"
|
||||
|
|
@ -143,6 +145,16 @@ jobs:
|
|||
echo "wget -r -nH --no-verbose --cut-dirs=1 $LLM_FTP_URL/llm/whisper-tiny -P $ORIGIN_DIR"
|
||||
wget -r -nH --no-verbose --cut-dirs=1 $LLM_FTP_URL/llm/whisper-tiny -P $ORIGIN_DIR
|
||||
fi
|
||||
if [ ! -d $LLAMA_ORIGIN_PATH ]; then
|
||||
echo "Directory $LLAMA_ORIGIN_PATH not found. Downloading from FTP server..."
|
||||
echo "wget --no-verbose $LLM_FTP_URL/llm/llama-7b-hf -P $ORIGIN_DIR"
|
||||
wget -r -nH --no-verbose --cut-dirs=1 $LLM_FTP_URL/llm/llama-7b-hf -P $ORIGIN_DIR
|
||||
fi
|
||||
if [ ! -d $BLOOM_ORIGIN_PATH ]; then
|
||||
echo "Directory $BLOOM_ORIGIN_PATH not found. Downloading from FTP server..."
|
||||
echo "wget --no-verbose $LLM_FTP_URL/llm/bloom-7b1 -P $ORIGIN_DIR"
|
||||
wget -r -nH --no-verbose --cut-dirs=1 $LLM_FTP_URL/llm/bloom-7b1 -P $ORIGIN_DIR
|
||||
fi
|
||||
if [ ! -d $SPEECH_DATASET_PATH ]; then
|
||||
echo "Directory $SPEECH_DATASET_PATH not found. Downloading from FTP server..."
|
||||
echo "wget -r -nH --no-verbose --cut-dirs=2 $LLM_FTP_URL/llm/datasets/librispeech_asr_dummy -P $DATASET_DIR"
|
||||
|
|
|
|||
|
|
@ -154,19 +154,23 @@ You may run the models using the LangChain API in `bigdl-llm`.
|
|||
|
||||
- **Using native INT4 format**
|
||||
|
||||
You may also convert Hugging Face *Transformers* models into *native INT4* format (currently only *llama*/*bloom*/*gptneox*/*starcoder* model family is supported), and then run the converted models using the LangChain API as follows.
|
||||
You may also convert Hugging Face *Transformers* models into *native INT4* format, and then run the converted models using the LangChain API as follows.
|
||||
|
||||
>**Note**: Currently only llama/bloom/gptneox/starcoder model family is supported; for other models, you may use the Transformers INT4 format as described above).
|
||||
>**Notes**:
|
||||
|
||||
>* Currently only llama/bloom/gptneox/starcoder/chatglm model families are supported; for other models, you may use the Hugging Face `transformers` INT4 format as described above).
|
||||
|
||||
>* You may choose the corresponding API developed for specific native models to load the converted model.
|
||||
|
||||
```python
|
||||
from bigdl.llm.langchain.llms import BigdlNativeLLM
|
||||
from bigdl.llm.langchain.embeddings import BigdlNativeEmbeddings
|
||||
from bigdl.llm.langchain.llms import LlamaLLM
|
||||
from bigdl.llm.langchain.embeddings import LlamaEmbeddings
|
||||
from langchain.chains.question_answering import load_qa_chain
|
||||
|
||||
embeddings = BigdlNativeEmbeddings(model_path='/path/to/converted/model.bin',
|
||||
model_family="llama",...)
|
||||
bigdl_llm = BigdlNativeLLM(model_path='/path/to/converted/model.bin',
|
||||
model_family="llama",...)
|
||||
#switch to ChatGLMEmbeddings/GptneoxEmbeddings/BloomEmbeddings/StarcoderEmbeddings to load other models
|
||||
embeddings = LlamaEmbeddings(model_path='/path/to/converted/model.bin')
|
||||
#switch to ChatGLMLLM/GptneoxLLM/BloomLLM/StarcoderLLM to load other models
|
||||
bigdl_llm = LlamaLLM(model_path='/path/to/converted/model.bin')
|
||||
|
||||
doc_chain = load_qa_chain(bigdl_llm, ...)
|
||||
doc_chain.run(...)
|
||||
|
|
|
|||
|
|
@ -31,21 +31,19 @@ from langchain.chains.question_answering import load_qa_chain
|
|||
from langchain.callbacks.manager import CallbackManager
|
||||
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
|
||||
|
||||
from bigdl.llm.langchain.llms import BigdlNativeLLM
|
||||
from bigdl.llm.langchain.embeddings import BigdlNativeEmbeddings
|
||||
|
||||
from bigdl.llm.langchain.llms import *
|
||||
from bigdl.llm.langchain.embeddings import *
|
||||
|
||||
|
||||
def main(args):
|
||||
|
||||
input_path = args.input_path
|
||||
model_path = args.model_path
|
||||
model_path = args.model_path
|
||||
model_family = args.model_family
|
||||
query = args.question
|
||||
n_ctx = args.n_ctx
|
||||
n_threads=args.thread_num
|
||||
|
||||
|
||||
callback_manager = CallbackManager([StreamingStdOutCallbackHandler()])
|
||||
|
||||
# split texts of input doc
|
||||
|
|
@ -54,15 +52,35 @@ def main(args):
|
|||
text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
|
||||
texts = text_splitter.split_text(input_doc)
|
||||
|
||||
model_family_to_embeddings = {
|
||||
"llama": LlamaEmbeddings,
|
||||
"gptneox": GptneoxEmbeddings,
|
||||
"bloom": BloomEmbeddings,
|
||||
"starcoder": StarcoderEmbeddings
|
||||
}
|
||||
|
||||
model_family_to_llm = {
|
||||
"llama": LlamaLLM,
|
||||
"gptneox": GptneoxLLM,
|
||||
"bloom": BloomLLM,
|
||||
"starcoder": StarcoderLLM
|
||||
}
|
||||
|
||||
if model_family in model_family_to_embeddings and model_family in model_family_to_llm:
|
||||
llm_embeddings = model_family_to_embeddings[model_family]
|
||||
langchain_llm = model_family_to_llm[model_family]
|
||||
else:
|
||||
raise ValueError(f"Unknown model family: {model_family}")
|
||||
|
||||
# create embeddings and store into vectordb
|
||||
embeddings = BigdlNativeEmbeddings(model_path=model_path, model_family=model_family, n_threads=n_threads, n_ctx=n_ctx)
|
||||
embeddings = llm_embeddings(model_path=model_path, n_threads=n_threads, n_ctx=n_ctx)
|
||||
docsearch = Chroma.from_texts(texts, embeddings, metadatas=[{"source": str(i)} for i in range(len(texts))]).as_retriever()
|
||||
|
||||
#get relavant texts
|
||||
# get relavant texts
|
||||
docs = docsearch.get_relevant_documents(query)
|
||||
|
||||
bigdl_llm = BigdlNativeLLM(
|
||||
model_path=model_path, model_family=model_family, n_ctx=n_ctx, n_threads=n_threads, callback_manager=callback_manager
|
||||
|
||||
bigdl_llm = langchain_llm(
|
||||
model_path=model_path, n_ctx=n_ctx, n_threads=n_threads, callback_manager=callback_manager
|
||||
)
|
||||
|
||||
doc_chain = load_qa_chain(
|
||||
|
|
@ -73,9 +91,9 @@ def main(args):
|
|||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description='BigdlNativeLLM Langchain QA over Docs Example')
|
||||
parser = argparse.ArgumentParser(description='BigDLCausalLM Langchain QA over Docs Example')
|
||||
parser.add_argument('-x','--model-family', type=str, required=True,
|
||||
choices=["llama", "bloom", "gptneox"],
|
||||
choices=["llama", "bloom", "gptneox", "chatglm", "starcoder"],
|
||||
help='the model family')
|
||||
parser.add_argument('-m','--model-path', type=str, required=True,
|
||||
help='the path to the converted llm model')
|
||||
|
|
|
|||
|
|
@ -21,7 +21,7 @@
|
|||
|
||||
import argparse
|
||||
|
||||
from bigdl.llm.langchain.llms import BigdlNativeLLM
|
||||
from bigdl.llm.langchain.llms import *
|
||||
from langchain import PromptTemplate, LLMChain
|
||||
from langchain.callbacks.manager import CallbackManager
|
||||
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
|
||||
|
|
@ -39,11 +39,23 @@ def main(args):
|
|||
|
||||
# Callbacks support token-wise streaming
|
||||
callback_manager = CallbackManager([StreamingStdOutCallbackHandler()])
|
||||
|
||||
model_family_to_llm = {
|
||||
"llama": LlamaLLM,
|
||||
"gptneox": GptneoxLLM,
|
||||
"bloom": BloomLLM,
|
||||
"starcoder": StarcoderLLM,
|
||||
"chatglm": ChatGLMLLM
|
||||
}
|
||||
|
||||
if model_family in model_family_to_llm:
|
||||
langchain_llm = model_family_to_llm[model_family]
|
||||
else:
|
||||
raise ValueError(f"Unknown model family: {model_family}")
|
||||
|
||||
# Verbose is required to pass to the callback manager
|
||||
llm = BigdlNativeLLM(
|
||||
llm = langchain_llm(
|
||||
model_path=model_path,
|
||||
model_family=model_family,
|
||||
n_threads=n_threads,
|
||||
callback_manager=callback_manager,
|
||||
verbose=True
|
||||
|
|
@ -55,9 +67,9 @@ def main(args):
|
|||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description='BigdlNativeLLM Langchain Streaming Chat Example')
|
||||
parser = argparse.ArgumentParser(description='BigDLCausalLM Langchain Streaming Chat Example')
|
||||
parser.add_argument('-x','--model-family', type=str, required=True,
|
||||
choices=["llama", "bloom", "gptneox"],
|
||||
choices=["llama", "bloom", "gptneox", "chatglm", "starcoder"],
|
||||
help='the model family')
|
||||
parser.add_argument('-m','--model-path', type=str, required=True,
|
||||
help='the path to the converted llm model')
|
||||
|
|
|
|||
|
|
@ -23,7 +23,7 @@
|
|||
|
||||
|
||||
from langchain import LLMChain, PromptTemplate
|
||||
from bigdl.llm.langchain.llms import BigdlNativeLLM
|
||||
from bigdl.llm.langchain.llms import *
|
||||
from langchain.memory import ConversationBufferWindowMemory
|
||||
from langchain.callbacks.manager import CallbackManager
|
||||
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
|
||||
|
|
@ -35,7 +35,6 @@ import argparse
|
|||
def prepare_chain(args):
|
||||
|
||||
model_path = args.model_path
|
||||
model_family = args.model_family
|
||||
n_threads = args.thread_num
|
||||
n_ctx = args.context_size
|
||||
|
||||
|
|
@ -48,11 +47,23 @@ def prepare_chain(args):
|
|||
A:"""
|
||||
prompt = PromptTemplate(input_variables=["history", "human_input"], template=template)
|
||||
|
||||
# We use our BigdlNativeLLM to subsititute OpenAI web-required API
|
||||
# We use our BigDLCausalLLM to subsititute OpenAI web-required API
|
||||
model_family_to_llm = {
|
||||
"llama": LlamaLLM,
|
||||
"gptneox": GptneoxLLM,
|
||||
"bloom": BloomLLM,
|
||||
"starcoder": StarcoderLLM,
|
||||
"chatglm": ChatGLMLLM
|
||||
}
|
||||
|
||||
if model_family in model_family_to_llm:
|
||||
langchain_llm = model_family_to_llm[model_family]
|
||||
else:
|
||||
raise ValueError(f"Unknown model family: {model_family}")
|
||||
|
||||
callback_manager = CallbackManager([StreamingStdOutCallbackHandler()])
|
||||
llm = BigdlNativeLLM(
|
||||
llm = langchain_llm(
|
||||
model_path=model_path,
|
||||
model_family=model_family,
|
||||
n_threads=n_threads,
|
||||
callback_manager=callback_manager,
|
||||
verbose=True,
|
||||
|
|
@ -114,8 +125,9 @@ def main(args):
|
|||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description='BigdlNativeLLM Langchain Voice Assistant Example')
|
||||
parser = argparse.ArgumentParser(description='BigDLCausalLM Langchain Voice Assistant Example')
|
||||
parser.add_argument('-x','--model-family', type=str, required=True,
|
||||
choices=["llama", "bloom", "gptneox", "chatglm", "starcoder"],
|
||||
help='the model family')
|
||||
parser.add_argument('-m','--model-path', type=str, required=True,
|
||||
help='the path to the converted llm model')
|
||||
|
|
|
|||
|
|
@ -19,10 +19,14 @@
|
|||
# Otherwise there would be module not found error in non-pip's setting as Python would
|
||||
# only search the first bigdl package and end up finding only one sub-package.
|
||||
|
||||
from .bigdlllm import BigdlNativeEmbeddings
|
||||
from .bigdlllm import *
|
||||
from .transformersembeddings import TransformersEmbeddings
|
||||
|
||||
__all__ = [
|
||||
"BigdlNativeEmbeddings",
|
||||
"LlamaEmbeddings",
|
||||
"BloomEmbeddings",
|
||||
"GptneoxEmbeddings",
|
||||
"StarcoderEmbeddings",
|
||||
"TransformersEmbeddings"
|
||||
]
|
||||
|
|
|
|||
|
|
@ -45,12 +45,14 @@
|
|||
# THE SOFTWARE.
|
||||
|
||||
"""Wrapper around BigdlNative embedding models."""
|
||||
import logging
|
||||
import importlib
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import BaseModel, Extra, Field, root_validator
|
||||
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from .transformersembeddings import TransformersEmbeddings
|
||||
|
||||
|
||||
class BigdlNativeEmbeddings(BaseModel, Embeddings):
|
||||
|
|
@ -63,18 +65,25 @@ class BigdlNativeEmbeddings(BaseModel, Embeddings):
|
|||
llama = BigdlNativeEmbeddings(model_path="/path/to/model.bin")
|
||||
"""
|
||||
|
||||
logging.warning("BigdlNativeEmbeddings has been deprecated, "
|
||||
"please switch to the new LMEmbeddings API for sepcific models.")
|
||||
|
||||
model_family: str = "llama"
|
||||
"""the model family"""
|
||||
"""The model family: currently supports llama, gptneox, bloom and starcoder."""
|
||||
|
||||
family_info = {
|
||||
'llama': {'module': "bigdl.llm.models", 'class': "Llama"},
|
||||
'bloom': {'module': "bigdl.llm.models", 'class': "Bloom"},
|
||||
'gptneox': {'module': "bigdl.llm.models", 'class': "Gptneox"},
|
||||
'starcoder': {'module':"bigdl.llm.models", 'class': "Starcoder"},
|
||||
} #: :meta private:
|
||||
"""info necessary for different model family initiation and configure"""
|
||||
"""Info necessary for different model family initiation and configure."""
|
||||
|
||||
client: Any #: :meta private:
|
||||
"""The actual model."""
|
||||
|
||||
model_path: str # TODO: missing doc
|
||||
"""Path to the converted BigDL-LLM optimized ggml binary checkpoint."""
|
||||
|
||||
n_ctx: int = Field(512, alias="n_ctx")
|
||||
"""Token context window."""
|
||||
|
|
@ -159,7 +168,7 @@ class BigdlNativeEmbeddings(BaseModel, Embeddings):
|
|||
)
|
||||
except Exception as e:
|
||||
raise ValueError(
|
||||
f"Could not load Llama model from path: {model_path}. "
|
||||
f"Could not load model from path: {model_path}. "
|
||||
f"Please make sure the model family {model_family} matches "
|
||||
"the model you want to load."
|
||||
f"Received error {e}"
|
||||
|
|
@ -190,3 +199,186 @@ class BigdlNativeEmbeddings(BaseModel, Embeddings):
|
|||
"""
|
||||
embedding = self.client.embed(text)
|
||||
return list(map(float, embedding))
|
||||
|
||||
|
||||
class _BaseEmbeddings(BaseModel, Embeddings):
|
||||
"""Wrapper around bigdl-llm embedding models.
|
||||
|
||||
param model_path: If running with ``native int4``, the path should be converted BigDL-LLM
|
||||
optimized ggml binary checkpoint, which should be converted by ``bigdl.llm.llm_convert``.
|
||||
If running with ``transformers int4``, the path should be the huggingface repo id
|
||||
to be downloaded or the huggingface checkpoint folder.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from bigdl.llm.langchain.embeddings import LlamaEmbeddings
|
||||
llama = LlamaEmbeddings(model_path="/path/to/model.bin")
|
||||
"""
|
||||
|
||||
ggml_model: str = None
|
||||
ggml_module: str = None
|
||||
|
||||
native: bool = True
|
||||
"""Load model to either BigDL-LLM optimized Transformers or Native (ggml) int4."""
|
||||
|
||||
client: Any #: :meta private:
|
||||
"""The actual model."""
|
||||
|
||||
model_kwargs: Optional[dict] = None
|
||||
"""Key word arguments to pass to the Transformers model."""
|
||||
|
||||
encode_kwargs: Optional[dict] = None
|
||||
"""Key word arguments to pass when calling the `encode` method of the Transformers model."""
|
||||
|
||||
kwargs: Any
|
||||
"""Additional key word arguments passed to TransformersLLM."""
|
||||
|
||||
model_path: str
|
||||
"""Path to the loading model file.
|
||||
If native, the path shoule be converted BigDL-LLM optimized ggml binary checkpoint.
|
||||
If transformers, the path should be the huggingface repo id to be downloaded
|
||||
or the huggingface checkpoint folder."""
|
||||
|
||||
n_ctx: int = Field(512, alias="n_ctx")
|
||||
"""Token context window."""
|
||||
|
||||
n_parts: int = Field(-1, alias="n_parts")
|
||||
"""Number of parts to split the model into.
|
||||
If -1, the number of parts is automatically determined."""
|
||||
|
||||
seed: int = Field(-1, alias="seed")
|
||||
"""Seed. If -1, a random seed is used."""
|
||||
|
||||
f16_kv: bool = Field(True, alias="f16_kv")
|
||||
"""Use half-precision for key/value cache."""
|
||||
|
||||
logits_all: bool = Field(False, alias="logits_all")
|
||||
"""Return logits for all tokens, not just the last token."""
|
||||
|
||||
vocab_only: bool = Field(False, alias="vocab_only")
|
||||
"""Only load the vocabulary, no weights."""\
|
||||
|
||||
use_mlock: bool = Field(False, alias="use_mlock")
|
||||
"""Force system to keep model in RAM."""
|
||||
|
||||
n_threads: Optional[int] = Field(2, alias="n_threads")
|
||||
"""Number of threads to use."""
|
||||
|
||||
n_batch: Optional[int] = Field(512, alias="n_batch")
|
||||
"""Number of tokens to process in parallel.
|
||||
Should be a number between 1 and n_ctx."""
|
||||
|
||||
n_gpu_layers: Optional[int] = Field(0, alias="n_gpu_layers")
|
||||
"""Number of layers to be loaded into gpu memory. Default None."""
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
extra = Extra.forbid
|
||||
|
||||
@root_validator()
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that bigdl-llm library is installed."""
|
||||
|
||||
native = values["native"]
|
||||
model_path = values["model_path"]
|
||||
model_kwargs = values["model_kwargs"]
|
||||
kwargs = values["kwargs"]
|
||||
model_param_names = [
|
||||
"n_ctx",
|
||||
"n_parts",
|
||||
"seed",
|
||||
"f16_kv",
|
||||
"logits_all",
|
||||
"vocab_only",
|
||||
"use_mlock",
|
||||
"n_threads",
|
||||
"n_batch",
|
||||
]
|
||||
model_params = {k: values[k] for k in model_param_names}
|
||||
# For backwards compatibility, only include if non-null.
|
||||
if values["n_gpu_layers"] is not None:
|
||||
model_params["n_gpu_layers"] = values["n_gpu_layers"]
|
||||
|
||||
try:
|
||||
|
||||
module = importlib.import_module(values["ggml_module"])
|
||||
class_ = getattr(module, values["ggml_model"])
|
||||
|
||||
if native:
|
||||
values["client"] = class_(model_path, embedding=True, **model_params)
|
||||
else:
|
||||
kwargs = {} if kwargs is None else kwargs
|
||||
values["client"] = TransformersEmbeddings.from_model_id(model_path, model_kwargs,
|
||||
**kwargs)
|
||||
|
||||
# from bigdl.llm.ggml.model.llama import Llama
|
||||
|
||||
# values["client"] = Llama(model_path, embedding=True, **model_params)
|
||||
|
||||
except ImportError:
|
||||
raise ModuleNotFoundError(
|
||||
"Could not import bigdl-llm library. "
|
||||
"Please install the bigdl-llm library to "
|
||||
"use this embedding model: pip install bigdl-llm"
|
||||
)
|
||||
except Exception as e:
|
||||
raise ValueError(
|
||||
f"Could not load model from path: {model_path}. "
|
||||
f"Please make sure the model embedding class matches "
|
||||
"the model you want to load."
|
||||
f"Received error {e}"
|
||||
)
|
||||
|
||||
return values
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Embed a list of documents using the optimized int4 model.
|
||||
|
||||
Args:
|
||||
texts: The list of texts to embed.
|
||||
|
||||
Returns:
|
||||
List of embeddings, one for each text.
|
||||
"""
|
||||
if self.native:
|
||||
embeddings = [self.client.embed(text) for text in texts]
|
||||
return [list(map(float, e)) for e in embeddings]
|
||||
else:
|
||||
return self.client.embed_documents(texts)
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
"""Embed a query using the optimized int4 model.
|
||||
|
||||
Args:
|
||||
text: The text to embed.
|
||||
|
||||
Returns:
|
||||
Embeddings for the text.
|
||||
"""
|
||||
if self.native:
|
||||
embedding = self.client.embed(text)
|
||||
return list(map(float, embedding))
|
||||
else:
|
||||
return self.client.embed_query(text)
|
||||
|
||||
|
||||
class LlamaEmbeddings(_BaseEmbeddings):
|
||||
ggml_model = "Llama"
|
||||
ggml_module = "bigdl.llm.models"
|
||||
|
||||
|
||||
class BloomEmbeddings(_BaseEmbeddings):
|
||||
ggml_model = "Bloom"
|
||||
ggml_module = "bigdl.llm.models"
|
||||
|
||||
|
||||
class GptneoxEmbeddings(_BaseEmbeddings):
|
||||
ggml_model = "Gptneox"
|
||||
ggml_module = "bigdl.llm.models"
|
||||
|
||||
|
||||
class StarcoderEmbeddings(_BaseEmbeddings):
|
||||
ggml_model = "Starcoder"
|
||||
ggml_module = "bigdl.llm.models"
|
||||
|
|
|
|||
|
|
@ -23,18 +23,28 @@
|
|||
from typing import Dict, Type
|
||||
from langchain.llms.base import BaseLLM
|
||||
|
||||
from .bigdlllm import BigdlNativeLLM
|
||||
from .bigdlllm import *
|
||||
from .transformersllm import TransformersLLM
|
||||
from .transformerspipelinellm import TransformersPipelineLLM
|
||||
|
||||
__all__ = [
|
||||
"BigdlNativeLLM",
|
||||
"LlamaLLM",
|
||||
"BloomLLM",
|
||||
"GptneoxLLM",
|
||||
"ChatGLMLLM",
|
||||
"StarcoderLLM",
|
||||
"TransformersLLM",
|
||||
"TransformersPipelineLLM"
|
||||
]
|
||||
|
||||
type_to_cls_dict: Dict[str, Type[BaseLLM]] = {
|
||||
"BigdlNativeLLM": BigdlNativeLLM,
|
||||
"LlamaLLM": LlamaLLM,
|
||||
"BloomLLM": BloomLLM,
|
||||
"GptneoxLLM": GptneoxLLM,
|
||||
"ChatGLMLLM": ChatGLMLLM,
|
||||
"StarcoderLLM": StarcoderLLM,
|
||||
"TransformersPipelineLLM": TransformersPipelineLLM,
|
||||
"TransformersLLM": TransformersLLM
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -44,6 +44,7 @@
|
|||
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
# THE SOFTWARE.
|
||||
|
||||
import logging
|
||||
import importlib
|
||||
from typing import Any, Dict, Generator, List, Optional
|
||||
|
||||
|
|
@ -51,7 +52,7 @@ from pydantic import Field, root_validator
|
|||
|
||||
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||
from langchain.llms.base import LLM
|
||||
|
||||
from .transformersllm import TransformersLLM
|
||||
|
||||
|
||||
class BigdlNativeLLM(LLM):
|
||||
|
|
@ -65,22 +66,26 @@ class BigdlNativeLLM(LLM):
|
|||
"""
|
||||
|
||||
|
||||
model_family: str = "llama"
|
||||
"""the model family: currently supports llama, gptneox, and bloom."""
|
||||
logging.warning("BigdlNativeLLM has been deprecated, "
|
||||
"please switch to the new LLM API for sepcific models.")
|
||||
|
||||
model_family: str = "llama"
|
||||
"""The model family: currently supports llama, gptneox, bloom, starcoder and chatglm."""
|
||||
|
||||
family_info = {
|
||||
'llama': {'module': "bigdl.llm.models" , 'class': "Llama"},
|
||||
'bloom': {'module': "bigdl.llm.models", 'class': "Bloom"},
|
||||
'gptneox': {'module': "bigdl.llm.models", 'class': "Gptneox"},
|
||||
'starcoder': {'module':"bigdl.llm.models", 'class': "Starcoder"},
|
||||
'chatglm': {'module':"bigdl.llm.ggml.model.chatglm", 'class': "ChatGLM"},
|
||||
} #: :meta private:
|
||||
"""info necessary for different model families initiation and configure"""
|
||||
"""Info necessary for different model families initiation and configure."""
|
||||
|
||||
client: Any #: :meta private:
|
||||
"""the actual model"""
|
||||
"""The actual model."""
|
||||
|
||||
model_path: str
|
||||
"""The path to the Llama model file."""
|
||||
"""Path to the converted BigDL-LLM optimized ggml binary checkpoint."""
|
||||
|
||||
lora_base: Optional[str] = None
|
||||
"""The path to the Llama LoRA base model."""
|
||||
|
|
@ -197,9 +202,9 @@ class BigdlNativeLLM(LLM):
|
|||
|
||||
except ImportError:
|
||||
raise ModuleNotFoundError(
|
||||
"Could not import llama-cpp-python library. "
|
||||
"Please install the llama-cpp-python library to "
|
||||
"use this embedding model: pip install llama-cpp-python"
|
||||
"Could not import bigdl-llm library. "
|
||||
"Please install the bigdl-llm library to "
|
||||
"use this embedding model: pip install bigdl-llm"
|
||||
)
|
||||
except Exception as e:
|
||||
raise ValueError(
|
||||
|
|
@ -351,3 +356,333 @@ class BigdlNativeLLM(LLM):
|
|||
def get_num_tokens(self, text: str) -> int:
|
||||
tokenized_text = self.client.tokenize(text.encode("utf-8"))
|
||||
return len(tokenized_text)
|
||||
|
||||
|
||||
class _BaseCausalLM(LLM):
|
||||
"""Wrapper around the BigDL-LLM
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain.llms import LlamaLLM
|
||||
llm = LlamaLLM(model_path="/path/to/llama/model")
|
||||
"""
|
||||
|
||||
|
||||
ggml_model: str = None
|
||||
ggml_module: str = None
|
||||
|
||||
native: bool = True
|
||||
"""Load model to either BigDL-LLM optimized Transformers or Native (ggml) int4."""
|
||||
|
||||
client: Any #: :meta private:
|
||||
"""The actual model."""
|
||||
|
||||
model_path: str
|
||||
"""Path to the loading model file.
|
||||
If native, the path shoule be converted BigDL-LLM optimized ggml binary checkpoint.
|
||||
If transformers, the path should be the huggingface repo id to be downloaded
|
||||
or the huggingface checkpoint folder."""
|
||||
|
||||
model_kwargs: Optional[dict] = None
|
||||
"""Key word arguments passed to the Transformers model."""
|
||||
|
||||
kwargs: Any
|
||||
"""Additional key word arguments passed to TransformersLLM."""
|
||||
|
||||
lora_base: Optional[str] = None
|
||||
"""The path to the Llama LoRA base model."""
|
||||
|
||||
lora_path: Optional[str] = None
|
||||
"""The path to the Llama LoRA. If None, no LoRa is loaded."""
|
||||
|
||||
n_ctx: int = Field(512, alias="n_ctx")
|
||||
"""Token context window."""
|
||||
|
||||
n_parts: int = Field(-1, alias="n_parts")
|
||||
"""Number of parts to split the model into.
|
||||
If -1, the number of parts is automatically determined."""
|
||||
|
||||
seed: int = Field(-1, alias="seed")
|
||||
"""Seed. If -1, a random seed is used."""
|
||||
|
||||
f16_kv: bool = Field(True, alias="f16_kv")
|
||||
"""Use half-precision for key/value cache."""
|
||||
|
||||
logits_all: bool = Field(False, alias="logits_all")
|
||||
"""Return logits for all tokens, not just the last token."""
|
||||
|
||||
vocab_only: bool = Field(False, alias="vocab_only")
|
||||
"""Only load the vocabulary, no weights."""
|
||||
|
||||
use_mlock: bool = Field(False, alias="use_mlock")
|
||||
"""Force system to keep model in RAM."""
|
||||
|
||||
n_threads: Optional[int] = Field(2, alias="n_threads")
|
||||
"""Number of threads to use."""
|
||||
|
||||
n_batch: Optional[int] = Field(512, alias="n_batch")
|
||||
"""Number of tokens to process in parallel.
|
||||
Should be a number between 1 and n_ctx."""
|
||||
|
||||
n_gpu_layers: Optional[int] = Field(0, alias="n_gpu_layers")
|
||||
"""Number of layers to be loaded into gpu memory. Default None."""
|
||||
|
||||
suffix: Optional[str] = Field(None)
|
||||
"""A suffix to append to the generated text. If None, no suffix is appended."""
|
||||
|
||||
max_tokens: Optional[int] = 256
|
||||
"""The maximum number of tokens to generate."""
|
||||
|
||||
temperature: Optional[float] = 0.8
|
||||
"""The temperature to use for sampling."""
|
||||
|
||||
top_p: Optional[float] = 0.95
|
||||
"""The top-p value to use for sampling."""
|
||||
|
||||
logprobs: Optional[int] = Field(None)
|
||||
"""The number of logprobs to return. If None, no logprobs are returned."""
|
||||
|
||||
echo: Optional[bool] = False
|
||||
"""Whether to echo the prompt."""
|
||||
|
||||
stop: Optional[List[str]] = []
|
||||
"""A list of strings to stop generation when encountered."""
|
||||
|
||||
repeat_penalty: Optional[float] = 1.1
|
||||
"""The penalty to apply to repeated tokens."""
|
||||
|
||||
top_k: Optional[int] = 40
|
||||
"""The top-k value to use for sampling."""
|
||||
|
||||
last_n_tokens_size: Optional[int] = 64
|
||||
"""The number of tokens to look back when applying the repeat_penalty."""
|
||||
|
||||
use_mmap: Optional[bool] = True
|
||||
"""Whether to keep the model loaded in RAM"""
|
||||
|
||||
streaming: bool = True
|
||||
"""Whether to stream the results, token by token."""
|
||||
|
||||
@root_validator()
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that bigdl-llm is installed, family is supported"""
|
||||
|
||||
native = values["native"]
|
||||
model_path = values["model_path"]
|
||||
model_kwargs = values["model_kwargs"]
|
||||
kwargs = values["kwargs"]
|
||||
model_param_names = [
|
||||
"lora_path",
|
||||
"lora_base",
|
||||
"n_ctx",
|
||||
"n_parts",
|
||||
"seed",
|
||||
"f16_kv",
|
||||
"logits_all",
|
||||
"vocab_only",
|
||||
"use_mlock",
|
||||
"n_threads",
|
||||
"n_batch",
|
||||
"use_mmap",
|
||||
"last_n_tokens_size",
|
||||
]
|
||||
model_params = {k: values[k] for k in model_param_names}
|
||||
# For backwards compatibility, only include if non-null.
|
||||
if values["n_gpu_layers"] is not None:
|
||||
model_params["n_gpu_layers"] = values["n_gpu_layers"]
|
||||
|
||||
try:
|
||||
module = importlib.import_module(values["ggml_module"])
|
||||
class_ = getattr(module, values["ggml_model"])
|
||||
if native:
|
||||
values["client"] = class_(model_path, **model_params)
|
||||
else:
|
||||
kwargs = {} if kwargs is None else kwargs
|
||||
values["client"] = TransformersLLM.from_model_id(model_path, model_kwargs,
|
||||
**kwargs)
|
||||
|
||||
except ImportError:
|
||||
raise ModuleNotFoundError(
|
||||
"Could not import bigdl-llm library. "
|
||||
"Please install the bigdl-llm library to "
|
||||
"use this embedding model: pip install bigdl-llm"
|
||||
)
|
||||
except Exception as e:
|
||||
raise ValueError(
|
||||
f"Could not load model from path: {model_path}. "
|
||||
f"Please make sure the model embedding class matches "
|
||||
"the model you want to load."
|
||||
f"Received error {e}"
|
||||
)
|
||||
|
||||
return values
|
||||
|
||||
@property
|
||||
def _default_params(self) -> Dict[str, Any]:
|
||||
"""Get the default parameters for calling llama_cpp."""
|
||||
return {
|
||||
"suffix": self.suffix,
|
||||
"max_tokens": self.max_tokens,
|
||||
"temperature": self.temperature,
|
||||
"top_p": self.top_p,
|
||||
"logprobs": self.logprobs,
|
||||
"echo": self.echo,
|
||||
"stop_sequences": self.stop, # key here is convention among LLM classes
|
||||
"repeat_penalty": self.repeat_penalty,
|
||||
"top_k": self.top_k,
|
||||
}
|
||||
|
||||
@property
|
||||
def _identifying_params(self) -> Dict[str, Any]:
|
||||
"""Get the identifying parameters."""
|
||||
return {**{"model_path": self.model_path},
|
||||
**self._default_params}
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
"""Return type of llm."""
|
||||
return "BigDL"
|
||||
|
||||
def _get_parameters(self, stop: Optional[List[str]] = None) -> Dict[str, Any]:
|
||||
"""
|
||||
Performs sanity check, preparing parameters in format needed by llama_cpp.
|
||||
|
||||
Args:
|
||||
stop (Optional[List[str]]): List of stop sequences for llama_cpp.
|
||||
|
||||
Returns:
|
||||
Dictionary containing the combined parameters.
|
||||
"""
|
||||
|
||||
# Raise error if stop sequences are in both input and default params
|
||||
if self.stop and stop is not None:
|
||||
raise ValueError("`stop` found in both the input and default params.")
|
||||
|
||||
params = self._default_params
|
||||
|
||||
# llama_cpp expects the "stop" key not this, so we remove it:
|
||||
params.pop("stop_sequences")
|
||||
|
||||
# then sets it as configured, or default to an empty list:
|
||||
params["stop"] = self.stop or stop or []
|
||||
|
||||
return params
|
||||
|
||||
def _call(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs
|
||||
) -> str:
|
||||
"""Call the Llama model and return the output.
|
||||
|
||||
Args:
|
||||
prompt: The prompt to use for generation.
|
||||
stop: A list of strings to stop generation when encountered.
|
||||
|
||||
Returns:
|
||||
The generated text.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain.llms import LlamaLLM
|
||||
llm = LlamaLLM(model_path="/path/to/local/llama/model.bin")
|
||||
llm("This is a prompt.")
|
||||
"""
|
||||
if self.native:
|
||||
if self.streaming:
|
||||
# If streaming is enabled, we use the stream
|
||||
# method that yields as they are generated
|
||||
# and return the combined strings from the first choices's text:
|
||||
combined_text_output = ""
|
||||
for token in self.stream(prompt=prompt, stop=stop, run_manager=run_manager):
|
||||
combined_text_output += token["choices"][0]["text"]
|
||||
return combined_text_output
|
||||
else:
|
||||
params = self._get_parameters(stop)
|
||||
result = self.client(prompt=prompt, **params)
|
||||
return result["choices"][0]["text"]
|
||||
else:
|
||||
return self.client._call(prompt, stop, run_manager, **kwargs)
|
||||
|
||||
def stream(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
) -> Generator[Dict, None, None]:
|
||||
"""Yields results objects as they are generated in real time.
|
||||
|
||||
BETA: this is a beta feature while we figure out the right abstraction.
|
||||
Once that happens, this interface could change.
|
||||
|
||||
It also calls the callback manager's on_llm_new_token event with
|
||||
similar parameters to the OpenAI LLM class method of the same name.
|
||||
|
||||
Args:
|
||||
prompt: The prompts to pass into the model.
|
||||
stop: Optional list of stop words to use when generating.
|
||||
|
||||
Returns:
|
||||
A generator representing the stream of tokens being generated.
|
||||
|
||||
Yields:
|
||||
A dictionary like objects containing a string token and metadata.
|
||||
See llama-cpp-python docs and below for more.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain.llms import LlamaLLM
|
||||
llm = LlamaLLM(
|
||||
model_path="/path/to/local/model.bin",
|
||||
temperature = 0.5
|
||||
)
|
||||
for chunk in llm.stream("Ask 'Hi, how are you?' like a pirate:'",
|
||||
stop=["'","\\n"]):
|
||||
result = chunk["choices"][0]
|
||||
print(result["text"], end='', flush=True)
|
||||
|
||||
"""
|
||||
params = self._get_parameters(stop)
|
||||
result = self.client(prompt=prompt, stream=True, **params)
|
||||
for chunk in result:
|
||||
token = chunk["choices"][0]["text"]
|
||||
log_probs = chunk["choices"][0].get("logprobs", None)
|
||||
if run_manager:
|
||||
run_manager.on_llm_new_token(
|
||||
token=token, verbose=self.verbose, log_probs=log_probs
|
||||
)
|
||||
yield chunk
|
||||
|
||||
def get_num_tokens(self, text: str) -> int:
|
||||
tokenized_text = self.client.tokenize(text.encode("utf-8"))
|
||||
return len(tokenized_text)
|
||||
|
||||
|
||||
class LlamaLLM(_BaseCausalLM):
|
||||
ggml_model = "Llama"
|
||||
ggml_module = "bigdl.llm.models"
|
||||
|
||||
|
||||
class BloomLLM(_BaseCausalLM):
|
||||
ggml_model = "Bloom"
|
||||
ggml_module = "bigdl.llm.models"
|
||||
|
||||
|
||||
class GptneoxLLM(_BaseCausalLM):
|
||||
ggml_model = "Gptneox"
|
||||
ggml_module = "bigdl.llm.models"
|
||||
|
||||
|
||||
class ChatGLMLLM(_BaseCausalLM):
|
||||
ggml_model = "ChatGLM"
|
||||
ggml_module = "bigdl.llm.ggml.model.chatglm"
|
||||
|
||||
|
||||
class StarcoderLLM(_BaseCausalLM):
|
||||
ggml_model = "Starcoder"
|
||||
ggml_module = "bigdl.llm.models"
|
||||
|
|
|
|||
|
|
@ -14,8 +14,8 @@
|
|||
# limitations under the License.
|
||||
#
|
||||
|
||||
from bigdl.llm.langchain.embeddings import BigdlNativeEmbeddings
|
||||
from bigdl.llm.langchain.llms import BigdlNativeLLM
|
||||
from bigdl.llm.langchain.embeddings import *
|
||||
from bigdl.llm.langchain.llms import *
|
||||
import pytest
|
||||
from unittest import TestCase
|
||||
import os
|
||||
|
|
@ -26,6 +26,7 @@ class Test_Models_Basics(TestCase):
|
|||
self.llama_model_path = os.environ.get('LLAMA_INT4_CKPT_PATH')
|
||||
self.bloom_model_path = os.environ.get('BLOOM_INT4_CKPT_PATH')
|
||||
self.gptneox_model_path = os.environ.get('GPTNEOX_INT4_CKPT_PATH')
|
||||
self.starcoder_model_path = os.environ.get('STARCODER_INT4_CKPT_PATH')
|
||||
thread_num = os.environ.get('THREAD_NUM')
|
||||
if thread_num is not None:
|
||||
self.n_threads = int(thread_num)
|
||||
|
|
@ -34,46 +35,64 @@ class Test_Models_Basics(TestCase):
|
|||
|
||||
|
||||
def test_langchain_llm_embedding_llama(self):
|
||||
bigdl_embeddings = BigdlNativeEmbeddings(
|
||||
model_path=self.llama_model_path,
|
||||
model_family="llama")
|
||||
bigdl_embeddings = LlamaEmbeddings(
|
||||
model_path=self.llama_model_path)
|
||||
text = "This is a test document."
|
||||
query_result = bigdl_embeddings.embed_query(text)
|
||||
doc_result = bigdl_embeddings.embed_documents([text])
|
||||
|
||||
def test_langchain_llm_embedding_gptneox(self):
|
||||
bigdl_embeddings = BigdlNativeEmbeddings(
|
||||
model_path=self.gptneox_model_path,
|
||||
model_family="gptneox")
|
||||
bigdl_embeddings = GptneoxEmbeddings(
|
||||
model_path=self.gptneox_model_path)
|
||||
text = "This is a test document."
|
||||
query_result = bigdl_embeddings.embed_query(text)
|
||||
doc_result = bigdl_embeddings.embed_documents([text])
|
||||
|
||||
def test_langchain_llm_embedding_bloom(self):
|
||||
bigdl_embeddings = BloomEmbeddings(
|
||||
model_path=self.bloom_model_path)
|
||||
text = "This is a test document."
|
||||
query_result = bigdl_embeddings.embed_query(text)
|
||||
doc_result = bigdl_embeddings.embed_documents([text])
|
||||
|
||||
def test_langchain_llm_embedding_starcoder(self):
|
||||
bigdl_embeddings = StarcoderEmbeddings(
|
||||
model_path=self.starcoder_model_path)
|
||||
text = "This is a test document."
|
||||
query_result = bigdl_embeddings.embed_query(text)
|
||||
doc_result = bigdl_embeddings.embed_documents([text])
|
||||
|
||||
def test_langchain_llm_llama(self):
|
||||
llm = BigdlNativeLLM(
|
||||
model_path=self.llama_model_path,
|
||||
llm = LlamaLLM(
|
||||
model_path=self.llama_model_path,
|
||||
max_tokens=32,
|
||||
n_threads=self.n_threads)
|
||||
question = "What is AI?"
|
||||
result = llm(question)
|
||||
|
||||
def test_langchain_llm_gptneox(self):
|
||||
llm = BigdlNativeLLM(
|
||||
llm = GptneoxLLM(
|
||||
model_path=self.gptneox_model_path,
|
||||
model_family="gptneox",
|
||||
max_tokens=32,
|
||||
n_threads=self.n_threads)
|
||||
question = "What is AI?"
|
||||
result = llm(question)
|
||||
|
||||
def test_langchain_llm_bloom(self):
|
||||
llm = BigdlNativeLLM(
|
||||
model_path=self.bloom_model_path,
|
||||
model_family="bloom",
|
||||
llm = BloomLLM(
|
||||
model_path=self.bloom_model_path,
|
||||
max_tokens=32,
|
||||
n_threads=self.n_threads)
|
||||
question = "What is AI?"
|
||||
result = llm(question)
|
||||
|
||||
|
||||
def test_langchain_llm_starcoder(self):
|
||||
llm = StarcoderLLM(
|
||||
model_path=self.starcoder_model_path,
|
||||
max_tokens=32,
|
||||
n_threads=self.n_threads)
|
||||
question = "What is AI?"
|
||||
result = llm(question)
|
||||
|
||||
if __name__ == '__main__':
|
||||
pytest.main([__file__])
|
||||
pytest.main([__file__])
|
||||
|
|
|
|||
|
|
@ -14,8 +14,10 @@
|
|||
# limitations under the License.
|
||||
#
|
||||
|
||||
from bigdl.llm.langchain.llms import TransformersLLM, TransformersPipelineLLM
|
||||
from bigdl.llm.langchain.embeddings import TransformersEmbeddings
|
||||
from bigdl.llm.langchain.llms import TransformersLLM, TransformersPipelineLLM, \
|
||||
LlamaLLM, BloomLLM
|
||||
from bigdl.llm.langchain.embeddings import TransformersEmbeddings, LlamaEmbeddings, \
|
||||
BloomEmbeddings
|
||||
|
||||
|
||||
from langchain.document_loaders import WebBaseLoader
|
||||
|
|
@ -37,12 +39,15 @@ class Test_Langchain_Transformers_API(TestCase):
|
|||
def setUp(self):
|
||||
self.auto_model_path = os.environ.get('ORIGINAL_CHATGLM2_6B_PATH')
|
||||
self.auto_causal_model_path = os.environ.get('ORIGINAL_REPLIT_CODE_PATH')
|
||||
self.llama_model_path = os.environ.get('LLAMA_ORIGIN_PATH')
|
||||
self.bloom_model_path = os.environ.get('BLOOM_ORIGIN_PATH')
|
||||
thread_num = os.environ.get('THREAD_NUM')
|
||||
if thread_num is not None:
|
||||
self.n_threads = int(thread_num)
|
||||
else:
|
||||
self.n_threads = 2
|
||||
|
||||
|
||||
def test_pipeline_llm(self):
|
||||
texts = 'def hello():\n print("hello world")\n'
|
||||
bigdl_llm = TransformersPipelineLLM.from_model_id(model_id=self.auto_causal_model_path, task='text-generation', model_kwargs={'trust_remote_code': True})
|
||||
|
|
@ -51,17 +56,37 @@ class Test_Langchain_Transformers_API(TestCase):
|
|||
res = "hello()" in output
|
||||
self.assertTrue(res)
|
||||
|
||||
|
||||
|
||||
def test_causalLM_embeddings(self):
|
||||
bigdl_embeddings = BloomEmbeddings(model_path=self.bloom_model_path, model_kwargs={'trust_remote_code': True}, native=False)
|
||||
text = "This is a test document."
|
||||
query_result = bigdl_embeddings.embed_query(text)
|
||||
doc_result = bigdl_embeddings.embed_documents([text])
|
||||
|
||||
bigdl_llm = BloomLLM(model_path=self.bloom_model_path, model_kwargs={'trust_remote_code': True}, native=False)
|
||||
res = bigdl_llm(text)
|
||||
|
||||
"""
|
||||
def test_transformers_llama_embeddings(self):
|
||||
bigdl_embeddings = TransformersEmbeddings.from_model_id(model_id=self.llama_model_path, model_kwargs={'trust_remote_code': True})
|
||||
text = "This is a test document."
|
||||
query_result = bigdl_embeddings.embed_query(text)
|
||||
doc_result = bigdl_embeddings.embed_documents([text])
|
||||
|
||||
bigdl_llm = TransformersLLM.from_model_id(model_id=self.llama_model_path, model_kwargs={'trust_remote_code': True})
|
||||
res = bigdl_llm(text)
|
||||
"""
|
||||
|
||||
def test_qa_chain(self):
|
||||
texts = '''
|
||||
AI is a machine’s ability to perform the cognitive functions
|
||||
we associate with human minds, such as perceiving, reasoning,
|
||||
learning, interacting with an environment, problem solving,
|
||||
and even exercising creativity. You’ve probably interacted
|
||||
with AI even if you didn’t realize it—voice assistants like Siri
|
||||
and Alexa are founded on AI technology, as are some customer
|
||||
service chatbots that pop up to help you navigate websites.
|
||||
'''
|
||||
AI is a machine’s ability to perform the cognitive functions
|
||||
we associate with human minds, such as perceiving, reasoning,
|
||||
learning, interacting with an environment, problem solving,
|
||||
and even exercising creativity. You’ve probably interacted
|
||||
with AI even if you didn’t realize it—voice assistants like Siri
|
||||
and Alexa are founded on AI technology, as are some customer
|
||||
service chatbots that pop up to help you navigate websites.
|
||||
'''
|
||||
text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
|
||||
texts = text_splitter.split_text(texts)
|
||||
query = 'What is AI?'
|
||||
|
|
@ -76,6 +101,35 @@ service chatbots that pop up to help you navigate websites.
|
|||
output = doc_chain.run(input_documents=docs, question=query)
|
||||
res = "AI" in output
|
||||
self.assertTrue(res)
|
||||
|
||||
|
||||
|
||||
"""
|
||||
def test_qa_chain_causalLM(self):
|
||||
texts = '''
|
||||
AI is a machine’s ability to perform the cognitive functions
|
||||
we associate with human minds, such as perceiving, reasoning,
|
||||
learning, interacting with an environment, problem solving,
|
||||
and even exercising creativity. You’ve probably interacted
|
||||
with AI even if you didn’t realize it—voice assistants like Siri
|
||||
and Alexa are founded on AI technology, as are some customer
|
||||
service chatbots that pop up to help you navigate websites.
|
||||
'''
|
||||
text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
|
||||
texts = text_splitter.split_text(texts)
|
||||
query = 'What is AI?'
|
||||
embeddings = LlamaEmbeddings(model_path=self.llama_model_path, model_kwargs={'trust_remote_code': True}, native=False)
|
||||
|
||||
docsearch = Chroma.from_texts(texts, embeddings, metadatas=[{"source": str(i)} for i in range(len(texts))]).as_retriever()
|
||||
|
||||
#get relavant texts
|
||||
docs = docsearch.get_relevant_documents(query)
|
||||
bigdl_llm = LlamaLLM(model_path=self.llama_model_path, model_kwargs={'trust_remote_code': True}, native=False)
|
||||
doc_chain = load_qa_chain(bigdl_llm, chain_type="stuff", prompt=QA_PROMPT)
|
||||
output = doc_chain.run(input_documents=docs, question=query)
|
||||
res = "AI" in output
|
||||
self.assertTrue(res)
|
||||
"""
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pytest.main([__file__])
|
||||
|
|
|
|||
Loading…
Reference in a new issue