add langchain vllm interface (#11121)
* done * fix * fix * add vllm * add langchain vllm exampels * add docs * temp
This commit is contained in:
parent
63e95698eb
commit
fabc395d0d
5 changed files with 371 additions and 18 deletions
|
|
@ -82,7 +82,7 @@ If the service have booted successfully, you should see the output similar to th
|
|||
|
||||
vLLM supports to utilize multiple cards through tensor parallel.
|
||||
|
||||
You can refer to this [documentation](https://ipex-llm.readthedocs.io/en/latest/doc/LLM/Quickstart/vLLM_quickstart.html#about-tensor-paralle) on how to utilize the `tensor-parallel` feature and start the service.
|
||||
You can refer to this [documentation](https://ipex-llm.readthedocs.io/en/latest/doc/LLM/Quickstart/vLLM_quickstart.html#about-tensor-parallel) on how to utilize the `tensor-parallel` feature and start the service.
|
||||
|
||||
#### Verify
|
||||
After the service has been booted successfully, you can send a test request using `curl`. Here, `YOUR_MODEL` should be set equal to `served_model_name` in your booting script, e.g. `Qwen1.5`.
|
||||
|
|
|
|||
|
|
@ -5,15 +5,7 @@ The examples in this folder shows how to use [LangChain](https://www.langchain.c
|
|||
### 1. Install ipex-llm
|
||||
Follow the instructions in [GPU Install Guide](https://ipex-llm.readthedocs.io/en/latest/doc/LLM/Overview/install_gpu.html) to install ipex-llm
|
||||
|
||||
### 2. Install Required Dependencies for langchain examples.
|
||||
|
||||
```bash
|
||||
pip install langchain==0.0.184
|
||||
pip install -U chromadb==0.3.25
|
||||
pip install -U pandas==2.0.3
|
||||
```
|
||||
|
||||
### 3. Configures OneAPI environment variables for Linux
|
||||
### 2. Configures OneAPI environment variables for Linux
|
||||
|
||||
> [!NOTE]
|
||||
> Skip this step if you are running on Windows.
|
||||
|
|
@ -24,9 +16,9 @@ This is a required step on Linux for APT or offline installed oneAPI. Skip this
|
|||
source /opt/intel/oneapi/setvars.sh
|
||||
```
|
||||
|
||||
### 4. Runtime Configurations
|
||||
### 3. Runtime Configurations
|
||||
For optimal performance, it is recommended to set several environment variables. Please check out the suggestions based on your device.
|
||||
#### 4.1 Configurations for Linux
|
||||
#### 3.1 Configurations for Linux
|
||||
<details>
|
||||
|
||||
<summary>For Intel Arc™ A-Series Graphics and Intel Data Center GPU Flex Series</summary>
|
||||
|
|
@ -63,7 +55,7 @@ export BIGDL_LLM_XMX_DISABLED=1
|
|||
|
||||
</details>
|
||||
|
||||
#### 4.2 Configurations for Windows
|
||||
#### 3.2 Configurations for Windows
|
||||
<details>
|
||||
|
||||
<summary>For Intel iGPU</summary>
|
||||
|
|
@ -88,9 +80,18 @@ set SYCL_CACHE_PERSISTENT=1
|
|||
> [!NOTE]
|
||||
> For the first time that each model runs on Intel iGPU/Intel Arc™ A300-Series or Pro A60, it may take several minutes to compile.
|
||||
|
||||
### 5. Run the examples
|
||||
### 4. Run the examples
|
||||
|
||||
#### 5.1. Streaming Chat
|
||||
#### 4.1. Streaming Chat
|
||||
|
||||
Install dependencies:
|
||||
|
||||
```bash
|
||||
pip install langchain==0.0.184
|
||||
pip install -U pandas==2.0.3
|
||||
```
|
||||
|
||||
Then execute:
|
||||
|
||||
```bash
|
||||
python chat.py -m MODEL_PATH -q QUESTION
|
||||
|
|
@ -99,7 +100,16 @@ arguments info:
|
|||
- `-m MODEL_PATH`: **required**, path to the model
|
||||
- `-q QUESTION`: question to ask. Default is `What is AI?`.
|
||||
|
||||
#### 5.2. RAG (Retrival Augmented Generation)
|
||||
#### 4.2. RAG (Retrival Augmented Generation)
|
||||
|
||||
Install dependencies:
|
||||
```bash
|
||||
pip install langchain==0.0.184
|
||||
pip install -U chromadb==0.3.25
|
||||
pip install -U pandas==2.0.3
|
||||
```
|
||||
|
||||
Then execute:
|
||||
|
||||
```bash
|
||||
python rag.py -m <path_to_model> [-q QUESTION] [-i INPUT_PATH]
|
||||
|
|
@ -110,16 +120,65 @@ arguments info:
|
|||
- `-i INPUT_PATH`: path to the input doc.
|
||||
|
||||
|
||||
#### 5.2. Low Bit
|
||||
#### 4.3. Low Bit
|
||||
|
||||
The low_bit example ([low_bit.py](./low_bit.py)) showcases how to use use langchain with low_bit optimized model.
|
||||
By `save_low_bit` we save the weights of low_bit model into the target folder.
|
||||
> Note: `save_low_bit` only saves the weights of the model.
|
||||
> Users could copy the tokenizer model into the target folder or specify `tokenizer_id` during initialization.
|
||||
|
||||
Install dependencies:
|
||||
```bash
|
||||
pip install langchain==0.0.184
|
||||
pip install -U pandas==2.0.3
|
||||
```
|
||||
Then execute:
|
||||
|
||||
```bash
|
||||
python low_bit.py -m <path_to_model> -t <path_to_target> [-q <your question>]
|
||||
```
|
||||
**Runtime Arguments Explained**:
|
||||
- `-m MODEL_PATH`: **Required**, the path to the model
|
||||
- `-t TARGET_PATH`: **Required**, the path to save the low_bit model
|
||||
- `-q QUESTION`: the question
|
||||
- `-q QUESTION`: the question
|
||||
|
||||
#### 4.4 vLLM
|
||||
|
||||
The vLLM example ([vllm.py](./vllm.py)) showcases how to use langchain with ipex-llm integrated vLLM engine.
|
||||
|
||||
Install dependencies:
|
||||
```bash
|
||||
pip install "langchain<0.2"
|
||||
```
|
||||
|
||||
Besides, you should also install IPEX-LLM integrated vLLM according instructions listed [here](https://ipex-llm.readthedocs.io/en/latest/doc/LLM/Quickstart/vLLM_quickstart.html#install-vllm)
|
||||
|
||||
**Runtime Arguments Explained**:
|
||||
- `-m MODEL_PATH`: **Required**, the path to the model
|
||||
- `-q QUESTION`: the question
|
||||
- `-t MAX_TOKENS`: max tokens to generate, default 128
|
||||
- `-p TENSOR_PARALLEL_SIZE`: Use multiple cards for generation
|
||||
- `-l LOAD_IN_LOW_BIT`: Low bit format for quantization
|
||||
|
||||
##### Single card
|
||||
|
||||
The following command shows an example on how to execute the example using one card:
|
||||
|
||||
```bash
|
||||
python ./vllm.py -m YOUR_MODEL_PATH -q "What is AI?" -t 128 -p 1 -l sym_int4
|
||||
```
|
||||
|
||||
##### Multi cards
|
||||
|
||||
To use `-p TENSOR_PARALLEL_SIZE` option, you will need to use our docker image: `intelanalytics/ipex-llm-serving-xpu:latest`. For how to use the image, try check this [guide](https://ipex-llm.readthedocs.io/en/latest/doc/LLM/DockerGuides/vllm_docker_quickstart.html#multi-card-serving).
|
||||
|
||||
The following command shows an example on how to execute the example using two cards:
|
||||
|
||||
```bash
|
||||
export CCL_WORKER_COUNT=2
|
||||
export FI_PROVIDER=shm
|
||||
export CCL_ATL_TRANSPORT=ofi
|
||||
export CCL_ZE_IPC_EXCHANGE=sockets
|
||||
export CCL_ATL_SHM=1
|
||||
python ./vllm.py -m YOUR_MODEL_PATH -q "What is AI?" -t 128 -p 2 -l sym_int4
|
||||
```
|
||||
45
python/llm/example/GPU/LangChain/vllm.py
Normal file
45
python/llm/example/GPU/LangChain/vllm.py
Normal file
|
|
@ -0,0 +1,45 @@
|
|||
from ipex_llm.langchain.vllm.vllm import VLLM
|
||||
from langchain.chains import LLMChain
|
||||
from langchain_core.prompts import PromptTemplate
|
||||
import argparse
|
||||
|
||||
def main(args):
|
||||
llm = VLLM(
|
||||
model=args.model_path,
|
||||
trust_remote_code=True, # mandatory for hf models
|
||||
max_new_tokens=128,
|
||||
top_k=10,
|
||||
top_p=0.95,
|
||||
temperature=0.8,
|
||||
max_model_len=2048,
|
||||
enforce_eager=True,
|
||||
load_in_low_bit=args.load_in_low_bit,
|
||||
device="xpu",
|
||||
tensor_parallel_size=args.tensor_parallel_size,
|
||||
)
|
||||
|
||||
print(llm.invoke(args.question))
|
||||
|
||||
template = """Question: {question}
|
||||
|
||||
Answer: Let's think step by step."""""
|
||||
prompt = PromptTemplate.from_template(template)
|
||||
|
||||
llm_chain = LLMChain(prompt=prompt, llm=llm)
|
||||
|
||||
print(llm_chain.invoke("Who was the US president in the year the first Pokemon game was released?"))
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description='Langchain integrated vLLM example')
|
||||
parser.add_argument('-m','--model-path', type=str, required=True,
|
||||
help='the path to transformers model')
|
||||
parser.add_argument('-q', '--question', type=str, default='What is the capital of France?', help='qustion you want to ask.')
|
||||
parser.add_argument('-t', '--max-tokens', type=int, default=128, help='max tokens to generate')
|
||||
parser.add_argument('-p', '--tensor-parallel-size', type=int, default=1, help="vLLM tensor parallel size")
|
||||
parser.add_argument('-l', '--load-in-low-bit', type=str, default='sym_int4', help="low bit format")
|
||||
args = parser.parse_args()
|
||||
|
||||
main(args)
|
||||
|
||||
20
python/llm/src/ipex_llm/langchain/vllm/__init__.py
Normal file
20
python/llm/src/ipex_llm/langchain/vllm/__init__.py
Normal file
|
|
@ -0,0 +1,20 @@
|
|||
#
|
||||
# Copyright 2016 The BigDL Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
# This would makes sure Python is aware there is more than one sub-package within bigdl,
|
||||
# physically located elsewhere.
|
||||
# Otherwise there would be module not found error in non-pip's setting as Python would
|
||||
# only search the first bigdl package and end up finding only one sub-package.
|
||||
229
python/llm/src/ipex_llm/langchain/vllm/vllm.py
Normal file
229
python/llm/src/ipex_llm/langchain/vllm/vllm.py
Normal file
|
|
@ -0,0 +1,229 @@
|
|||
#
|
||||
# Copyright 2016 The BigDL Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
# This file is adapted from
|
||||
# https://github.com/hwchase17/langchain/blob/master/langchain/llms/llamacpp.py
|
||||
|
||||
# The MIT License
|
||||
|
||||
# Copyright (c) Harrison Chase
|
||||
|
||||
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
# of this software and associated documentation files (the "Software"), to deal
|
||||
# in the Software without restriction, including without limitation the rights
|
||||
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
# copies of the Software, and to permit persons to whom the Software is
|
||||
# furnished to do so, subject to the following conditions:
|
||||
|
||||
# The above copyright notice and this permission notice shall be included in
|
||||
# all copies or substantial portions of the Software.
|
||||
|
||||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
# THE SOFTWARE.
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from langchain_core.callbacks import CallbackManagerForLLMRun
|
||||
from langchain_core.language_models.llms import BaseLLM
|
||||
from langchain_core.outputs import Generation, LLMResult
|
||||
from langchain_core.pydantic_v1 import Field, root_validator
|
||||
|
||||
from langchain_community.llms.openai import BaseOpenAI
|
||||
from langchain_community.utils.openai import is_openai_v1
|
||||
|
||||
|
||||
class VLLM(BaseLLM):
|
||||
"""VLLM language model."""
|
||||
|
||||
model: str = ""
|
||||
"""The name or path of a HuggingFace Transformers model."""
|
||||
|
||||
tensor_parallel_size: Optional[int] = 1
|
||||
"""The number of GPUs to use for distributed execution with tensor parallelism."""
|
||||
|
||||
trust_remote_code: Optional[bool] = False
|
||||
"""Trust remote code (e.g., from HuggingFace) when downloading the model
|
||||
and tokenizer."""
|
||||
|
||||
n: int = 1
|
||||
"""Number of output sequences to return for the given prompt."""
|
||||
|
||||
best_of: Optional[int] = None
|
||||
"""Number of output sequences that are generated from the prompt."""
|
||||
|
||||
presence_penalty: float = 0.0
|
||||
"""Float that penalizes new tokens based on whether they appear in the
|
||||
generated text so far"""
|
||||
|
||||
frequency_penalty: float = 0.0
|
||||
"""Float that penalizes new tokens based on their frequency in the
|
||||
generated text so far"""
|
||||
|
||||
temperature: float = 1.0
|
||||
"""Float that controls the randomness of the sampling."""
|
||||
|
||||
top_p: float = 1.0
|
||||
"""Float that controls the cumulative probability of the top tokens to consider."""
|
||||
|
||||
top_k: int = -1
|
||||
"""Integer that controls the number of top tokens to consider."""
|
||||
|
||||
use_beam_search: bool = False
|
||||
"""Whether to use beam search instead of sampling."""
|
||||
|
||||
stop: Optional[List[str]] = None
|
||||
"""List of strings that stop the generation when they are generated."""
|
||||
|
||||
ignore_eos: bool = False
|
||||
"""Whether to ignore the EOS token and continue generating tokens after
|
||||
the EOS token is generated."""
|
||||
|
||||
max_new_tokens: int = 512
|
||||
"""Maximum number of tokens to generate per output sequence."""
|
||||
|
||||
logprobs: Optional[int] = None
|
||||
"""Number of log probabilities to return per output token."""
|
||||
|
||||
dtype: str = "auto"
|
||||
"""The data type for the model weights and activations."""
|
||||
|
||||
download_dir: Optional[str] = None
|
||||
"""Directory to download and load the weights. (Default to the default
|
||||
cache dir of huggingface)"""
|
||||
|
||||
vllm_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
||||
"""Holds any model parameters valid for `vllm.LLM` call not explicitly specified."""
|
||||
|
||||
load_in_low_bit: str = "sym_int4"
|
||||
"""Load in low bit format for ipex-llm low-bit quantization"""
|
||||
|
||||
device: str = "xpu"
|
||||
|
||||
enforce_eager: bool = True
|
||||
|
||||
|
||||
client: Any #: :meta private:
|
||||
|
||||
@root_validator()
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
print(values)
|
||||
"""Validate that python package exists in environment."""
|
||||
|
||||
try:
|
||||
# from vllm import LLM as VLLModel
|
||||
from ipex_llm.vllm.engine import IPEXLLMClass as VLLModel
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Could not import vllm python package. "
|
||||
"Please install it with `pip install vllm`."
|
||||
)
|
||||
|
||||
values["client"] = VLLModel(
|
||||
model=values["model"],
|
||||
tensor_parallel_size=values["tensor_parallel_size"],
|
||||
trust_remote_code=values["trust_remote_code"],
|
||||
dtype=values["dtype"],
|
||||
download_dir=values["download_dir"],
|
||||
load_in_low_bit=values["load_in_low_bit"],
|
||||
device=values["device"],
|
||||
enforce_eager=values["enforce_eager"],
|
||||
**values["vllm_kwargs"],
|
||||
)
|
||||
|
||||
return values
|
||||
|
||||
@property
|
||||
def _default_params(self) -> Dict[str, Any]:
|
||||
"""Get the default parameters for calling vllm."""
|
||||
return {
|
||||
"n": self.n,
|
||||
"best_of": self.best_of,
|
||||
"max_tokens": self.max_new_tokens,
|
||||
"top_k": self.top_k,
|
||||
"top_p": self.top_p,
|
||||
"temperature": self.temperature,
|
||||
"presence_penalty": self.presence_penalty,
|
||||
"frequency_penalty": self.frequency_penalty,
|
||||
"stop": self.stop,
|
||||
"ignore_eos": self.ignore_eos,
|
||||
"use_beam_search": self.use_beam_search,
|
||||
"logprobs": self.logprobs,
|
||||
}
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
prompts: List[str],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> LLMResult:
|
||||
"""Run the LLM on the given prompt and input."""
|
||||
|
||||
from vllm import SamplingParams
|
||||
|
||||
# build sampling parameters
|
||||
params = {**self._default_params, **kwargs, "stop": stop}
|
||||
sampling_params = SamplingParams(**params)
|
||||
# call the model
|
||||
outputs = self.client.generate(prompts, sampling_params)
|
||||
|
||||
generations = []
|
||||
for output in outputs:
|
||||
text = output.outputs[0].text
|
||||
generations.append([Generation(text=text)])
|
||||
|
||||
return LLMResult(generations=generations)
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
"""Return type of llm."""
|
||||
return "vllm"
|
||||
|
||||
|
||||
class VLLMOpenAI(BaseOpenAI):
|
||||
"""vLLM OpenAI-compatible API client"""
|
||||
|
||||
@classmethod
|
||||
def is_lc_serializable(cls) -> bool:
|
||||
return False
|
||||
|
||||
@property
|
||||
def _invocation_params(self) -> Dict[str, Any]:
|
||||
"""Get the parameters used to invoke the model."""
|
||||
|
||||
params: Dict[str, Any] = {
|
||||
"model": self.model_name,
|
||||
**self._default_params,
|
||||
"logit_bias": None,
|
||||
}
|
||||
if not is_openai_v1():
|
||||
params.update(
|
||||
{
|
||||
"api_key": self.openai_api_key,
|
||||
"api_base": self.openai_api_base,
|
||||
}
|
||||
)
|
||||
|
||||
return params
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
"""Return type of llm."""
|
||||
return "vllm-openai"
|
||||
Loading…
Reference in a new issue