[LLM] Unify Langchain Native and Transformers LLM API (#8752)
* deprecate BigDLNativeTransformers and add specific LMEmbedding method * deprecate and add LM methods for langchain llms * add native params to native langchain * new imple for embedding * move ut from bigdlnative to casual llm * rename embeddings api and examples update align with usage updating * docqa example hot-fix * add more api docs * add langchain ut for starcoder * support model_kwargs for transformer methods when calling causalLM and add ut * ut fix for transformers embedding * update for langchain causal supporting transformers * remove model_family in readme doc * add model_families params to support more models * update api docs and remove chatglm embeddings for now * remove chatglm embeddings in examples * new refactor for ut to add bloom and transformers llama ut * disable llama transformers embedding ut
This commit is contained in:
		
							parent
							
								
									5582872744
								
							
						
					
					
						commit
						d2926c7672
					
				
					 11 changed files with 747 additions and 75 deletions
				
			
		
							
								
								
									
										12
									
								
								.github/workflows/llm_unit_tests.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										12
									
								
								.github/workflows/llm_unit_tests.yml
									
									
									
									
										vendored
									
									
								
							| 
						 | 
					@ -77,6 +77,8 @@ jobs:
 | 
				
			||||||
        run: |
 | 
					        run: |
 | 
				
			||||||
          echo "SPEECH_DATASET_PATH=${DATASET_DIR}/librispeech_asr_dummy" >> "$GITHUB_ENV"
 | 
					          echo "SPEECH_DATASET_PATH=${DATASET_DIR}/librispeech_asr_dummy" >> "$GITHUB_ENV"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					          echo "LLAMA_ORIGIN_PATH=${ORIGIN_DIR}/llama-7b-hf" >> "$GITHUB_ENV"
 | 
				
			||||||
 | 
					          echo "BLOOM_ORIGIN_PATH=${ORIGIN_DIR}/bloom-7b1" >> "$GITHUB_ENV"
 | 
				
			||||||
          echo "ORIGINAL_CHATGLM2_6B_PATH=${ORIGIN_DIR}/chatglm2-6b" >> "$GITHUB_ENV"
 | 
					          echo "ORIGINAL_CHATGLM2_6B_PATH=${ORIGIN_DIR}/chatglm2-6b" >> "$GITHUB_ENV"
 | 
				
			||||||
          echo "ORIGINAL_REPLIT_CODE_PATH=${ORIGIN_DIR}/replit-code-v1-3b" >> "$GITHUB_ENV"
 | 
					          echo "ORIGINAL_REPLIT_CODE_PATH=${ORIGIN_DIR}/replit-code-v1-3b" >> "$GITHUB_ENV"
 | 
				
			||||||
          echo "ORIGINAL_WHISPER_TINY_PATH=${ORIGIN_DIR}/whisper-tiny" >> "$GITHUB_ENV"
 | 
					          echo "ORIGINAL_WHISPER_TINY_PATH=${ORIGIN_DIR}/whisper-tiny" >> "$GITHUB_ENV"
 | 
				
			||||||
| 
						 | 
					@ -143,6 +145,16 @@ jobs:
 | 
				
			||||||
            echo "wget -r -nH --no-verbose --cut-dirs=1 $LLM_FTP_URL/llm/whisper-tiny -P $ORIGIN_DIR"
 | 
					            echo "wget -r -nH --no-verbose --cut-dirs=1 $LLM_FTP_URL/llm/whisper-tiny -P $ORIGIN_DIR"
 | 
				
			||||||
            wget -r -nH --no-verbose --cut-dirs=1 $LLM_FTP_URL/llm/whisper-tiny -P $ORIGIN_DIR
 | 
					            wget -r -nH --no-verbose --cut-dirs=1 $LLM_FTP_URL/llm/whisper-tiny -P $ORIGIN_DIR
 | 
				
			||||||
          fi
 | 
					          fi
 | 
				
			||||||
 | 
					          if [ ! -d $LLAMA_ORIGIN_PATH ]; then
 | 
				
			||||||
 | 
					            echo "Directory $LLAMA_ORIGIN_PATH not found. Downloading from FTP server..."
 | 
				
			||||||
 | 
					            echo "wget --no-verbose $LLM_FTP_URL/llm/llama-7b-hf -P $ORIGIN_DIR"
 | 
				
			||||||
 | 
					            wget -r -nH --no-verbose --cut-dirs=1 $LLM_FTP_URL/llm/llama-7b-hf -P $ORIGIN_DIR
 | 
				
			||||||
 | 
					          fi
 | 
				
			||||||
 | 
					          if [ ! -d $BLOOM_ORIGIN_PATH ]; then
 | 
				
			||||||
 | 
					            echo "Directory $BLOOM_ORIGIN_PATH not found. Downloading from FTP server..."
 | 
				
			||||||
 | 
					            echo "wget --no-verbose $LLM_FTP_URL/llm/bloom-7b1 -P $ORIGIN_DIR"
 | 
				
			||||||
 | 
					            wget -r -nH --no-verbose --cut-dirs=1 $LLM_FTP_URL/llm/bloom-7b1 -P $ORIGIN_DIR
 | 
				
			||||||
 | 
					          fi
 | 
				
			||||||
          if [ ! -d $SPEECH_DATASET_PATH ]; then
 | 
					          if [ ! -d $SPEECH_DATASET_PATH ]; then
 | 
				
			||||||
            echo "Directory $SPEECH_DATASET_PATH not found. Downloading from FTP server..."
 | 
					            echo "Directory $SPEECH_DATASET_PATH not found. Downloading from FTP server..."
 | 
				
			||||||
            echo "wget -r -nH --no-verbose --cut-dirs=2 $LLM_FTP_URL/llm/datasets/librispeech_asr_dummy -P $DATASET_DIR"
 | 
					            echo "wget -r -nH --no-verbose --cut-dirs=2 $LLM_FTP_URL/llm/datasets/librispeech_asr_dummy -P $DATASET_DIR"
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -154,19 +154,23 @@ You may run the models using the LangChain API in `bigdl-llm`.
 | 
				
			||||||
 
 | 
					 
 | 
				
			||||||
- **Using native INT4 format**
 | 
					- **Using native INT4 format**
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  You may also convert Hugging Face *Transformers* models into *native INT4* format (currently only *llama*/*bloom*/*gptneox*/*starcoder* model family is supported), and then run the converted models using the LangChain API as follows.
 | 
					  You may also convert Hugging Face *Transformers* models into *native INT4* format, and then run the converted models using the LangChain API as follows.
 | 
				
			||||||
  
 | 
					  
 | 
				
			||||||
  >**Note**: Currently only llama/bloom/gptneox/starcoder model family is supported; for other models, you may use the Transformers INT4 format as described above).
 | 
					  >**Notes**: 
 | 
				
			||||||
 | 
					  
 | 
				
			||||||
 | 
					   >* Currently only llama/bloom/gptneox/starcoder/chatglm model families are supported; for other models, you may use the Hugging Face `transformers` INT4 format as described above).
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					   >* You may choose the corresponding API developed for specific native models to load the converted model.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  ```python
 | 
					  ```python
 | 
				
			||||||
  from bigdl.llm.langchain.llms import BigdlNativeLLM
 | 
					  from bigdl.llm.langchain.llms import LlamaLLM
 | 
				
			||||||
  from bigdl.llm.langchain.embeddings import BigdlNativeEmbeddings
 | 
					  from bigdl.llm.langchain.embeddings import LlamaEmbeddings
 | 
				
			||||||
  from langchain.chains.question_answering import load_qa_chain
 | 
					  from langchain.chains.question_answering import load_qa_chain
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  embeddings = BigdlNativeEmbeddings(model_path='/path/to/converted/model.bin',
 | 
					  #switch to ChatGLMEmbeddings/GptneoxEmbeddings/BloomEmbeddings/StarcoderEmbeddings to load other models
 | 
				
			||||||
                            model_family="llama",...)
 | 
					  embeddings = LlamaEmbeddings(model_path='/path/to/converted/model.bin')
 | 
				
			||||||
  bigdl_llm = BigdlNativeLLM(model_path='/path/to/converted/model.bin',
 | 
					  #switch to ChatGLMLLM/GptneoxLLM/BloomLLM/StarcoderLLM to load other models
 | 
				
			||||||
                       model_family="llama",...)
 | 
					  bigdl_llm = LlamaLLM(model_path='/path/to/converted/model.bin')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  doc_chain = load_qa_chain(bigdl_llm, ...)
 | 
					  doc_chain = load_qa_chain(bigdl_llm, ...)
 | 
				
			||||||
  doc_chain.run(...)
 | 
					  doc_chain.run(...)
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -31,9 +31,8 @@ from langchain.chains.question_answering import load_qa_chain
 | 
				
			||||||
from langchain.callbacks.manager import CallbackManager
 | 
					from langchain.callbacks.manager import CallbackManager
 | 
				
			||||||
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
 | 
					from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from bigdl.llm.langchain.llms import BigdlNativeLLM
 | 
					from bigdl.llm.langchain.llms import *
 | 
				
			||||||
from bigdl.llm.langchain.embeddings import BigdlNativeEmbeddings
 | 
					from bigdl.llm.langchain.embeddings import *
 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def main(args):
 | 
					def main(args):
 | 
				
			||||||
| 
						 | 
					@ -45,7 +44,6 @@ def main(args):
 | 
				
			||||||
    n_ctx = args.n_ctx
 | 
					    n_ctx = args.n_ctx
 | 
				
			||||||
    n_threads=args.thread_num
 | 
					    n_threads=args.thread_num
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					 | 
				
			||||||
    callback_manager = CallbackManager([StreamingStdOutCallbackHandler()])
 | 
					    callback_manager = CallbackManager([StreamingStdOutCallbackHandler()])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # split texts of input doc
 | 
					    # split texts of input doc
 | 
				
			||||||
| 
						 | 
					@ -54,15 +52,35 @@ def main(args):
 | 
				
			||||||
    text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
 | 
					    text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
 | 
				
			||||||
    texts = text_splitter.split_text(input_doc)
 | 
					    texts = text_splitter.split_text(input_doc)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    model_family_to_embeddings = {
 | 
				
			||||||
 | 
					        "llama": LlamaEmbeddings,
 | 
				
			||||||
 | 
					        "gptneox": GptneoxEmbeddings,
 | 
				
			||||||
 | 
					        "bloom": BloomEmbeddings,
 | 
				
			||||||
 | 
					        "starcoder": StarcoderEmbeddings
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    model_family_to_llm = {
 | 
				
			||||||
 | 
					        "llama": LlamaLLM,
 | 
				
			||||||
 | 
					        "gptneox": GptneoxLLM,
 | 
				
			||||||
 | 
					        "bloom": BloomLLM,
 | 
				
			||||||
 | 
					        "starcoder": StarcoderLLM
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if model_family in model_family_to_embeddings and model_family in model_family_to_llm:
 | 
				
			||||||
 | 
					        llm_embeddings = model_family_to_embeddings[model_family]
 | 
				
			||||||
 | 
					        langchain_llm = model_family_to_llm[model_family]
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        raise ValueError(f"Unknown model family: {model_family}")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # create embeddings and store into vectordb
 | 
					    # create embeddings and store into vectordb
 | 
				
			||||||
    embeddings = BigdlNativeEmbeddings(model_path=model_path, model_family=model_family, n_threads=n_threads, n_ctx=n_ctx)
 | 
					    embeddings = llm_embeddings(model_path=model_path, n_threads=n_threads, n_ctx=n_ctx)
 | 
				
			||||||
    docsearch = Chroma.from_texts(texts, embeddings, metadatas=[{"source": str(i)} for i in range(len(texts))]).as_retriever()
 | 
					    docsearch = Chroma.from_texts(texts, embeddings, metadatas=[{"source": str(i)} for i in range(len(texts))]).as_retriever()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    #get relavant texts
 | 
					    # get relavant texts
 | 
				
			||||||
    docs = docsearch.get_relevant_documents(query)
 | 
					    docs = docsearch.get_relevant_documents(query)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    bigdl_llm = BigdlNativeLLM(
 | 
					    bigdl_llm = langchain_llm(
 | 
				
			||||||
        model_path=model_path, model_family=model_family, n_ctx=n_ctx, n_threads=n_threads, callback_manager=callback_manager
 | 
					        model_path=model_path, n_ctx=n_ctx, n_threads=n_threads, callback_manager=callback_manager
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    doc_chain = load_qa_chain(
 | 
					    doc_chain = load_qa_chain(
 | 
				
			||||||
| 
						 | 
					@ -73,9 +91,9 @@ def main(args):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
if __name__ == '__main__':
 | 
					if __name__ == '__main__':
 | 
				
			||||||
    parser = argparse.ArgumentParser(description='BigdlNativeLLM Langchain QA over Docs Example')
 | 
					    parser = argparse.ArgumentParser(description='BigDLCausalLM Langchain QA over Docs Example')
 | 
				
			||||||
    parser.add_argument('-x','--model-family', type=str, required=True,
 | 
					    parser.add_argument('-x','--model-family', type=str, required=True,
 | 
				
			||||||
                        choices=["llama", "bloom", "gptneox"],
 | 
					                        choices=["llama", "bloom", "gptneox", "chatglm", "starcoder"],
 | 
				
			||||||
                        help='the model family')
 | 
					                        help='the model family')
 | 
				
			||||||
    parser.add_argument('-m','--model-path', type=str, required=True,
 | 
					    parser.add_argument('-m','--model-path', type=str, required=True,
 | 
				
			||||||
                        help='the path to the converted llm model')
 | 
					                        help='the path to the converted llm model')
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -21,7 +21,7 @@
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import argparse
 | 
					import argparse
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from bigdl.llm.langchain.llms import BigdlNativeLLM
 | 
					from bigdl.llm.langchain.llms import *
 | 
				
			||||||
from langchain import PromptTemplate, LLMChain
 | 
					from langchain import PromptTemplate, LLMChain
 | 
				
			||||||
from langchain.callbacks.manager import CallbackManager
 | 
					from langchain.callbacks.manager import CallbackManager
 | 
				
			||||||
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
 | 
					from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
 | 
				
			||||||
| 
						 | 
					@ -40,10 +40,22 @@ def main(args):
 | 
				
			||||||
    # Callbacks support token-wise streaming
 | 
					    # Callbacks support token-wise streaming
 | 
				
			||||||
    callback_manager = CallbackManager([StreamingStdOutCallbackHandler()])
 | 
					    callback_manager = CallbackManager([StreamingStdOutCallbackHandler()])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    model_family_to_llm = {
 | 
				
			||||||
 | 
					        "llama": LlamaLLM,
 | 
				
			||||||
 | 
					        "gptneox": GptneoxLLM,
 | 
				
			||||||
 | 
					        "bloom": BloomLLM,
 | 
				
			||||||
 | 
					        "starcoder": StarcoderLLM,
 | 
				
			||||||
 | 
					        "chatglm": ChatGLMLLM
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if model_family in model_family_to_llm:
 | 
				
			||||||
 | 
					        langchain_llm = model_family_to_llm[model_family]
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        raise ValueError(f"Unknown model family: {model_family}")
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
    # Verbose is required to pass to the callback manager
 | 
					    # Verbose is required to pass to the callback manager
 | 
				
			||||||
    llm = BigdlNativeLLM(
 | 
					    llm = langchain_llm(
 | 
				
			||||||
        model_path=model_path,
 | 
					        model_path=model_path,
 | 
				
			||||||
        model_family=model_family,
 | 
					 | 
				
			||||||
        n_threads=n_threads,
 | 
					        n_threads=n_threads,
 | 
				
			||||||
        callback_manager=callback_manager, 
 | 
					        callback_manager=callback_manager, 
 | 
				
			||||||
        verbose=True
 | 
					        verbose=True
 | 
				
			||||||
| 
						 | 
					@ -55,9 +67,9 @@ def main(args):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
if __name__ == '__main__':
 | 
					if __name__ == '__main__':
 | 
				
			||||||
    parser = argparse.ArgumentParser(description='BigdlNativeLLM Langchain Streaming Chat Example')
 | 
					    parser = argparse.ArgumentParser(description='BigDLCausalLM Langchain Streaming Chat Example')
 | 
				
			||||||
    parser.add_argument('-x','--model-family', type=str, required=True,
 | 
					    parser.add_argument('-x','--model-family', type=str, required=True,
 | 
				
			||||||
                        choices=["llama", "bloom", "gptneox"],
 | 
					                        choices=["llama", "bloom", "gptneox", "chatglm", "starcoder"],
 | 
				
			||||||
                        help='the model family')
 | 
					                        help='the model family')
 | 
				
			||||||
    parser.add_argument('-m','--model-path', type=str, required=True,
 | 
					    parser.add_argument('-m','--model-path', type=str, required=True,
 | 
				
			||||||
                        help='the path to the converted llm model')
 | 
					                        help='the path to the converted llm model')
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -23,7 +23,7 @@
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from langchain import LLMChain, PromptTemplate
 | 
					from langchain import LLMChain, PromptTemplate
 | 
				
			||||||
from bigdl.llm.langchain.llms import BigdlNativeLLM
 | 
					from bigdl.llm.langchain.llms import *
 | 
				
			||||||
from langchain.memory import ConversationBufferWindowMemory
 | 
					from langchain.memory import ConversationBufferWindowMemory
 | 
				
			||||||
from langchain.callbacks.manager import CallbackManager
 | 
					from langchain.callbacks.manager import CallbackManager
 | 
				
			||||||
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
 | 
					from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
 | 
				
			||||||
| 
						 | 
					@ -35,7 +35,6 @@ import argparse
 | 
				
			||||||
def prepare_chain(args):
 | 
					def prepare_chain(args):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    model_path = args.model_path
 | 
					    model_path = args.model_path
 | 
				
			||||||
    model_family = args.model_family
 | 
					 | 
				
			||||||
    n_threads = args.thread_num
 | 
					    n_threads = args.thread_num
 | 
				
			||||||
    n_ctx = args.context_size
 | 
					    n_ctx = args.context_size
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -48,11 +47,23 @@ def prepare_chain(args):
 | 
				
			||||||
    A:"""
 | 
					    A:"""
 | 
				
			||||||
    prompt = PromptTemplate(input_variables=["history", "human_input"], template=template)
 | 
					    prompt = PromptTemplate(input_variables=["history", "human_input"], template=template)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # We use our BigdlNativeLLM to subsititute OpenAI web-required API
 | 
					    # We use our BigDLCausalLLM to subsititute OpenAI web-required API
 | 
				
			||||||
 | 
					    model_family_to_llm = {
 | 
				
			||||||
 | 
					        "llama": LlamaLLM,
 | 
				
			||||||
 | 
					        "gptneox": GptneoxLLM,
 | 
				
			||||||
 | 
					        "bloom": BloomLLM,
 | 
				
			||||||
 | 
					        "starcoder": StarcoderLLM,
 | 
				
			||||||
 | 
					        "chatglm": ChatGLMLLM
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if model_family in model_family_to_llm:
 | 
				
			||||||
 | 
					        langchain_llm = model_family_to_llm[model_family]
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        raise ValueError(f"Unknown model family: {model_family}")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    callback_manager = CallbackManager([StreamingStdOutCallbackHandler()])
 | 
					    callback_manager = CallbackManager([StreamingStdOutCallbackHandler()])
 | 
				
			||||||
    llm = BigdlNativeLLM(
 | 
					    llm = langchain_llm(
 | 
				
			||||||
            model_path=model_path,
 | 
					            model_path=model_path,
 | 
				
			||||||
            model_family=model_family,
 | 
					 | 
				
			||||||
            n_threads=n_threads,
 | 
					            n_threads=n_threads,
 | 
				
			||||||
            callback_manager=callback_manager,
 | 
					            callback_manager=callback_manager,
 | 
				
			||||||
            verbose=True,
 | 
					            verbose=True,
 | 
				
			||||||
| 
						 | 
					@ -114,8 +125,9 @@ def main(args):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
if __name__ == '__main__':
 | 
					if __name__ == '__main__':
 | 
				
			||||||
    parser = argparse.ArgumentParser(description='BigdlNativeLLM Langchain Voice Assistant Example')
 | 
					    parser = argparse.ArgumentParser(description='BigDLCausalLM Langchain Voice Assistant Example')
 | 
				
			||||||
    parser.add_argument('-x','--model-family', type=str, required=True,
 | 
					    parser.add_argument('-x','--model-family', type=str, required=True,
 | 
				
			||||||
 | 
					                        choices=["llama", "bloom", "gptneox", "chatglm", "starcoder"],
 | 
				
			||||||
                        help='the model family')
 | 
					                        help='the model family')
 | 
				
			||||||
    parser.add_argument('-m','--model-path', type=str, required=True,
 | 
					    parser.add_argument('-m','--model-path', type=str, required=True,
 | 
				
			||||||
                        help='the path to the converted llm model')
 | 
					                        help='the path to the converted llm model')
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -19,10 +19,14 @@
 | 
				
			||||||
# Otherwise there would be module not found error in non-pip's setting as Python would
 | 
					# Otherwise there would be module not found error in non-pip's setting as Python would
 | 
				
			||||||
# only search the first bigdl package and end up finding only one sub-package.
 | 
					# only search the first bigdl package and end up finding only one sub-package.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from .bigdlllm import BigdlNativeEmbeddings
 | 
					from .bigdlllm import *
 | 
				
			||||||
from .transformersembeddings import TransformersEmbeddings
 | 
					from .transformersembeddings import TransformersEmbeddings
 | 
				
			||||||
 | 
					
 | 
				
			||||||
__all__ = [
 | 
					__all__ = [
 | 
				
			||||||
    "BigdlNativeEmbeddings",
 | 
					    "BigdlNativeEmbeddings",
 | 
				
			||||||
 | 
					    "LlamaEmbeddings",
 | 
				
			||||||
 | 
					    "BloomEmbeddings",
 | 
				
			||||||
 | 
					    "GptneoxEmbeddings",
 | 
				
			||||||
 | 
					    "StarcoderEmbeddings",
 | 
				
			||||||
    "TransformersEmbeddings"
 | 
					    "TransformersEmbeddings"
 | 
				
			||||||
]
 | 
					]
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -45,12 +45,14 @@
 | 
				
			||||||
# THE SOFTWARE.
 | 
					# THE SOFTWARE.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
"""Wrapper around BigdlNative embedding models."""
 | 
					"""Wrapper around BigdlNative embedding models."""
 | 
				
			||||||
 | 
					import logging
 | 
				
			||||||
import importlib
 | 
					import importlib
 | 
				
			||||||
from typing import Any, Dict, List, Optional
 | 
					from typing import Any, Dict, List, Optional
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from pydantic import BaseModel, Extra, Field, root_validator
 | 
					from pydantic import BaseModel, Extra, Field, root_validator
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from langchain.embeddings.base import Embeddings
 | 
					from langchain.embeddings.base import Embeddings
 | 
				
			||||||
 | 
					from .transformersembeddings import TransformersEmbeddings
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class BigdlNativeEmbeddings(BaseModel, Embeddings):
 | 
					class BigdlNativeEmbeddings(BaseModel, Embeddings):
 | 
				
			||||||
| 
						 | 
					@ -63,18 +65,25 @@ class BigdlNativeEmbeddings(BaseModel, Embeddings):
 | 
				
			||||||
            llama = BigdlNativeEmbeddings(model_path="/path/to/model.bin")
 | 
					            llama = BigdlNativeEmbeddings(model_path="/path/to/model.bin")
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    logging.warning("BigdlNativeEmbeddings has been deprecated, "
 | 
				
			||||||
 | 
					                    "please switch to the new LMEmbeddings API for sepcific models.")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    model_family: str = "llama"
 | 
					    model_family: str = "llama"
 | 
				
			||||||
    """the model family"""
 | 
					    """The model family: currently supports llama, gptneox, bloom and starcoder."""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    family_info = {
 | 
					    family_info = {
 | 
				
			||||||
        'llama': {'module': "bigdl.llm.models", 'class': "Llama"},
 | 
					        'llama': {'module': "bigdl.llm.models", 'class': "Llama"},
 | 
				
			||||||
        'bloom': {'module': "bigdl.llm.models", 'class': "Bloom"},
 | 
					        'bloom': {'module': "bigdl.llm.models", 'class': "Bloom"},
 | 
				
			||||||
        'gptneox': {'module': "bigdl.llm.models", 'class': "Gptneox"},
 | 
					        'gptneox': {'module': "bigdl.llm.models", 'class': "Gptneox"},
 | 
				
			||||||
 | 
					        'starcoder': {'module':"bigdl.llm.models", 'class': "Starcoder"},
 | 
				
			||||||
    }  #: :meta private:
 | 
					    }  #: :meta private:
 | 
				
			||||||
    """info necessary for different model family initiation and configure"""
 | 
					    """Info necessary for different model family initiation and configure."""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    client: Any  #: :meta private:
 | 
					    client: Any  #: :meta private:
 | 
				
			||||||
 | 
					    """The actual model."""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    model_path: str  # TODO: missing doc
 | 
					    model_path: str  # TODO: missing doc
 | 
				
			||||||
 | 
					    """Path to the converted BigDL-LLM optimized ggml binary checkpoint."""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    n_ctx: int = Field(512, alias="n_ctx")
 | 
					    n_ctx: int = Field(512, alias="n_ctx")
 | 
				
			||||||
    """Token context window."""
 | 
					    """Token context window."""
 | 
				
			||||||
| 
						 | 
					@ -159,7 +168,7 @@ class BigdlNativeEmbeddings(BaseModel, Embeddings):
 | 
				
			||||||
            )
 | 
					            )
 | 
				
			||||||
        except Exception as e:
 | 
					        except Exception as e:
 | 
				
			||||||
            raise ValueError(
 | 
					            raise ValueError(
 | 
				
			||||||
                f"Could not load Llama model from path: {model_path}. "
 | 
					                f"Could not load model from path: {model_path}. "
 | 
				
			||||||
                f"Please make sure the model family {model_family} matches "
 | 
					                f"Please make sure the model family {model_family} matches "
 | 
				
			||||||
                "the model you want to load."
 | 
					                "the model you want to load."
 | 
				
			||||||
                f"Received error {e}"
 | 
					                f"Received error {e}"
 | 
				
			||||||
| 
						 | 
					@ -190,3 +199,186 @@ class BigdlNativeEmbeddings(BaseModel, Embeddings):
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
        embedding = self.client.embed(text)
 | 
					        embedding = self.client.embed(text)
 | 
				
			||||||
        return list(map(float, embedding))
 | 
					        return list(map(float, embedding))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class _BaseEmbeddings(BaseModel, Embeddings):
 | 
				
			||||||
 | 
					    """Wrapper around bigdl-llm embedding models.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    param model_path: If running with ``native int4``, the path should be converted BigDL-LLM
 | 
				
			||||||
 | 
					          optimized ggml binary checkpoint, which should be converted by ``bigdl.llm.llm_convert``.
 | 
				
			||||||
 | 
					          If running with ``transformers int4``, the path should be the huggingface repo id
 | 
				
			||||||
 | 
					          to be downloaded or the huggingface checkpoint folder.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Example:
 | 
				
			||||||
 | 
					        .. code-block:: python
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            from bigdl.llm.langchain.embeddings import LlamaEmbeddings
 | 
				
			||||||
 | 
					            llama = LlamaEmbeddings(model_path="/path/to/model.bin")
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    ggml_model: str = None
 | 
				
			||||||
 | 
					    ggml_module: str = None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    native: bool = True
 | 
				
			||||||
 | 
					    """Load model to either BigDL-LLM optimized Transformers or Native (ggml) int4."""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    client: Any  #: :meta private:
 | 
				
			||||||
 | 
					    """The actual model."""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    model_kwargs: Optional[dict] = None
 | 
				
			||||||
 | 
					    """Key word arguments to pass to the Transformers model."""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    encode_kwargs: Optional[dict] = None
 | 
				
			||||||
 | 
					    """Key word arguments to pass when calling the `encode` method of the Transformers model."""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    kwargs: Any
 | 
				
			||||||
 | 
					    """Additional key word arguments passed to TransformersLLM."""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    model_path: str
 | 
				
			||||||
 | 
					    """Path to the loading model file.
 | 
				
			||||||
 | 
					    If native, the path shoule be converted BigDL-LLM optimized ggml binary checkpoint.
 | 
				
			||||||
 | 
					    If transformers, the path should be the huggingface repo id to be downloaded
 | 
				
			||||||
 | 
					    or the huggingface checkpoint folder."""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    n_ctx: int = Field(512, alias="n_ctx")
 | 
				
			||||||
 | 
					    """Token context window."""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    n_parts: int = Field(-1, alias="n_parts")
 | 
				
			||||||
 | 
					    """Number of parts to split the model into. 
 | 
				
			||||||
 | 
					    If -1, the number of parts is automatically determined."""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    seed: int = Field(-1, alias="seed")
 | 
				
			||||||
 | 
					    """Seed. If -1, a random seed is used."""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    f16_kv: bool = Field(True, alias="f16_kv")
 | 
				
			||||||
 | 
					    """Use half-precision for key/value cache."""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    logits_all: bool = Field(False, alias="logits_all")
 | 
				
			||||||
 | 
					    """Return logits for all tokens, not just the last token."""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    vocab_only: bool = Field(False, alias="vocab_only")
 | 
				
			||||||
 | 
					    """Only load the vocabulary, no weights."""\
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    use_mlock: bool = Field(False, alias="use_mlock")
 | 
				
			||||||
 | 
					    """Force system to keep model in RAM."""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    n_threads: Optional[int] = Field(2, alias="n_threads")
 | 
				
			||||||
 | 
					    """Number of threads to use."""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    n_batch: Optional[int] = Field(512, alias="n_batch")
 | 
				
			||||||
 | 
					    """Number of tokens to process in parallel.
 | 
				
			||||||
 | 
					    Should be a number between 1 and n_ctx."""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    n_gpu_layers: Optional[int] = Field(0, alias="n_gpu_layers")
 | 
				
			||||||
 | 
					    """Number of layers to be loaded into gpu memory. Default None."""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    class Config:
 | 
				
			||||||
 | 
					        """Configuration for this pydantic object."""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        extra = Extra.forbid
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @root_validator()
 | 
				
			||||||
 | 
					    def validate_environment(cls, values: Dict) -> Dict:
 | 
				
			||||||
 | 
					        """Validate that bigdl-llm library is installed."""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        native = values["native"]
 | 
				
			||||||
 | 
					        model_path = values["model_path"]
 | 
				
			||||||
 | 
					        model_kwargs = values["model_kwargs"]
 | 
				
			||||||
 | 
					        kwargs = values["kwargs"]
 | 
				
			||||||
 | 
					        model_param_names = [
 | 
				
			||||||
 | 
					            "n_ctx",
 | 
				
			||||||
 | 
					            "n_parts",
 | 
				
			||||||
 | 
					            "seed",
 | 
				
			||||||
 | 
					            "f16_kv",
 | 
				
			||||||
 | 
					            "logits_all",
 | 
				
			||||||
 | 
					            "vocab_only",
 | 
				
			||||||
 | 
					            "use_mlock",
 | 
				
			||||||
 | 
					            "n_threads",
 | 
				
			||||||
 | 
					            "n_batch",
 | 
				
			||||||
 | 
					        ]
 | 
				
			||||||
 | 
					        model_params = {k: values[k] for k in model_param_names}
 | 
				
			||||||
 | 
					        # For backwards compatibility, only include if non-null.
 | 
				
			||||||
 | 
					        if values["n_gpu_layers"] is not None:
 | 
				
			||||||
 | 
					            model_params["n_gpu_layers"] = values["n_gpu_layers"]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        try:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            module = importlib.import_module(values["ggml_module"])
 | 
				
			||||||
 | 
					            class_ = getattr(module, values["ggml_model"])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            if native:
 | 
				
			||||||
 | 
					                values["client"] = class_(model_path, embedding=True, **model_params)
 | 
				
			||||||
 | 
					            else:
 | 
				
			||||||
 | 
					                kwargs = {} if kwargs is None else kwargs
 | 
				
			||||||
 | 
					                values["client"] = TransformersEmbeddings.from_model_id(model_path, model_kwargs,
 | 
				
			||||||
 | 
					                                                                        **kwargs)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            # from bigdl.llm.ggml.model.llama import Llama
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            # values["client"] = Llama(model_path, embedding=True, **model_params)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        except ImportError:
 | 
				
			||||||
 | 
					            raise ModuleNotFoundError(
 | 
				
			||||||
 | 
					                "Could not import bigdl-llm library. "
 | 
				
			||||||
 | 
					                "Please install the bigdl-llm library to "
 | 
				
			||||||
 | 
					                "use this embedding model: pip install bigdl-llm"
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					        except Exception as e:
 | 
				
			||||||
 | 
					            raise ValueError(
 | 
				
			||||||
 | 
					                f"Could not load model from path: {model_path}. "
 | 
				
			||||||
 | 
					                f"Please make sure the model embedding class matches "
 | 
				
			||||||
 | 
					                "the model you want to load."
 | 
				
			||||||
 | 
					                f"Received error {e}"
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return values
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def embed_documents(self, texts: List[str]) -> List[List[float]]:
 | 
				
			||||||
 | 
					        """Embed a list of documents using the optimized int4 model.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        Args:
 | 
				
			||||||
 | 
					            texts: The list of texts to embed.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        Returns:
 | 
				
			||||||
 | 
					            List of embeddings, one for each text.
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        if self.native:
 | 
				
			||||||
 | 
					            embeddings = [self.client.embed(text) for text in texts]
 | 
				
			||||||
 | 
					            return [list(map(float, e)) for e in embeddings]
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            return self.client.embed_documents(texts)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def embed_query(self, text: str) -> List[float]:
 | 
				
			||||||
 | 
					        """Embed a query using the optimized int4 model.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        Args:
 | 
				
			||||||
 | 
					            text: The text to embed.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        Returns:
 | 
				
			||||||
 | 
					            Embeddings for the text.
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        if self.native:
 | 
				
			||||||
 | 
					            embedding = self.client.embed(text)
 | 
				
			||||||
 | 
					            return list(map(float, embedding))
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            return self.client.embed_query(text)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class LlamaEmbeddings(_BaseEmbeddings):
 | 
				
			||||||
 | 
					    ggml_model = "Llama"
 | 
				
			||||||
 | 
					    ggml_module = "bigdl.llm.models"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class BloomEmbeddings(_BaseEmbeddings):
 | 
				
			||||||
 | 
					    ggml_model = "Bloom"
 | 
				
			||||||
 | 
					    ggml_module = "bigdl.llm.models"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class GptneoxEmbeddings(_BaseEmbeddings):
 | 
				
			||||||
 | 
					    ggml_model = "Gptneox"
 | 
				
			||||||
 | 
					    ggml_module = "bigdl.llm.models"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class StarcoderEmbeddings(_BaseEmbeddings):
 | 
				
			||||||
 | 
					    ggml_model = "Starcoder"
 | 
				
			||||||
 | 
					    ggml_module = "bigdl.llm.models"
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -23,18 +23,28 @@
 | 
				
			||||||
from typing import Dict, Type
 | 
					from typing import Dict, Type
 | 
				
			||||||
from langchain.llms.base import BaseLLM
 | 
					from langchain.llms.base import BaseLLM
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from .bigdlllm import BigdlNativeLLM
 | 
					from .bigdlllm import *
 | 
				
			||||||
from .transformersllm import TransformersLLM
 | 
					from .transformersllm import TransformersLLM
 | 
				
			||||||
from .transformerspipelinellm import TransformersPipelineLLM
 | 
					from .transformerspipelinellm import TransformersPipelineLLM
 | 
				
			||||||
 | 
					
 | 
				
			||||||
__all__ = [
 | 
					__all__ = [
 | 
				
			||||||
    "BigdlNativeLLM",
 | 
					    "BigdlNativeLLM",
 | 
				
			||||||
 | 
					    "LlamaLLM",
 | 
				
			||||||
 | 
					    "BloomLLM",
 | 
				
			||||||
 | 
					    "GptneoxLLM",
 | 
				
			||||||
 | 
					    "ChatGLMLLM",
 | 
				
			||||||
 | 
					    "StarcoderLLM",
 | 
				
			||||||
    "TransformersLLM",
 | 
					    "TransformersLLM",
 | 
				
			||||||
    "TransformersPipelineLLM"
 | 
					    "TransformersPipelineLLM"
 | 
				
			||||||
]
 | 
					]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
type_to_cls_dict: Dict[str, Type[BaseLLM]] = {
 | 
					type_to_cls_dict: Dict[str, Type[BaseLLM]] = {
 | 
				
			||||||
    "BigdlNativeLLM": BigdlNativeLLM,
 | 
					    "BigdlNativeLLM": BigdlNativeLLM,
 | 
				
			||||||
 | 
					    "LlamaLLM": LlamaLLM,
 | 
				
			||||||
 | 
					    "BloomLLM": BloomLLM,
 | 
				
			||||||
 | 
					    "GptneoxLLM": GptneoxLLM,
 | 
				
			||||||
 | 
					    "ChatGLMLLM": ChatGLMLLM,
 | 
				
			||||||
 | 
					    "StarcoderLLM": StarcoderLLM,
 | 
				
			||||||
    "TransformersPipelineLLM": TransformersPipelineLLM,
 | 
					    "TransformersPipelineLLM": TransformersPipelineLLM,
 | 
				
			||||||
    "TransformersLLM": TransformersLLM
 | 
					    "TransformersLLM": TransformersLLM
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
| 
						 | 
					@ -44,6 +44,7 @@
 | 
				
			||||||
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
 | 
					# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
 | 
				
			||||||
# THE SOFTWARE.
 | 
					# THE SOFTWARE.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import logging
 | 
				
			||||||
import importlib
 | 
					import importlib
 | 
				
			||||||
from typing import Any, Dict, Generator, List, Optional
 | 
					from typing import Any, Dict, Generator, List, Optional
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -51,7 +52,7 @@ from pydantic import Field, root_validator
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from langchain.callbacks.manager import CallbackManagerForLLMRun
 | 
					from langchain.callbacks.manager import CallbackManagerForLLMRun
 | 
				
			||||||
from langchain.llms.base import LLM
 | 
					from langchain.llms.base import LLM
 | 
				
			||||||
 | 
					from .transformersllm import TransformersLLM
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class BigdlNativeLLM(LLM):
 | 
					class BigdlNativeLLM(LLM):
 | 
				
			||||||
| 
						 | 
					@ -65,22 +66,26 @@ class BigdlNativeLLM(LLM):
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    model_family: str = "llama"
 | 
					    logging.warning("BigdlNativeLLM has been deprecated, "
 | 
				
			||||||
    """the model family: currently supports llama, gptneox, and bloom."""
 | 
					                    "please switch to the new LLM API for sepcific models.")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    model_family: str = "llama"
 | 
				
			||||||
 | 
					    """The model family: currently supports llama, gptneox, bloom, starcoder and chatglm."""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    family_info = {
 | 
					    family_info = {
 | 
				
			||||||
        'llama': {'module': "bigdl.llm.models" , 'class': "Llama"},
 | 
					        'llama': {'module': "bigdl.llm.models" , 'class': "Llama"},
 | 
				
			||||||
        'bloom': {'module': "bigdl.llm.models", 'class': "Bloom"},
 | 
					        'bloom': {'module': "bigdl.llm.models", 'class': "Bloom"},
 | 
				
			||||||
        'gptneox': {'module': "bigdl.llm.models", 'class': "Gptneox"},
 | 
					        'gptneox': {'module': "bigdl.llm.models", 'class': "Gptneox"},
 | 
				
			||||||
 | 
					        'starcoder': {'module':"bigdl.llm.models", 'class': "Starcoder"},
 | 
				
			||||||
 | 
					        'chatglm': {'module':"bigdl.llm.ggml.model.chatglm", 'class': "ChatGLM"},
 | 
				
			||||||
    }  #: :meta private:
 | 
					    }  #: :meta private:
 | 
				
			||||||
    """info necessary for different model families initiation and configure"""
 | 
					    """Info necessary for different model families initiation and configure."""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    client: Any  #: :meta private:
 | 
					    client: Any  #: :meta private:
 | 
				
			||||||
    """the actual model"""
 | 
					    """The actual model."""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    model_path: str
 | 
					    model_path: str
 | 
				
			||||||
    """The path to the Llama model file."""
 | 
					    """Path to the converted BigDL-LLM optimized ggml binary checkpoint."""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    lora_base: Optional[str] = None
 | 
					    lora_base: Optional[str] = None
 | 
				
			||||||
    """The path to the Llama LoRA base model."""
 | 
					    """The path to the Llama LoRA base model."""
 | 
				
			||||||
| 
						 | 
					@ -197,9 +202,9 @@ class BigdlNativeLLM(LLM):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        except ImportError:
 | 
					        except ImportError:
 | 
				
			||||||
            raise ModuleNotFoundError(
 | 
					            raise ModuleNotFoundError(
 | 
				
			||||||
                "Could not import llama-cpp-python library. "
 | 
					                "Could not import bigdl-llm library. "
 | 
				
			||||||
                "Please install the llama-cpp-python library to "
 | 
					                "Please install the bigdl-llm library to "
 | 
				
			||||||
                "use this embedding model: pip install llama-cpp-python"
 | 
					                "use this embedding model: pip install bigdl-llm"
 | 
				
			||||||
            )
 | 
					            )
 | 
				
			||||||
        except Exception as e:
 | 
					        except Exception as e:
 | 
				
			||||||
            raise ValueError(
 | 
					            raise ValueError(
 | 
				
			||||||
| 
						 | 
					@ -351,3 +356,333 @@ class BigdlNativeLLM(LLM):
 | 
				
			||||||
    def get_num_tokens(self, text: str) -> int:
 | 
					    def get_num_tokens(self, text: str) -> int:
 | 
				
			||||||
        tokenized_text = self.client.tokenize(text.encode("utf-8"))
 | 
					        tokenized_text = self.client.tokenize(text.encode("utf-8"))
 | 
				
			||||||
        return len(tokenized_text)
 | 
					        return len(tokenized_text)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class _BaseCausalLM(LLM):
 | 
				
			||||||
 | 
					    """Wrapper around the BigDL-LLM
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Example:
 | 
				
			||||||
 | 
					        .. code-block:: python
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            from langchain.llms import LlamaLLM
 | 
				
			||||||
 | 
					            llm = LlamaLLM(model_path="/path/to/llama/model")
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    ggml_model: str = None
 | 
				
			||||||
 | 
					    ggml_module: str = None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    native: bool = True
 | 
				
			||||||
 | 
					    """Load model to either BigDL-LLM optimized Transformers or Native (ggml) int4."""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    client: Any  #: :meta private:
 | 
				
			||||||
 | 
					    """The actual model."""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    model_path: str
 | 
				
			||||||
 | 
					    """Path to the loading model file.
 | 
				
			||||||
 | 
					    If native, the path shoule be converted BigDL-LLM optimized ggml binary checkpoint.
 | 
				
			||||||
 | 
					    If transformers, the path should be the huggingface repo id to be downloaded
 | 
				
			||||||
 | 
					    or the huggingface checkpoint folder."""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    model_kwargs: Optional[dict] = None
 | 
				
			||||||
 | 
					    """Key word arguments passed to the Transformers model."""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    kwargs: Any
 | 
				
			||||||
 | 
					    """Additional key word arguments passed to TransformersLLM."""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    lora_base: Optional[str] = None
 | 
				
			||||||
 | 
					    """The path to the Llama LoRA base model."""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    lora_path: Optional[str] = None
 | 
				
			||||||
 | 
					    """The path to the Llama LoRA. If None, no LoRa is loaded."""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    n_ctx: int = Field(512, alias="n_ctx")
 | 
				
			||||||
 | 
					    """Token context window."""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    n_parts: int = Field(-1, alias="n_parts")
 | 
				
			||||||
 | 
					    """Number of parts to split the model into.
 | 
				
			||||||
 | 
					    If -1, the number of parts is automatically determined."""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    seed: int = Field(-1, alias="seed")
 | 
				
			||||||
 | 
					    """Seed. If -1, a random seed is used."""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    f16_kv: bool = Field(True, alias="f16_kv")
 | 
				
			||||||
 | 
					    """Use half-precision for key/value cache."""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    logits_all: bool = Field(False, alias="logits_all")
 | 
				
			||||||
 | 
					    """Return logits for all tokens, not just the last token."""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    vocab_only: bool = Field(False, alias="vocab_only")
 | 
				
			||||||
 | 
					    """Only load the vocabulary, no weights."""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    use_mlock: bool = Field(False, alias="use_mlock")
 | 
				
			||||||
 | 
					    """Force system to keep model in RAM."""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    n_threads: Optional[int] = Field(2, alias="n_threads")
 | 
				
			||||||
 | 
					    """Number of threads to use."""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    n_batch: Optional[int] = Field(512, alias="n_batch")
 | 
				
			||||||
 | 
					    """Number of tokens to process in parallel.
 | 
				
			||||||
 | 
					    Should be a number between 1 and n_ctx."""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    n_gpu_layers: Optional[int] = Field(0, alias="n_gpu_layers")
 | 
				
			||||||
 | 
					    """Number of layers to be loaded into gpu memory. Default None."""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    suffix: Optional[str] = Field(None)
 | 
				
			||||||
 | 
					    """A suffix to append to the generated text. If None, no suffix is appended."""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    max_tokens: Optional[int] = 256
 | 
				
			||||||
 | 
					    """The maximum number of tokens to generate."""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    temperature: Optional[float] = 0.8
 | 
				
			||||||
 | 
					    """The temperature to use for sampling."""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    top_p: Optional[float] = 0.95
 | 
				
			||||||
 | 
					    """The top-p value to use for sampling."""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    logprobs: Optional[int] = Field(None)
 | 
				
			||||||
 | 
					    """The number of logprobs to return. If None, no logprobs are returned."""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    echo: Optional[bool] = False
 | 
				
			||||||
 | 
					    """Whether to echo the prompt."""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    stop: Optional[List[str]] = []
 | 
				
			||||||
 | 
					    """A list of strings to stop generation when encountered."""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    repeat_penalty: Optional[float] = 1.1
 | 
				
			||||||
 | 
					    """The penalty to apply to repeated tokens."""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    top_k: Optional[int] = 40
 | 
				
			||||||
 | 
					    """The top-k value to use for sampling."""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    last_n_tokens_size: Optional[int] = 64
 | 
				
			||||||
 | 
					    """The number of tokens to look back when applying the repeat_penalty."""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    use_mmap: Optional[bool] = True
 | 
				
			||||||
 | 
					    """Whether to keep the model loaded in RAM"""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    streaming: bool = True
 | 
				
			||||||
 | 
					    """Whether to stream the results, token by token."""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @root_validator()
 | 
				
			||||||
 | 
					    def validate_environment(cls, values: Dict) -> Dict:
 | 
				
			||||||
 | 
					        """Validate that bigdl-llm is installed, family is supported"""  
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        native = values["native"]
 | 
				
			||||||
 | 
					        model_path = values["model_path"]
 | 
				
			||||||
 | 
					        model_kwargs = values["model_kwargs"]
 | 
				
			||||||
 | 
					        kwargs = values["kwargs"]
 | 
				
			||||||
 | 
					        model_param_names = [
 | 
				
			||||||
 | 
					            "lora_path",
 | 
				
			||||||
 | 
					            "lora_base",
 | 
				
			||||||
 | 
					            "n_ctx",
 | 
				
			||||||
 | 
					            "n_parts",
 | 
				
			||||||
 | 
					            "seed",
 | 
				
			||||||
 | 
					            "f16_kv",
 | 
				
			||||||
 | 
					            "logits_all",
 | 
				
			||||||
 | 
					            "vocab_only",
 | 
				
			||||||
 | 
					            "use_mlock",
 | 
				
			||||||
 | 
					            "n_threads",
 | 
				
			||||||
 | 
					            "n_batch",
 | 
				
			||||||
 | 
					            "use_mmap",
 | 
				
			||||||
 | 
					            "last_n_tokens_size",
 | 
				
			||||||
 | 
					        ]
 | 
				
			||||||
 | 
					        model_params = {k: values[k] for k in model_param_names}
 | 
				
			||||||
 | 
					        # For backwards compatibility, only include if non-null.
 | 
				
			||||||
 | 
					        if values["n_gpu_layers"] is not None:
 | 
				
			||||||
 | 
					            model_params["n_gpu_layers"] = values["n_gpu_layers"]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        try:
 | 
				
			||||||
 | 
					            module = importlib.import_module(values["ggml_module"])
 | 
				
			||||||
 | 
					            class_ = getattr(module, values["ggml_model"])
 | 
				
			||||||
 | 
					            if native:
 | 
				
			||||||
 | 
					                values["client"] = class_(model_path, **model_params)
 | 
				
			||||||
 | 
					            else:
 | 
				
			||||||
 | 
					                kwargs = {} if kwargs is None else kwargs
 | 
				
			||||||
 | 
					                values["client"] = TransformersLLM.from_model_id(model_path, model_kwargs,
 | 
				
			||||||
 | 
					                                                                 **kwargs)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        except ImportError:
 | 
				
			||||||
 | 
					            raise ModuleNotFoundError(
 | 
				
			||||||
 | 
					                "Could not import bigdl-llm library. "
 | 
				
			||||||
 | 
					                "Please install the bigdl-llm library to "
 | 
				
			||||||
 | 
					                "use this embedding model: pip install bigdl-llm"
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					        except Exception as e:
 | 
				
			||||||
 | 
					            raise ValueError(
 | 
				
			||||||
 | 
					                f"Could not load model from path: {model_path}. "
 | 
				
			||||||
 | 
					                f"Please make sure the model embedding class matches "
 | 
				
			||||||
 | 
					                "the model you want to load."
 | 
				
			||||||
 | 
					                f"Received error {e}"
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return values
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @property
 | 
				
			||||||
 | 
					    def _default_params(self) -> Dict[str, Any]:
 | 
				
			||||||
 | 
					        """Get the default parameters for calling llama_cpp."""
 | 
				
			||||||
 | 
					        return {
 | 
				
			||||||
 | 
					            "suffix": self.suffix,
 | 
				
			||||||
 | 
					            "max_tokens": self.max_tokens,
 | 
				
			||||||
 | 
					            "temperature": self.temperature,
 | 
				
			||||||
 | 
					            "top_p": self.top_p,
 | 
				
			||||||
 | 
					            "logprobs": self.logprobs,
 | 
				
			||||||
 | 
					            "echo": self.echo,
 | 
				
			||||||
 | 
					            "stop_sequences": self.stop,  # key here is convention among LLM classes
 | 
				
			||||||
 | 
					            "repeat_penalty": self.repeat_penalty,
 | 
				
			||||||
 | 
					            "top_k": self.top_k,
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @property
 | 
				
			||||||
 | 
					    def _identifying_params(self) -> Dict[str, Any]:
 | 
				
			||||||
 | 
					        """Get the identifying parameters."""
 | 
				
			||||||
 | 
					        return {**{"model_path": self.model_path},
 | 
				
			||||||
 | 
					                **self._default_params}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @property
 | 
				
			||||||
 | 
					    def _llm_type(self) -> str:
 | 
				
			||||||
 | 
					        """Return type of llm."""
 | 
				
			||||||
 | 
					        return "BigDL"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def _get_parameters(self, stop: Optional[List[str]] = None) -> Dict[str, Any]:
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        Performs sanity check, preparing parameters in format needed by llama_cpp.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        Args:
 | 
				
			||||||
 | 
					            stop (Optional[List[str]]): List of stop sequences for llama_cpp.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        Returns:
 | 
				
			||||||
 | 
					            Dictionary containing the combined parameters.
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Raise error if stop sequences are in both input and default params
 | 
				
			||||||
 | 
					        if self.stop and stop is not None:
 | 
				
			||||||
 | 
					            raise ValueError("`stop` found in both the input and default params.")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        params = self._default_params
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # llama_cpp expects the "stop" key not this, so we remove it:
 | 
				
			||||||
 | 
					        params.pop("stop_sequences")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # then sets it as configured, or default to an empty list:
 | 
				
			||||||
 | 
					        params["stop"] = self.stop or stop or []
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return params
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def _call(
 | 
				
			||||||
 | 
					        self,
 | 
				
			||||||
 | 
					        prompt: str,
 | 
				
			||||||
 | 
					        stop: Optional[List[str]] = None,
 | 
				
			||||||
 | 
					        run_manager: Optional[CallbackManagerForLLMRun] = None,
 | 
				
			||||||
 | 
					        **kwargs
 | 
				
			||||||
 | 
					    ) -> str:
 | 
				
			||||||
 | 
					        """Call the Llama model and return the output.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        Args:
 | 
				
			||||||
 | 
					            prompt: The prompt to use for generation.
 | 
				
			||||||
 | 
					            stop: A list of strings to stop generation when encountered.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        Returns:
 | 
				
			||||||
 | 
					            The generated text.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        Example:
 | 
				
			||||||
 | 
					            .. code-block:: python
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                from langchain.llms import LlamaLLM
 | 
				
			||||||
 | 
					                llm = LlamaLLM(model_path="/path/to/local/llama/model.bin")
 | 
				
			||||||
 | 
					                llm("This is a prompt.")
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        if self.native:
 | 
				
			||||||
 | 
					            if self.streaming:
 | 
				
			||||||
 | 
					                # If streaming is enabled, we use the stream
 | 
				
			||||||
 | 
					                # method that yields as they are generated
 | 
				
			||||||
 | 
					                # and return the combined strings from the first choices's text:
 | 
				
			||||||
 | 
					                combined_text_output = ""
 | 
				
			||||||
 | 
					                for token in self.stream(prompt=prompt, stop=stop, run_manager=run_manager):
 | 
				
			||||||
 | 
					                    combined_text_output += token["choices"][0]["text"]
 | 
				
			||||||
 | 
					                return combined_text_output
 | 
				
			||||||
 | 
					            else:
 | 
				
			||||||
 | 
					                params = self._get_parameters(stop)
 | 
				
			||||||
 | 
					                result = self.client(prompt=prompt, **params)
 | 
				
			||||||
 | 
					                return result["choices"][0]["text"]
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            return self.client._call(prompt, stop, run_manager, **kwargs)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def stream(
 | 
				
			||||||
 | 
					        self,
 | 
				
			||||||
 | 
					        prompt: str,
 | 
				
			||||||
 | 
					        stop: Optional[List[str]] = None,
 | 
				
			||||||
 | 
					        run_manager: Optional[CallbackManagerForLLMRun] = None,
 | 
				
			||||||
 | 
					    ) -> Generator[Dict, None, None]:
 | 
				
			||||||
 | 
					        """Yields results objects as they are generated in real time.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        BETA: this is a beta feature while we figure out the right abstraction.
 | 
				
			||||||
 | 
					        Once that happens, this interface could change.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        It also calls the callback manager's on_llm_new_token event with
 | 
				
			||||||
 | 
					        similar parameters to the OpenAI LLM class method of the same name.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        Args:
 | 
				
			||||||
 | 
					            prompt: The prompts to pass into the model.
 | 
				
			||||||
 | 
					            stop: Optional list of stop words to use when generating.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        Returns:
 | 
				
			||||||
 | 
					            A generator representing the stream of tokens being generated.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        Yields:
 | 
				
			||||||
 | 
					            A dictionary like objects containing a string token and metadata.
 | 
				
			||||||
 | 
					            See llama-cpp-python docs and below for more.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        Example:
 | 
				
			||||||
 | 
					            .. code-block:: python
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                from langchain.llms import LlamaLLM
 | 
				
			||||||
 | 
					                llm = LlamaLLM(
 | 
				
			||||||
 | 
					                    model_path="/path/to/local/model.bin",
 | 
				
			||||||
 | 
					                    temperature = 0.5
 | 
				
			||||||
 | 
					                )
 | 
				
			||||||
 | 
					                for chunk in llm.stream("Ask 'Hi, how are you?' like a pirate:'",
 | 
				
			||||||
 | 
					                        stop=["'","\\n"]):
 | 
				
			||||||
 | 
					                    result = chunk["choices"][0]
 | 
				
			||||||
 | 
					                    print(result["text"], end='', flush=True)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        params = self._get_parameters(stop)
 | 
				
			||||||
 | 
					        result = self.client(prompt=prompt, stream=True, **params)
 | 
				
			||||||
 | 
					        for chunk in result:
 | 
				
			||||||
 | 
					            token = chunk["choices"][0]["text"]
 | 
				
			||||||
 | 
					            log_probs = chunk["choices"][0].get("logprobs", None)
 | 
				
			||||||
 | 
					            if run_manager:
 | 
				
			||||||
 | 
					                run_manager.on_llm_new_token(
 | 
				
			||||||
 | 
					                    token=token, verbose=self.verbose, log_probs=log_probs
 | 
				
			||||||
 | 
					                )
 | 
				
			||||||
 | 
					            yield chunk
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def get_num_tokens(self, text: str) -> int:
 | 
				
			||||||
 | 
					        tokenized_text = self.client.tokenize(text.encode("utf-8"))
 | 
				
			||||||
 | 
					        return len(tokenized_text)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class LlamaLLM(_BaseCausalLM):
 | 
				
			||||||
 | 
					    ggml_model = "Llama"
 | 
				
			||||||
 | 
					    ggml_module = "bigdl.llm.models"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class BloomLLM(_BaseCausalLM):
 | 
				
			||||||
 | 
					    ggml_model = "Bloom"
 | 
				
			||||||
 | 
					    ggml_module = "bigdl.llm.models"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class GptneoxLLM(_BaseCausalLM):
 | 
				
			||||||
 | 
					    ggml_model = "Gptneox"
 | 
				
			||||||
 | 
					    ggml_module = "bigdl.llm.models"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class ChatGLMLLM(_BaseCausalLM):
 | 
				
			||||||
 | 
					    ggml_model = "ChatGLM"
 | 
				
			||||||
 | 
					    ggml_module = "bigdl.llm.ggml.model.chatglm"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class StarcoderLLM(_BaseCausalLM):
 | 
				
			||||||
 | 
					    ggml_model = "Starcoder"
 | 
				
			||||||
 | 
					    ggml_module = "bigdl.llm.models"
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -14,8 +14,8 @@
 | 
				
			||||||
# limitations under the License.
 | 
					# limitations under the License.
 | 
				
			||||||
#
 | 
					#
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from bigdl.llm.langchain.embeddings import BigdlNativeEmbeddings
 | 
					from bigdl.llm.langchain.embeddings import *
 | 
				
			||||||
from bigdl.llm.langchain.llms import BigdlNativeLLM
 | 
					from bigdl.llm.langchain.llms import *
 | 
				
			||||||
import pytest
 | 
					import pytest
 | 
				
			||||||
from unittest import TestCase
 | 
					from unittest import TestCase
 | 
				
			||||||
import os
 | 
					import os
 | 
				
			||||||
| 
						 | 
					@ -26,6 +26,7 @@ class Test_Models_Basics(TestCase):
 | 
				
			||||||
        self.llama_model_path = os.environ.get('LLAMA_INT4_CKPT_PATH')
 | 
					        self.llama_model_path = os.environ.get('LLAMA_INT4_CKPT_PATH')
 | 
				
			||||||
        self.bloom_model_path = os.environ.get('BLOOM_INT4_CKPT_PATH')
 | 
					        self.bloom_model_path = os.environ.get('BLOOM_INT4_CKPT_PATH')
 | 
				
			||||||
        self.gptneox_model_path = os.environ.get('GPTNEOX_INT4_CKPT_PATH')
 | 
					        self.gptneox_model_path = os.environ.get('GPTNEOX_INT4_CKPT_PATH')
 | 
				
			||||||
 | 
					        self.starcoder_model_path = os.environ.get('STARCODER_INT4_CKPT_PATH')
 | 
				
			||||||
        thread_num = os.environ.get('THREAD_NUM')
 | 
					        thread_num = os.environ.get('THREAD_NUM')
 | 
				
			||||||
        if thread_num is not None:
 | 
					        if thread_num is not None:
 | 
				
			||||||
            self.n_threads = int(thread_num)
 | 
					            self.n_threads = int(thread_num)
 | 
				
			||||||
| 
						 | 
					@ -34,23 +35,35 @@ class Test_Models_Basics(TestCase):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        
 | 
					        
 | 
				
			||||||
    def test_langchain_llm_embedding_llama(self):
 | 
					    def test_langchain_llm_embedding_llama(self):
 | 
				
			||||||
        bigdl_embeddings = BigdlNativeEmbeddings(
 | 
					        bigdl_embeddings = LlamaEmbeddings(
 | 
				
			||||||
            model_path=self.llama_model_path,
 | 
					            model_path=self.llama_model_path)
 | 
				
			||||||
            model_family="llama")
 | 
					 | 
				
			||||||
        text = "This is a test document."
 | 
					        text = "This is a test document."
 | 
				
			||||||
        query_result = bigdl_embeddings.embed_query(text)
 | 
					        query_result = bigdl_embeddings.embed_query(text)
 | 
				
			||||||
        doc_result = bigdl_embeddings.embed_documents([text])
 | 
					        doc_result = bigdl_embeddings.embed_documents([text])
 | 
				
			||||||
    
 | 
					    
 | 
				
			||||||
    def test_langchain_llm_embedding_gptneox(self):
 | 
					    def test_langchain_llm_embedding_gptneox(self):
 | 
				
			||||||
        bigdl_embeddings = BigdlNativeEmbeddings(
 | 
					        bigdl_embeddings = GptneoxEmbeddings(
 | 
				
			||||||
            model_path=self.gptneox_model_path,
 | 
					            model_path=self.gptneox_model_path)
 | 
				
			||||||
            model_family="gptneox")
 | 
					        text = "This is a test document."
 | 
				
			||||||
 | 
					        query_result = bigdl_embeddings.embed_query(text)
 | 
				
			||||||
 | 
					        doc_result = bigdl_embeddings.embed_documents([text])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_langchain_llm_embedding_bloom(self):
 | 
				
			||||||
 | 
					        bigdl_embeddings = BloomEmbeddings(
 | 
				
			||||||
 | 
					            model_path=self.bloom_model_path)
 | 
				
			||||||
 | 
					        text = "This is a test document."
 | 
				
			||||||
 | 
					        query_result = bigdl_embeddings.embed_query(text)
 | 
				
			||||||
 | 
					        doc_result = bigdl_embeddings.embed_documents([text])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_langchain_llm_embedding_starcoder(self):
 | 
				
			||||||
 | 
					        bigdl_embeddings = StarcoderEmbeddings(
 | 
				
			||||||
 | 
					            model_path=self.starcoder_model_path)
 | 
				
			||||||
        text = "This is a test document."
 | 
					        text = "This is a test document."
 | 
				
			||||||
        query_result = bigdl_embeddings.embed_query(text)
 | 
					        query_result = bigdl_embeddings.embed_query(text)
 | 
				
			||||||
        doc_result = bigdl_embeddings.embed_documents([text])
 | 
					        doc_result = bigdl_embeddings.embed_documents([text])
 | 
				
			||||||
        
 | 
					        
 | 
				
			||||||
    def test_langchain_llm_llama(self):
 | 
					    def test_langchain_llm_llama(self):
 | 
				
			||||||
        llm = BigdlNativeLLM(
 | 
					        llm = LlamaLLM(
 | 
				
			||||||
            model_path=self.llama_model_path,
 | 
					            model_path=self.llama_model_path,
 | 
				
			||||||
            max_tokens=32,
 | 
					            max_tokens=32,
 | 
				
			||||||
            n_threads=self.n_threads)
 | 
					            n_threads=self.n_threads)
 | 
				
			||||||
| 
						 | 
					@ -58,18 +71,24 @@ class Test_Models_Basics(TestCase):
 | 
				
			||||||
        result = llm(question)
 | 
					        result = llm(question)
 | 
				
			||||||
        
 | 
					        
 | 
				
			||||||
    def test_langchain_llm_gptneox(self):
 | 
					    def test_langchain_llm_gptneox(self):
 | 
				
			||||||
        llm = BigdlNativeLLM(
 | 
					        llm = GptneoxLLM(
 | 
				
			||||||
            model_path=self.gptneox_model_path,
 | 
					            model_path=self.gptneox_model_path,
 | 
				
			||||||
            model_family="gptneox", 
 | 
					 | 
				
			||||||
            max_tokens=32,
 | 
					            max_tokens=32,
 | 
				
			||||||
            n_threads=self.n_threads)
 | 
					            n_threads=self.n_threads)
 | 
				
			||||||
        question = "What is AI?"
 | 
					        question = "What is AI?"
 | 
				
			||||||
        result = llm(question)
 | 
					        result = llm(question)
 | 
				
			||||||
        
 | 
					        
 | 
				
			||||||
    def test_langchain_llm_bloom(self):
 | 
					    def test_langchain_llm_bloom(self):
 | 
				
			||||||
        llm = BigdlNativeLLM(
 | 
					        llm = BloomLLM(
 | 
				
			||||||
            model_path=self.bloom_model_path,
 | 
					            model_path=self.bloom_model_path,
 | 
				
			||||||
            model_family="bloom",
 | 
					            max_tokens=32,
 | 
				
			||||||
 | 
					            n_threads=self.n_threads)
 | 
				
			||||||
 | 
					        question = "What is AI?"
 | 
				
			||||||
 | 
					        result = llm(question)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_langchain_llm_starcoder(self):
 | 
				
			||||||
 | 
					        llm = StarcoderLLM(
 | 
				
			||||||
 | 
					            model_path=self.starcoder_model_path,
 | 
				
			||||||
            max_tokens=32,
 | 
					            max_tokens=32,
 | 
				
			||||||
            n_threads=self.n_threads)
 | 
					            n_threads=self.n_threads)
 | 
				
			||||||
        question = "What is AI?"
 | 
					        question = "What is AI?"
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -14,8 +14,10 @@
 | 
				
			||||||
# limitations under the License.
 | 
					# limitations under the License.
 | 
				
			||||||
#
 | 
					#
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from bigdl.llm.langchain.llms import TransformersLLM, TransformersPipelineLLM
 | 
					from bigdl.llm.langchain.llms import TransformersLLM, TransformersPipelineLLM, \
 | 
				
			||||||
from bigdl.llm.langchain.embeddings import TransformersEmbeddings
 | 
					    LlamaLLM, BloomLLM
 | 
				
			||||||
 | 
					from bigdl.llm.langchain.embeddings import TransformersEmbeddings, LlamaEmbeddings, \
 | 
				
			||||||
 | 
					    BloomEmbeddings
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from langchain.document_loaders import WebBaseLoader
 | 
					from langchain.document_loaders import WebBaseLoader
 | 
				
			||||||
| 
						 | 
					@ -37,12 +39,15 @@ class Test_Langchain_Transformers_API(TestCase):
 | 
				
			||||||
    def setUp(self):
 | 
					    def setUp(self):
 | 
				
			||||||
        self.auto_model_path = os.environ.get('ORIGINAL_CHATGLM2_6B_PATH')
 | 
					        self.auto_model_path = os.environ.get('ORIGINAL_CHATGLM2_6B_PATH')
 | 
				
			||||||
        self.auto_causal_model_path = os.environ.get('ORIGINAL_REPLIT_CODE_PATH')
 | 
					        self.auto_causal_model_path = os.environ.get('ORIGINAL_REPLIT_CODE_PATH')
 | 
				
			||||||
 | 
					        self.llama_model_path = os.environ.get('LLAMA_ORIGIN_PATH')
 | 
				
			||||||
 | 
					        self.bloom_model_path = os.environ.get('BLOOM_ORIGIN_PATH')
 | 
				
			||||||
        thread_num = os.environ.get('THREAD_NUM')
 | 
					        thread_num = os.environ.get('THREAD_NUM')
 | 
				
			||||||
        if thread_num is not None:
 | 
					        if thread_num is not None:
 | 
				
			||||||
            self.n_threads = int(thread_num)
 | 
					            self.n_threads = int(thread_num)
 | 
				
			||||||
        else:
 | 
					        else:
 | 
				
			||||||
            self.n_threads = 2         
 | 
					            self.n_threads = 2         
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def test_pipeline_llm(self):
 | 
					    def test_pipeline_llm(self):
 | 
				
			||||||
        texts = 'def hello():\n  print("hello world")\n'
 | 
					        texts = 'def hello():\n  print("hello world")\n'
 | 
				
			||||||
        bigdl_llm = TransformersPipelineLLM.from_model_id(model_id=self.auto_causal_model_path, task='text-generation', model_kwargs={'trust_remote_code': True})
 | 
					        bigdl_llm = TransformersPipelineLLM.from_model_id(model_id=self.auto_causal_model_path, task='text-generation', model_kwargs={'trust_remote_code': True})
 | 
				
			||||||
| 
						 | 
					@ -52,16 +57,36 @@ class Test_Langchain_Transformers_API(TestCase):
 | 
				
			||||||
        self.assertTrue(res)
 | 
					        self.assertTrue(res)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_causalLM_embeddings(self):
 | 
				
			||||||
 | 
					        bigdl_embeddings = BloomEmbeddings(model_path=self.bloom_model_path, model_kwargs={'trust_remote_code': True}, native=False)
 | 
				
			||||||
 | 
					        text = "This is a test document."
 | 
				
			||||||
 | 
					        query_result = bigdl_embeddings.embed_query(text)
 | 
				
			||||||
 | 
					        doc_result = bigdl_embeddings.embed_documents([text])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        bigdl_llm = BloomLLM(model_path=self.bloom_model_path, model_kwargs={'trust_remote_code': True}, native=False)
 | 
				
			||||||
 | 
					        res = bigdl_llm(text)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    def test_transformers_llama_embeddings(self):
 | 
				
			||||||
 | 
					        bigdl_embeddings = TransformersEmbeddings.from_model_id(model_id=self.llama_model_path, model_kwargs={'trust_remote_code': True})
 | 
				
			||||||
 | 
					        text = "This is a test document."
 | 
				
			||||||
 | 
					        query_result = bigdl_embeddings.embed_query(text)
 | 
				
			||||||
 | 
					        doc_result = bigdl_embeddings.embed_documents([text])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        bigdl_llm = TransformersLLM.from_model_id(model_id=self.llama_model_path, model_kwargs={'trust_remote_code': True})
 | 
				
			||||||
 | 
					        res = bigdl_llm(text)
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def test_qa_chain(self):
 | 
					    def test_qa_chain(self):
 | 
				
			||||||
        texts = '''
 | 
					        texts = '''
 | 
				
			||||||
AI is a machine’s ability to perform the cognitive functions 
 | 
					            AI is a machine’s ability to perform the cognitive functions 
 | 
				
			||||||
we associate with human minds, such as perceiving, reasoning, 
 | 
					            we associate with human minds, such as perceiving, reasoning, 
 | 
				
			||||||
learning, interacting with an environment, problem solving,
 | 
					            learning, interacting with an environment, problem solving,
 | 
				
			||||||
and even exercising creativity. You’ve probably interacted 
 | 
					            and even exercising creativity. You’ve probably interacted 
 | 
				
			||||||
with AI even if you didn’t realize it—voice assistants like Siri 
 | 
					            with AI even if you didn’t realize it—voice assistants like Siri 
 | 
				
			||||||
and Alexa are founded on AI technology, as are some customer 
 | 
					            and Alexa are founded on AI technology, as are some customer 
 | 
				
			||||||
service chatbots that pop up to help you navigate websites.
 | 
					            service chatbots that pop up to help you navigate websites.
 | 
				
			||||||
        '''
 | 
					            '''
 | 
				
			||||||
        text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
 | 
					        text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
 | 
				
			||||||
        texts = text_splitter.split_text(texts)
 | 
					        texts = text_splitter.split_text(texts)
 | 
				
			||||||
        query = 'What is AI?'
 | 
					        query = 'What is AI?'
 | 
				
			||||||
| 
						 | 
					@ -77,5 +102,34 @@ service chatbots that pop up to help you navigate websites.
 | 
				
			||||||
        res = "AI" in output
 | 
					        res = "AI" in output
 | 
				
			||||||
        self.assertTrue(res)
 | 
					        self.assertTrue(res)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    def test_qa_chain_causalLM(self):
 | 
				
			||||||
 | 
					        texts = '''
 | 
				
			||||||
 | 
					            AI is a machine’s ability to perform the cognitive functions 
 | 
				
			||||||
 | 
					            we associate with human minds, such as perceiving, reasoning, 
 | 
				
			||||||
 | 
					            learning, interacting with an environment, problem solving,
 | 
				
			||||||
 | 
					            and even exercising creativity. You’ve probably interacted 
 | 
				
			||||||
 | 
					            with AI even if you didn’t realize it—voice assistants like Siri 
 | 
				
			||||||
 | 
					            and Alexa are founded on AI technology, as are some customer 
 | 
				
			||||||
 | 
					            service chatbots that pop up to help you navigate websites.
 | 
				
			||||||
 | 
					            '''
 | 
				
			||||||
 | 
					        text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
 | 
				
			||||||
 | 
					        texts = text_splitter.split_text(texts)
 | 
				
			||||||
 | 
					        query = 'What is AI?'
 | 
				
			||||||
 | 
					        embeddings = LlamaEmbeddings(model_path=self.llama_model_path, model_kwargs={'trust_remote_code': True}, native=False)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        docsearch = Chroma.from_texts(texts, embeddings, metadatas=[{"source": str(i)} for i in range(len(texts))]).as_retriever()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        #get relavant texts
 | 
				
			||||||
 | 
					        docs = docsearch.get_relevant_documents(query)
 | 
				
			||||||
 | 
					        bigdl_llm = LlamaLLM(model_path=self.llama_model_path, model_kwargs={'trust_remote_code': True}, native=False)
 | 
				
			||||||
 | 
					        doc_chain = load_qa_chain(bigdl_llm, chain_type="stuff", prompt=QA_PROMPT)
 | 
				
			||||||
 | 
					        output = doc_chain.run(input_documents=docs, question=query)
 | 
				
			||||||
 | 
					        res = "AI" in output
 | 
				
			||||||
 | 
					        self.assertTrue(res)
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
if __name__ == '__main__':
 | 
					if __name__ == '__main__':
 | 
				
			||||||
    pytest.main([__file__])
 | 
					    pytest.main([__file__])
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in a new issue