add langchain vllm interface (#11121)

* done

* fix

* fix

* add vllm

* add langchain vllm exampels

* add docs

* temp
This commit is contained in:
Guancheng Fu 2024-05-24 17:19:27 +08:00 committed by GitHub
parent 63e95698eb
commit fabc395d0d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 371 additions and 18 deletions

View file

@ -82,7 +82,7 @@ If the service have booted successfully, you should see the output similar to th
vLLM supports to utilize multiple cards through tensor parallel.
You can refer to this [documentation](https://ipex-llm.readthedocs.io/en/latest/doc/LLM/Quickstart/vLLM_quickstart.html#about-tensor-paralle) on how to utilize the `tensor-parallel` feature and start the service.
You can refer to this [documentation](https://ipex-llm.readthedocs.io/en/latest/doc/LLM/Quickstart/vLLM_quickstart.html#about-tensor-parallel) on how to utilize the `tensor-parallel` feature and start the service.
#### Verify
After the service has been booted successfully, you can send a test request using `curl`. Here, `YOUR_MODEL` should be set equal to `served_model_name` in your booting script, e.g. `Qwen1.5`.

View file

@ -5,15 +5,7 @@ The examples in this folder shows how to use [LangChain](https://www.langchain.c
### 1. Install ipex-llm
Follow the instructions in [GPU Install Guide](https://ipex-llm.readthedocs.io/en/latest/doc/LLM/Overview/install_gpu.html) to install ipex-llm
### 2. Install Required Dependencies for langchain examples.
```bash
pip install langchain==0.0.184
pip install -U chromadb==0.3.25
pip install -U pandas==2.0.3
```
### 3. Configures OneAPI environment variables for Linux
### 2. Configures OneAPI environment variables for Linux
> [!NOTE]
> Skip this step if you are running on Windows.
@ -24,9 +16,9 @@ This is a required step on Linux for APT or offline installed oneAPI. Skip this
source /opt/intel/oneapi/setvars.sh
```
### 4. Runtime Configurations
### 3. Runtime Configurations
For optimal performance, it is recommended to set several environment variables. Please check out the suggestions based on your device.
#### 4.1 Configurations for Linux
#### 3.1 Configurations for Linux
<details>
<summary>For Intel Arc™ A-Series Graphics and Intel Data Center GPU Flex Series</summary>
@ -63,7 +55,7 @@ export BIGDL_LLM_XMX_DISABLED=1
</details>
#### 4.2 Configurations for Windows
#### 3.2 Configurations for Windows
<details>
<summary>For Intel iGPU</summary>
@ -88,9 +80,18 @@ set SYCL_CACHE_PERSISTENT=1
> [!NOTE]
> For the first time that each model runs on Intel iGPU/Intel Arc™ A300-Series or Pro A60, it may take several minutes to compile.
### 5. Run the examples
### 4. Run the examples
#### 5.1. Streaming Chat
#### 4.1. Streaming Chat
Install dependencies:
```bash
pip install langchain==0.0.184
pip install -U pandas==2.0.3
```
Then execute:
```bash
python chat.py -m MODEL_PATH -q QUESTION
@ -99,7 +100,16 @@ arguments info:
- `-m MODEL_PATH`: **required**, path to the model
- `-q QUESTION`: question to ask. Default is `What is AI?`.
#### 5.2. RAG (Retrival Augmented Generation)
#### 4.2. RAG (Retrival Augmented Generation)
Install dependencies:
```bash
pip install langchain==0.0.184
pip install -U chromadb==0.3.25
pip install -U pandas==2.0.3
```
Then execute:
```bash
python rag.py -m <path_to_model> [-q QUESTION] [-i INPUT_PATH]
@ -110,16 +120,65 @@ arguments info:
- `-i INPUT_PATH`: path to the input doc.
#### 5.2. Low Bit
#### 4.3. Low Bit
The low_bit example ([low_bit.py](./low_bit.py)) showcases how to use use langchain with low_bit optimized model.
By `save_low_bit` we save the weights of low_bit model into the target folder.
> Note: `save_low_bit` only saves the weights of the model.
> Users could copy the tokenizer model into the target folder or specify `tokenizer_id` during initialization.
Install dependencies:
```bash
pip install langchain==0.0.184
pip install -U pandas==2.0.3
```
Then execute:
```bash
python low_bit.py -m <path_to_model> -t <path_to_target> [-q <your question>]
```
**Runtime Arguments Explained**:
- `-m MODEL_PATH`: **Required**, the path to the model
- `-t TARGET_PATH`: **Required**, the path to save the low_bit model
- `-q QUESTION`: the question
- `-q QUESTION`: the question
#### 4.4 vLLM
The vLLM example ([vllm.py](./vllm.py)) showcases how to use langchain with ipex-llm integrated vLLM engine.
Install dependencies:
```bash
pip install "langchain<0.2"
```
Besides, you should also install IPEX-LLM integrated vLLM according instructions listed [here](https://ipex-llm.readthedocs.io/en/latest/doc/LLM/Quickstart/vLLM_quickstart.html#install-vllm)
**Runtime Arguments Explained**:
- `-m MODEL_PATH`: **Required**, the path to the model
- `-q QUESTION`: the question
- `-t MAX_TOKENS`: max tokens to generate, default 128
- `-p TENSOR_PARALLEL_SIZE`: Use multiple cards for generation
- `-l LOAD_IN_LOW_BIT`: Low bit format for quantization
##### Single card
The following command shows an example on how to execute the example using one card:
```bash
python ./vllm.py -m YOUR_MODEL_PATH -q "What is AI?" -t 128 -p 1 -l sym_int4
```
##### Multi cards
To use `-p TENSOR_PARALLEL_SIZE` option, you will need to use our docker image: `intelanalytics/ipex-llm-serving-xpu:latest`. For how to use the image, try check this [guide](https://ipex-llm.readthedocs.io/en/latest/doc/LLM/DockerGuides/vllm_docker_quickstart.html#multi-card-serving).
The following command shows an example on how to execute the example using two cards:
```bash
export CCL_WORKER_COUNT=2
export FI_PROVIDER=shm
export CCL_ATL_TRANSPORT=ofi
export CCL_ZE_IPC_EXCHANGE=sockets
export CCL_ATL_SHM=1
python ./vllm.py -m YOUR_MODEL_PATH -q "What is AI?" -t 128 -p 2 -l sym_int4
```

View file

@ -0,0 +1,45 @@
from ipex_llm.langchain.vllm.vllm import VLLM
from langchain.chains import LLMChain
from langchain_core.prompts import PromptTemplate
import argparse
def main(args):
llm = VLLM(
model=args.model_path,
trust_remote_code=True, # mandatory for hf models
max_new_tokens=128,
top_k=10,
top_p=0.95,
temperature=0.8,
max_model_len=2048,
enforce_eager=True,
load_in_low_bit=args.load_in_low_bit,
device="xpu",
tensor_parallel_size=args.tensor_parallel_size,
)
print(llm.invoke(args.question))
template = """Question: {question}
Answer: Let's think step by step."""""
prompt = PromptTemplate.from_template(template)
llm_chain = LLMChain(prompt=prompt, llm=llm)
print(llm_chain.invoke("Who was the US president in the year the first Pokemon game was released?"))
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Langchain integrated vLLM example')
parser.add_argument('-m','--model-path', type=str, required=True,
help='the path to transformers model')
parser.add_argument('-q', '--question', type=str, default='What is the capital of France?', help='qustion you want to ask.')
parser.add_argument('-t', '--max-tokens', type=int, default=128, help='max tokens to generate')
parser.add_argument('-p', '--tensor-parallel-size', type=int, default=1, help="vLLM tensor parallel size")
parser.add_argument('-l', '--load-in-low-bit', type=str, default='sym_int4', help="low bit format")
args = parser.parse_args()
main(args)

View file

@ -0,0 +1,20 @@
#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# This would makes sure Python is aware there is more than one sub-package within bigdl,
# physically located elsewhere.
# Otherwise there would be module not found error in non-pip's setting as Python would
# only search the first bigdl package and end up finding only one sub-package.

View file

@ -0,0 +1,229 @@
#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# This file is adapted from
# https://github.com/hwchase17/langchain/blob/master/langchain/llms/llamacpp.py
# The MIT License
# Copyright (c) Harrison Chase
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
from typing import Any, Dict, List, Optional
from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.language_models.llms import BaseLLM
from langchain_core.outputs import Generation, LLMResult
from langchain_core.pydantic_v1 import Field, root_validator
from langchain_community.llms.openai import BaseOpenAI
from langchain_community.utils.openai import is_openai_v1
class VLLM(BaseLLM):
"""VLLM language model."""
model: str = ""
"""The name or path of a HuggingFace Transformers model."""
tensor_parallel_size: Optional[int] = 1
"""The number of GPUs to use for distributed execution with tensor parallelism."""
trust_remote_code: Optional[bool] = False
"""Trust remote code (e.g., from HuggingFace) when downloading the model
and tokenizer."""
n: int = 1
"""Number of output sequences to return for the given prompt."""
best_of: Optional[int] = None
"""Number of output sequences that are generated from the prompt."""
presence_penalty: float = 0.0
"""Float that penalizes new tokens based on whether they appear in the
generated text so far"""
frequency_penalty: float = 0.0
"""Float that penalizes new tokens based on their frequency in the
generated text so far"""
temperature: float = 1.0
"""Float that controls the randomness of the sampling."""
top_p: float = 1.0
"""Float that controls the cumulative probability of the top tokens to consider."""
top_k: int = -1
"""Integer that controls the number of top tokens to consider."""
use_beam_search: bool = False
"""Whether to use beam search instead of sampling."""
stop: Optional[List[str]] = None
"""List of strings that stop the generation when they are generated."""
ignore_eos: bool = False
"""Whether to ignore the EOS token and continue generating tokens after
the EOS token is generated."""
max_new_tokens: int = 512
"""Maximum number of tokens to generate per output sequence."""
logprobs: Optional[int] = None
"""Number of log probabilities to return per output token."""
dtype: str = "auto"
"""The data type for the model weights and activations."""
download_dir: Optional[str] = None
"""Directory to download and load the weights. (Default to the default
cache dir of huggingface)"""
vllm_kwargs: Dict[str, Any] = Field(default_factory=dict)
"""Holds any model parameters valid for `vllm.LLM` call not explicitly specified."""
load_in_low_bit: str = "sym_int4"
"""Load in low bit format for ipex-llm low-bit quantization"""
device: str = "xpu"
enforce_eager: bool = True
client: Any #: :meta private:
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
print(values)
"""Validate that python package exists in environment."""
try:
# from vllm import LLM as VLLModel
from ipex_llm.vllm.engine import IPEXLLMClass as VLLModel
except ImportError:
raise ImportError(
"Could not import vllm python package. "
"Please install it with `pip install vllm`."
)
values["client"] = VLLModel(
model=values["model"],
tensor_parallel_size=values["tensor_parallel_size"],
trust_remote_code=values["trust_remote_code"],
dtype=values["dtype"],
download_dir=values["download_dir"],
load_in_low_bit=values["load_in_low_bit"],
device=values["device"],
enforce_eager=values["enforce_eager"],
**values["vllm_kwargs"],
)
return values
@property
def _default_params(self) -> Dict[str, Any]:
"""Get the default parameters for calling vllm."""
return {
"n": self.n,
"best_of": self.best_of,
"max_tokens": self.max_new_tokens,
"top_k": self.top_k,
"top_p": self.top_p,
"temperature": self.temperature,
"presence_penalty": self.presence_penalty,
"frequency_penalty": self.frequency_penalty,
"stop": self.stop,
"ignore_eos": self.ignore_eos,
"use_beam_search": self.use_beam_search,
"logprobs": self.logprobs,
}
def _generate(
self,
prompts: List[str],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> LLMResult:
"""Run the LLM on the given prompt and input."""
from vllm import SamplingParams
# build sampling parameters
params = {**self._default_params, **kwargs, "stop": stop}
sampling_params = SamplingParams(**params)
# call the model
outputs = self.client.generate(prompts, sampling_params)
generations = []
for output in outputs:
text = output.outputs[0].text
generations.append([Generation(text=text)])
return LLMResult(generations=generations)
@property
def _llm_type(self) -> str:
"""Return type of llm."""
return "vllm"
class VLLMOpenAI(BaseOpenAI):
"""vLLM OpenAI-compatible API client"""
@classmethod
def is_lc_serializable(cls) -> bool:
return False
@property
def _invocation_params(self) -> Dict[str, Any]:
"""Get the parameters used to invoke the model."""
params: Dict[str, Any] = {
"model": self.model_name,
**self._default_params,
"logit_bias": None,
}
if not is_openai_v1():
params.update(
{
"api_key": self.openai_api_key,
"api_base": self.openai_api_base,
}
)
return params
@property
def _llm_type(self) -> str:
"""Return type of llm."""
return "vllm-openai"