add langchain gpu example (#10277)
* first draft * fix * add readme for transformer_int4_gpu * fix doc * check device_map * add arc ut test * fix ut test * fix langchain ut * Refine README * fix gpu mem too high * fix ut test --------- Co-authored-by: Ariadne <wyn2000330@126.com>
This commit is contained in:
		
							parent
							
								
									5dbbe1a826
								
							
						
					
					
						commit
						fc7f10cd12
					
				
					 7 changed files with 255 additions and 0 deletions
				
			
		
							
								
								
									
										14
									
								
								.github/workflows/llm_unit_tests.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										14
									
								
								.github/workflows/llm_unit_tests.yml
									
									
									
									
										vendored
									
									
								
							| 
						 | 
				
			
			@ -346,3 +346,17 @@ jobs:
 | 
			
		|||
            source /home/arda/intel/oneapi/setvars.sh
 | 
			
		||||
          fi
 | 
			
		||||
          bash python/llm/test/run-llm-example-tests-gpu.sh
 | 
			
		||||
 | 
			
		||||
      - name: Run LLM langchain test
 | 
			
		||||
        shell: bash
 | 
			
		||||
        run: |
 | 
			
		||||
          pip install -U langchain==0.0.184
 | 
			
		||||
          pip install -U chromadb==0.3.25
 | 
			
		||||
          pip install -U pandas==2.0.3
 | 
			
		||||
          # Specific oneapi position on arc ut test machines
 | 
			
		||||
          if [[ '${{ matrix.pytorch-version }}' == '2.1' ]]; then
 | 
			
		||||
            source /opt/intel/oneapi/setvars.sh
 | 
			
		||||
          elif [[ '${{ matrix.pytorch-version }}' == '2.0' ]]; then
 | 
			
		||||
            source /home/arda/intel/oneapi/setvars.sh
 | 
			
		||||
          fi
 | 
			
		||||
          bash python/llm/test/run-llm-langchain-tests-gpu.sh
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -0,0 +1,93 @@
 | 
			
		|||
# Langchain examples
 | 
			
		||||
 | 
			
		||||
The examples in this folder shows how to use [LangChain](https://www.langchain.com/) with `bigdl-llm` on Intel GPU.
 | 
			
		||||
 | 
			
		||||
### 1. Install bigdl-llm
 | 
			
		||||
Follow the instructions in [GPU Install Guide](https://bigdl.readthedocs.io/en/latest/doc/LLM/Overview/install_gpu.html) to install bigdl-llm
 | 
			
		||||
 | 
			
		||||
### 2. Install Required Dependencies for langchain examples. 
 | 
			
		||||
 | 
			
		||||
```bash
 | 
			
		||||
pip install langchain==0.0.184
 | 
			
		||||
pip install -U chromadb==0.3.25
 | 
			
		||||
pip install -U pandas==2.0.3
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
### 3. Configures OneAPI environment variables
 | 
			
		||||
#### 3.1 Configurations for Linux
 | 
			
		||||
```bash
 | 
			
		||||
source /opt/intel/oneapi/setvars.sh
 | 
			
		||||
```
 | 
			
		||||
#### 3.2 Configurations for Windows
 | 
			
		||||
```cmd
 | 
			
		||||
call "C:\Program Files (x86)\Intel\oneAPI\setvars.bat"
 | 
			
		||||
```
 | 
			
		||||
> Note: Please make sure you are using **CMD** (**Anaconda Prompt** if using conda) to run the command as PowerShell is not supported.
 | 
			
		||||
### 4. Runtime Configurations
 | 
			
		||||
For optimal performance, it is recommended to set several environment variables. Please check out the suggestions based on your device.
 | 
			
		||||
#### 4.1 Configurations for Linux
 | 
			
		||||
<details>
 | 
			
		||||
 | 
			
		||||
<summary>For Intel Arc™ A-Series Graphics and Intel Data Center GPU Flex Series</summary>
 | 
			
		||||
 | 
			
		||||
```bash
 | 
			
		||||
export USE_XETLA=OFF
 | 
			
		||||
export SYCL_PI_LEVEL_ZERO_USE_IMMEDIATE_COMMANDLISTS=1
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
</details>
 | 
			
		||||
 | 
			
		||||
<details>
 | 
			
		||||
 | 
			
		||||
<summary>For Intel Data Center GPU Max Series</summary>
 | 
			
		||||
 | 
			
		||||
```bash
 | 
			
		||||
export LD_PRELOAD=${LD_PRELOAD}:${CONDA_PREFIX}/lib/libtcmalloc.so
 | 
			
		||||
export SYCL_PI_LEVEL_ZERO_USE_IMMEDIATE_COMMANDLISTS=1
 | 
			
		||||
export ENABLE_SDP_FUSION=1
 | 
			
		||||
```
 | 
			
		||||
> Note: Please note that `libtcmalloc.so` can be installed by `conda install -c conda-forge -y gperftools=2.10`.
 | 
			
		||||
</details>
 | 
			
		||||
 | 
			
		||||
#### 4.2 Configurations for Windows
 | 
			
		||||
<details>
 | 
			
		||||
 | 
			
		||||
<summary>For Intel iGPU</summary>
 | 
			
		||||
 | 
			
		||||
```cmd
 | 
			
		||||
set SYCL_CACHE_PERSISTENT=1
 | 
			
		||||
set BIGDL_LLM_XMX_DISABLED=1
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
</details>
 | 
			
		||||
 | 
			
		||||
<details>
 | 
			
		||||
 | 
			
		||||
<summary>For Intel Arc™ A300-Series or Pro A60</summary>
 | 
			
		||||
 | 
			
		||||
```cmd
 | 
			
		||||
set SYCL_CACHE_PERSISTENT=1
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
</details>
 | 
			
		||||
 | 
			
		||||
<details>
 | 
			
		||||
 | 
			
		||||
<summary>For other Intel dGPU Series</summary>
 | 
			
		||||
 | 
			
		||||
There is no need to set further environment variables.
 | 
			
		||||
 | 
			
		||||
</details>
 | 
			
		||||
 | 
			
		||||
> Note: For the first time that each model runs on Intel iGPU/Intel Arc™ A300-Series or Pro A60, it may take several minutes to compile.
 | 
			
		||||
 | 
			
		||||
### 5. Run the examples
 | 
			
		||||
 | 
			
		||||
#### 5.1. Streaming Chat
 | 
			
		||||
 | 
			
		||||
```bash
 | 
			
		||||
python chat.py -m MODEL_PATH -q QUESTION
 | 
			
		||||
```
 | 
			
		||||
arguments info:
 | 
			
		||||
- `-m MODEL_PATH`: **required**, path to the model
 | 
			
		||||
- `-q QUESTION`: question to ask. Default is `What is AI?`.
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,65 @@
 | 
			
		|||
#
 | 
			
		||||
# Copyright 2016 The BigDL Authors.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
#
 | 
			
		||||
 | 
			
		||||
# This would makes sure Python is aware there is more than one sub-package within bigdl,
 | 
			
		||||
# physically located elsewhere.
 | 
			
		||||
# Otherwise there would be module not found error in non-pip's setting as Python would
 | 
			
		||||
# only search the first bigdl package and end up finding only one sub-package.
 | 
			
		||||
 | 
			
		||||
import argparse
 | 
			
		||||
 | 
			
		||||
from bigdl.llm.langchain.llms import TransformersLLM, TransformersPipelineLLM
 | 
			
		||||
from langchain import PromptTemplate, LLMChain
 | 
			
		||||
from langchain import HuggingFacePipeline
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def main(args):
 | 
			
		||||
    
 | 
			
		||||
    question = args.question
 | 
			
		||||
    model_path = args.model_path
 | 
			
		||||
    template ="""{question}"""
 | 
			
		||||
 | 
			
		||||
    prompt = PromptTemplate(template=template, input_variables=["question"])
 | 
			
		||||
 | 
			
		||||
    # llm = TransformersPipelineLLM.from_model_id(
 | 
			
		||||
    #     model_id=model_path,
 | 
			
		||||
    #     task="text-generation",
 | 
			
		||||
    #     model_kwargs={"temperature": 0, "max_length": 64, "trust_remote_code": True},
 | 
			
		||||
    #     device_map='xpu'
 | 
			
		||||
    # )
 | 
			
		||||
 | 
			
		||||
    llm = TransformersLLM.from_model_id(
 | 
			
		||||
        model_id=model_path,
 | 
			
		||||
        model_kwargs={"temperature": 0, "max_length": 64, "trust_remote_code": True},
 | 
			
		||||
        device_map='xpu'
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    llm_chain = LLMChain(prompt=prompt, llm=llm)
 | 
			
		||||
 | 
			
		||||
    output = llm_chain.run(question)
 | 
			
		||||
    print("====output=====")
 | 
			
		||||
    print(output)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == '__main__':
 | 
			
		||||
    parser = argparse.ArgumentParser(description='TransformersLLM Langchain Chat Example')
 | 
			
		||||
    parser.add_argument('-m','--model-path', type=str, required=True,
 | 
			
		||||
                        help='the path to transformers model')
 | 
			
		||||
    parser.add_argument('-q', '--question', type=str, default='What is AI?',
 | 
			
		||||
                        help='qustion you want to ask.')
 | 
			
		||||
    args = parser.parse_args()
 | 
			
		||||
    
 | 
			
		||||
    main(args)
 | 
			
		||||
| 
						 | 
				
			
			@ -7,6 +7,7 @@ This folder contains examples of running BigDL-LLM on Intel GPU:
 | 
			
		|||
- [LLM-Finetuning](LLM-Finetuning): running ***finetuning*** (such as LoRA, QLoRA, QA-LoRA, etc) using BigDL-LLM on Intel GPUs
 | 
			
		||||
- [vLLM-Serving](vLLM-Serving): running ***vLLM*** serving framework on intel GPUs (with BigDL-LLM low-bit optimized models)
 | 
			
		||||
- [Deepspeed-AutoTP](Deepspeed-AutoTP): running distributed inference using ***DeepSpeed AutoTP*** (with BigDL-LLM low-bit optimized models) on Intel GPUs
 | 
			
		||||
- [LangChain](LangChain): running ***LangChain*** applications on BigDL-LLM
 | 
			
		||||
- [PyTorch-Models](PyTorch-Models): running any PyTorch model on BigDL-LLM (with "one-line code change")
 | 
			
		||||
- [Speculative-Decoding](Speculative-Decoding): running any ***Hugging Face Transformers*** model with ***self-speculative decoding*** on Intel GPUs
 | 
			
		||||
- [ModelScope-Models](ModelScope-Models): running ***ModelScope*** model with BigDL-LLM on Intel GPUs
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -89,6 +89,7 @@ class TransformersLLM(LLM):
 | 
			
		|||
        cls,
 | 
			
		||||
        model_id: str,
 | 
			
		||||
        model_kwargs: Optional[dict] = None,
 | 
			
		||||
        device_map: str = 'cpu',
 | 
			
		||||
        **kwargs: Any,
 | 
			
		||||
    ) -> LLM:
 | 
			
		||||
        """
 | 
			
		||||
| 
						 | 
				
			
			@ -131,6 +132,10 @@ class TransformersLLM(LLM):
 | 
			
		|||
        except:
 | 
			
		||||
            model = AutoModel.from_pretrained(model_id, load_in_4bit=True, **_model_kwargs)
 | 
			
		||||
 | 
			
		||||
        # TODO: may refactore this code in the future
 | 
			
		||||
        if 'xpu' in device_map:
 | 
			
		||||
            model = model.to(device_map)
 | 
			
		||||
 | 
			
		||||
        if "trust_remote_code" in _model_kwargs:
 | 
			
		||||
            _model_kwargs = {
 | 
			
		||||
                k: v for k, v in _model_kwargs.items() if k != "trust_remote_code"
 | 
			
		||||
| 
						 | 
				
			
			@ -149,6 +154,7 @@ class TransformersLLM(LLM):
 | 
			
		|||
        cls,
 | 
			
		||||
        model_id: str,
 | 
			
		||||
        model_kwargs: Optional[dict] = None,
 | 
			
		||||
        device_map: str = 'cpu',
 | 
			
		||||
        **kwargs: Any,
 | 
			
		||||
    ) -> LLM:
 | 
			
		||||
        """
 | 
			
		||||
| 
						 | 
				
			
			@ -188,6 +194,10 @@ class TransformersLLM(LLM):
 | 
			
		|||
            model = AutoModelForCausalLM.load_low_bit(model_id, **_model_kwargs)
 | 
			
		||||
        except:
 | 
			
		||||
            model = AutoModel.load_low_bit(model_id, **_model_kwargs)
 | 
			
		||||
        
 | 
			
		||||
        # TODO: may refactore this code in the future
 | 
			
		||||
        if 'xpu' in device_map:
 | 
			
		||||
            model = model.to(device_map)
 | 
			
		||||
 | 
			
		||||
        if "trust_remote_code" in _model_kwargs:
 | 
			
		||||
            _model_kwargs = {
 | 
			
		||||
| 
						 | 
				
			
			@ -224,6 +234,7 @@ class TransformersLLM(LLM):
 | 
			
		|||
        if self.streaming:
 | 
			
		||||
            from transformers import TextStreamer
 | 
			
		||||
            input_ids = self.tokenizer.encode(prompt, return_tensors="pt")
 | 
			
		||||
            input_ids = input_ids.to(self.model.device)
 | 
			
		||||
            streamer = TextStreamer(self.tokenizer, skip_prompt=True, skip_special_tokens=True)
 | 
			
		||||
            if stop is not None:
 | 
			
		||||
                from transformers.generation.stopping_criteria import StoppingCriteriaList
 | 
			
		||||
| 
						 | 
				
			
			@ -240,6 +251,7 @@ class TransformersLLM(LLM):
 | 
			
		|||
            return text
 | 
			
		||||
        else:
 | 
			
		||||
            input_ids = self.tokenizer.encode(prompt, return_tensors="pt")
 | 
			
		||||
            input_ids = input_ids.to(self.model.device)
 | 
			
		||||
            if stop is not None:
 | 
			
		||||
                from transformers.generation.stopping_criteria import StoppingCriteriaList
 | 
			
		||||
                from transformers.tools.agents import StopSequenceCriteria
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
							
								
								
									
										48
									
								
								python/llm/test/langchain_gpu/test_transformers_api.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										48
									
								
								python/llm/test/langchain_gpu/test_transformers_api.py
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
				
			
			@ -0,0 +1,48 @@
 | 
			
		|||
#
 | 
			
		||||
# Copyright 2016 The BigDL Authors.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
#
 | 
			
		||||
 | 
			
		||||
from bigdl.llm.langchain.llms import TransformersLLM, TransformersPipelineLLM, \
 | 
			
		||||
    LlamaLLM, BloomLLM
 | 
			
		||||
from bigdl.llm.langchain.embeddings import TransformersEmbeddings, LlamaEmbeddings, \
 | 
			
		||||
    BloomEmbeddings
 | 
			
		||||
 | 
			
		||||
import pytest
 | 
			
		||||
from unittest import TestCase
 | 
			
		||||
import os
 | 
			
		||||
 | 
			
		||||
device = os.environ['DEVICE']
 | 
			
		||||
print(f'Running on {device}')
 | 
			
		||||
 | 
			
		||||
class Test_Langchain_Transformers_API(TestCase):
 | 
			
		||||
    def setUp(self):
 | 
			
		||||
        self.llama_model_path = os.environ.get('LLAMA2_7B_ORIGIN_PATH')
 | 
			
		||||
        thread_num = os.environ.get('THREAD_NUM')
 | 
			
		||||
        if thread_num is not None:
 | 
			
		||||
            self.n_threads = int(thread_num)
 | 
			
		||||
        else:
 | 
			
		||||
            self.n_threads = 2         
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    def test_bigdl_llm(self):
 | 
			
		||||
        texts = 'What is the capital of France?\n\n'
 | 
			
		||||
        bigdl_llm = TransformersLLM.from_model_id(model_id=self.llama_model_path, model_kwargs={'trust_remote_code': True}, device_map=device)
 | 
			
		||||
        
 | 
			
		||||
        output = bigdl_llm(texts)
 | 
			
		||||
        res = "Paris" in output
 | 
			
		||||
        self.assertTrue(res)
 | 
			
		||||
 | 
			
		||||
if __name__ == '__main__':
 | 
			
		||||
    pytest.main([__file__])
 | 
			
		||||
							
								
								
									
										22
									
								
								python/llm/test/run-llm-langchain-tests-gpu.sh
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										22
									
								
								python/llm/test/run-llm-langchain-tests-gpu.sh
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
				
			
			@ -0,0 +1,22 @@
 | 
			
		|||
#!/bin/bash
 | 
			
		||||
 | 
			
		||||
export ANALYTICS_ZOO_ROOT=${ANALYTICS_ZOO_ROOT}
 | 
			
		||||
export LLM_HOME=${ANALYTICS_ZOO_ROOT}/python/llm/src
 | 
			
		||||
export LLM_INFERENCE_TEST_DIR=${ANALYTICS_ZOO_ROOT}/python/llm/test/langchain_gpu
 | 
			
		||||
 | 
			
		||||
export USE_XETLA=OFF
 | 
			
		||||
export SYCL_PI_LEVEL_ZERO_USE_IMMEDIATE_COMMANDLISTS=1
 | 
			
		||||
export DEVICE='xpu'
 | 
			
		||||
 | 
			
		||||
set -e
 | 
			
		||||
 | 
			
		||||
echo "# Start testing inference"
 | 
			
		||||
start=$(date "+%s")
 | 
			
		||||
 | 
			
		||||
python -m pytest -s ${LLM_INFERENCE_TEST_DIR}
 | 
			
		||||
 | 
			
		||||
now=$(date "+%s")
 | 
			
		||||
time=$((now-start))
 | 
			
		||||
 | 
			
		||||
echo "Bigdl-llm langchain gpu tests finished"
 | 
			
		||||
echo "Time used:$time seconds"
 | 
			
		||||
		Loading…
	
		Reference in a new issue