Add vLLM to ipex-llm serving image (#10807)

* add vllm

* done

* doc work

* fix done

* temp

* add docs

* format

* add start-fastchat-service.sh

* fix
This commit is contained in:
Guancheng Fu 2024-04-29 17:25:42 +08:00 committed by GitHub
parent 1f876fd837
commit 2c64754eb0
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
12 changed files with 819 additions and 171 deletions

View file

@ -6,14 +6,31 @@ ARG https_proxy
# Disable pip's cache behavior
ARG PIP_NO_CACHE_DIR=false
COPY ./entrypoint.sh /opt/entrypoint.sh
# Install Serving Dependencies
RUN cd /llm && \
pip install --pre --upgrade ipex-llm[serving] && \
pip install transformers==4.36.2 gradio==4.19.2 && \
chmod +x /opt/entrypoint.sh
RUN cd /llm &&\
# Install ipex-llm[serving] only will update ipex_llm source code without updating
# bigdl-core-xe, which will lead to problems
apt-get update && \
apt-get install -y libfabric-dev wrk && \
pip install --pre --upgrade ipex-llm[xpu,serving] && \
pip install transformers==4.37.0 gradio==4.19.2 && \
# Install vLLM-v2 dependencies
cd /llm && \
git clone -b sycl_xpu https://github.com/analytics-zoo/vllm.git && \
cd vllm && \
pip install -r requirements-xpu.txt && \
pip install --no-deps xformers && \
VLLM_BUILD_XPU_OPS=1 pip install --no-build-isolation -v -e . && \
pip install outlines==0.0.34 --no-deps && \
pip install interegular cloudpickle diskcache joblib lark nest-asyncio numba scipy && \
# For Qwen series models support
pip install transformers_stream_generator einops tiktoken
ADD ./offline_inference.py /llm/vllm-examples/
ADD ./payload-1024.lua /llm/vllm-examples/
ADD ./start-vllm-service.sh /llm/vllm-examples/
ADD ./benchmark_throughput.py /llm/vllm-examples/
ADD ./start-fastchat-service.sh /llm/fastchat-examples/
WORKDIR /llm/
ENTRYPOINT [ "/opt/entrypoint.sh" ]

View file

@ -43,4 +43,125 @@ root@arda-arc12:/# sycl-ls
```
After the container is booted, you could get into the container through `docker exec`.
To run model-serving using `IPEX-LLM` as backend, you can refer to this [document](https://github.com/intel-analytics/IPEX-LLM/tree/main/python/llm/src/ipex_llm/serving).
Currently, we provide two different serving engines in the image, which are FastChat serving engine and vLLM serving engine.
#### FastChat serving engine
To run model-serving using `IPEX-LLM` as backend using FastChat, you can refer to this [quickstart](https://ipex-llm.readthedocs.io/en/latest/doc/LLM/Quickstart/fastchat_quickstart.html#).
For convenience, we have included a file `/llm/fastchat-examples/start-fastchat-service.sh` in the image.
You can modify this script to using fastchat with either `ipex_llm_worker` or `vllm_worker`.
#### vLLM serving engine
To run vLLM engine using `IPEX-LLM` as backend, you can refer to this [document](https://github.com/intel-analytics/ipex-llm/blob/main/python/llm/example/GPU/vLLM-Serving/README.md).
We have included multiple example files in `/llm/vllm-examples`:
1. `offline_inference.py`: Used for offline inference example
2. `benchmark_throughput.py`: Used for benchmarking throughput
3. `payload-1024.lua`: Used for testing request per second using 1k-128 request
4. `start-vllm-service.sh`: Used for template for starting vLLM service
##### Online benchmark throurgh api_server
We can benchmark the api_server to get an estimation about TPS (transactions per second). To do so, you need to start the service first according to the instructions in this [section](https://github.com/intel-analytics/ipex-llm/blob/main/python/llm/example/GPU/vLLM-Serving/README.md#service).
In container, do the following:
1. modify the `/llm/vllm-examples/payload-1024.lua` so that the "model" attribute is correct. By default, we use a prompt that is roughly 1024 token long, you can change it if needed.
2. Start the benchmark using `wrk` using the script below:
```bash
cd /llm/vllm-examples
# You can change -t and -c to control the concurrency.
# By default, we use 12 connections to benchmark the service.
wrk -t12 -c12 -d15m -s payload-1024.lua http://localhost:8000/v1/completions --timeout 1h
```
#### Offline benchmark through benchmark_throughput.py
We have included the benchmark_throughput script provied by `vllm` in our image as `/llm/benchmark_throughput.py`. To use the benchmark_throughput script, you will need to download the test dataset through:
```bash
wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json
```
The full example looks like this:
```bash
cd /llm/vllm-examples
wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json
export MODEL="YOUR_MODEL"
# You can change load-in-low-bit from values in [sym_int4, fp8, fp16]
python3 /llm/vllm-examples/benchmark_throughput.py \
--backend vllm \
--dataset /llm/vllm-examples/ShareGPT_V3_unfiltered_cleaned_split.json \
--model $MODEL \
--num-prompts 1000 \
--seed 42 \
--trust-remote-code \
--enforce-eager \
--dtype float16 \
--device xpu \
--load-in-low-bit sym_int4 \
--gpu-memory-utilization 0.85
```
> Note: you can adjust --load-in-low-bit to use other formats of low-bit quantization.
You can also adjust `--gpu-memory-utilization` rate using the below script to find the best performance using the following script:
```bash
#!/bin/bash
# Define the log directory
LOG_DIR="YOUR_LOG_DIR"
# Check if the log directory exists, if not, create it
if [ ! -d "$LOG_DIR" ]; then
mkdir -p "$LOG_DIR"
fi
# Define an array of model paths
MODELS=(
"YOUR TESTED MODELS"
)
# Define an array of utilization rates
UTIL_RATES=(0.85 0.90 0.95)
# Loop over each model
for MODEL in "${MODELS[@]}"; do
# Loop over each utilization rate
for RATE in "${UTIL_RATES[@]}"; do
# Extract a simple model name from the path for easier identification
MODEL_NAME=$(basename "$MODEL")
# Define the log file name based on the model and rate
LOG_FILE="$LOG_DIR/${MODEL_NAME}_utilization_${RATE}.log"
# Execute the command and redirect output to the log file
# Sometimes you might need to set --max-model-len if memory is not enough
# load-in-low-bit accepts inputs [sym_int4, fp8, fp16]
python3 /llm/vllm-examples/benchmark_throughput.py \
--backend vllm \
--dataset /llm/vllm-examples/ShareGPT_V3_unfiltered_cleaned_split.json \
--model $MODEL \
--num-prompts 1000 \
--seed 42 \
--trust-remote-code \
--enforce-eager \
--dtype float16 \
--load-in-low-bit sym_int4 \
--device xpu \
--gpu-memory-utilization $RATE &> "$LOG_FILE"
done
done
# Inform the user that the script has completed its execution
echo "All benchmarks have been executed and logged."
```

View file

@ -0,0 +1,357 @@
"""Benchmark offline inference throughput."""
import argparse
import json
import random
import time
from typing import List, Optional, Tuple
import torch
from transformers import (AutoModelForCausalLM, AutoTokenizer,
PreTrainedTokenizerBase)
from tqdm import tqdm
def sample_requests(
dataset_path: str,
num_requests: int,
tokenizer: PreTrainedTokenizerBase,
fixed_output_len: Optional[int],
) -> List[Tuple[str, int, int]]:
if fixed_output_len is not None and fixed_output_len < 4:
raise ValueError("output_len too small")
# Load the dataset.
with open(dataset_path) as f:
dataset = json.load(f)
# Filter out the conversations with less than 2 turns.
dataset = [data for data in dataset if len(data["conversations"]) >= 2]
# Only keep the first two turns of each conversation.
dataset = [(data["conversations"][0]["value"],
data["conversations"][1]["value"]) for data in dataset]
# Tokenize the prompts and completions.
prompts = [prompt for prompt, _ in dataset]
prompt_token_ids = tokenizer(prompts).input_ids
completions = [completion for _, completion in dataset]
completion_token_ids = tokenizer(completions).input_ids
tokenized_dataset = []
for i in range(len(dataset)):
output_len = len(completion_token_ids[i])
if fixed_output_len is not None:
output_len = fixed_output_len
tokenized_dataset.append((prompts[i], prompt_token_ids[i], output_len))
# Filter out too long sequences.
filtered_dataset: List[Tuple[str, int, int]] = []
for prompt, prompt_token_ids, output_len in tokenized_dataset:
prompt_len = len(prompt_token_ids)
if prompt_len < 4 or output_len < 4:
# Prune too short sequences.
continue
if prompt_len > 1024 or prompt_len + output_len > 2048:
# Prune too long sequences.
continue
filtered_dataset.append((prompt, prompt_len, output_len))
# Sample the requests.
sampled_requests = random.sample(filtered_dataset, num_requests)
return sampled_requests
def run_vllm(
requests: List[Tuple[str, int, int]],
model: str,
tokenizer: str,
quantization: Optional[str],
tensor_parallel_size: int,
seed: int,
n: int,
use_beam_search: bool,
trust_remote_code: bool,
dtype: str,
max_model_len: Optional[int],
enforce_eager: bool,
kv_cache_dtype: str,
device: str,
enable_prefix_caching: bool,
gpu_memory_utilization: float = 0.9,
load_in_low_bit: str = "sym_int4",
) -> float:
from vllm import SamplingParams
from ipex_llm.vllm.engine import IPEXLLMClass as LLM
llm = LLM(model=model,
tokenizer=tokenizer,
quantization=quantization,
tensor_parallel_size=tensor_parallel_size,
seed=seed,
trust_remote_code=trust_remote_code,
dtype=dtype,
max_model_len=max_model_len,
gpu_memory_utilization=gpu_memory_utilization,
enforce_eager=enforce_eager,
kv_cache_dtype=kv_cache_dtype,
device=device,
enable_prefix_caching=enable_prefix_caching,
load_in_low_bit=load_in_low_bit)
# Add the requests to the engine.
for prompt, _, output_len in requests:
sampling_params = SamplingParams(
n=n,
temperature=0.0 if use_beam_search else 1.0,
top_p=1.0,
use_beam_search=use_beam_search,
ignore_eos=True,
max_tokens=output_len,
)
# FIXME(woosuk): Do not use internal method.
llm._add_request(
prompt=prompt,
prompt_token_ids=None,
sampling_params=sampling_params,
)
start = time.perf_counter()
# FIXME(woosuk): Do not use internal method.
llm._run_engine(use_tqdm=True)
end = time.perf_counter()
return end - start
def run_hf(
requests: List[Tuple[str, int, int]],
model: str,
tokenizer: PreTrainedTokenizerBase,
n: int,
use_beam_search: bool,
max_batch_size: int,
trust_remote_code: bool,
) -> float:
assert not use_beam_search
llm = AutoModelForCausalLM.from_pretrained(
model, torch_dtype=torch.float16, trust_remote_code=trust_remote_code)
if llm.config.model_type == "llama":
# To enable padding in the HF backend.
tokenizer.pad_token = tokenizer.eos_token
llm = llm.cuda()
pbar = tqdm(total=len(requests))
start = time.perf_counter()
batch: List[str] = []
max_prompt_len = 0
max_output_len = 0
for i in range(len(requests)):
prompt, prompt_len, output_len = requests[i]
# Add the prompt to the batch.
batch.append(prompt)
max_prompt_len = max(max_prompt_len, prompt_len)
max_output_len = max(max_output_len, output_len)
if len(batch) < max_batch_size and i != len(requests) - 1:
# Check if we can add more requests to the batch.
_, next_prompt_len, next_output_len = requests[i + 1]
if (max(max_prompt_len, next_prompt_len) +
max(max_output_len, next_output_len)) <= 2048:
# We can add more requests to the batch.
continue
# Generate the sequences.
input_ids = tokenizer(batch, return_tensors="pt",
padding=True).input_ids
llm_outputs = llm.generate(
input_ids=input_ids.cuda(),
do_sample=not use_beam_search,
num_return_sequences=n,
temperature=1.0,
top_p=1.0,
use_cache=True,
max_new_tokens=max_output_len,
)
# Include the decoding time.
tokenizer.batch_decode(llm_outputs, skip_special_tokens=True)
pbar.update(len(batch))
# Clear the batch.
batch = []
max_prompt_len = 0
max_output_len = 0
end = time.perf_counter()
return end - start
def run_mii(
requests: List[Tuple[str, int, int]],
model: str,
tensor_parallel_size: int,
output_len: int,
) -> float:
from mii import pipeline
llm = pipeline(model, tensor_parallel=tensor_parallel_size)
prompts = [prompt for prompt, _, _ in requests]
start = time.perf_counter()
llm(prompts, max_new_tokens=output_len)
end = time.perf_counter()
return end - start
def main(args: argparse.Namespace):
print(args)
random.seed(args.seed)
# Sample the requests.
tokenizer = AutoTokenizer.from_pretrained(
args.tokenizer, trust_remote_code=args.trust_remote_code)
if args.dataset is None:
# Synthesize a prompt with the given input length.
prompt = "hi" * (args.input_len - 1)
requests = [(prompt, args.input_len, args.output_len)
for _ in range(args.num_prompts)]
else:
requests = sample_requests(args.dataset, args.num_prompts, tokenizer,
args.output_len)
if args.backend == "vllm":
elapsed_time = run_vllm(
requests, args.model, args.tokenizer, args.quantization,
args.tensor_parallel_size, args.seed, args.n, args.use_beam_search,
args.trust_remote_code, args.dtype, args.max_model_len,
args.enforce_eager, args.kv_cache_dtype, args.device,
args.enable_prefix_caching, args.gpu_memory_utilization, args.load_in_low_bit)
elif args.backend == "hf":
assert args.tensor_parallel_size == 1
elapsed_time = run_hf(requests, args.model, tokenizer, args.n,
args.use_beam_search, args.hf_max_batch_size,
args.trust_remote_code)
elif args.backend == "mii":
elapsed_time = run_mii(requests, args.model, args.tensor_parallel_size,
args.output_len)
else:
raise ValueError(f"Unknown backend: {args.backend}")
total_num_tokens = sum(prompt_len + output_len
for _, prompt_len, output_len in requests)
print(f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, "
f"{total_num_tokens / elapsed_time:.2f} tokens/s")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Benchmark the throughput.")
parser.add_argument("--backend",
type=str,
choices=["vllm", "hf", "mii"],
default="vllm")
parser.add_argument("--dataset",
type=str,
default=None,
help="Path to the dataset.")
parser.add_argument("--input-len",
type=int,
default=None,
help="Input prompt length for each request")
parser.add_argument("--output-len",
type=int,
default=None,
help="Output length for each request. Overrides the "
"output length from the dataset.")
parser.add_argument("--model", type=str, default="facebook/opt-125m")
parser.add_argument("--tokenizer", type=str, default=None)
parser.add_argument('--quantization',
'-q',
choices=['awq', 'gptq', 'squeezellm', None],
default=None)
parser.add_argument("--tensor-parallel-size", "-tp", type=int, default=1)
parser.add_argument("--n",
type=int,
default=1,
help="Number of generated sequences per prompt.")
parser.add_argument("--use-beam-search", action="store_true")
parser.add_argument("--num-prompts",
type=int,
default=1000,
help="Number of prompts to process.")
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--hf-max-batch-size",
type=int,
default=None,
help="Maximum batch size for HF backend.")
parser.add_argument('--trust-remote-code',
action='store_true',
help='trust remote code from huggingface')
parser.add_argument(
'--max-model-len',
type=int,
default=None,
help='Maximum length of a sequence (including prompt and output). '
'If None, will be derived from the model.')
parser.add_argument(
'--dtype',
type=str,
default='auto',
choices=['auto', 'half', 'float16', 'bfloat16', 'float', 'float32'],
help='data type for model weights and activations. '
'The "auto" option will use FP16 precision '
'for FP32 and FP16 models, and BF16 precision '
'for BF16 models.')
parser.add_argument('--gpu-memory-utilization',
type=float,
default=0.9,
help='the fraction of GPU memory to be used for '
'the model executor, which can range from 0 to 1.'
'If unspecified, will use the default value of 0.9.')
parser.add_argument("--enforce-eager",
action="store_true",
help="enforce eager execution")
parser.add_argument(
"--kv-cache-dtype",
type=str,
choices=["auto", "fp8_e5m2"],
default="auto",
help=
'Data type for kv cache storage. If "auto", will use model data type.')
parser.add_argument(
"--device",
type=str,
default="cuda",
choices=["cuda", "xpu"],
help='device type for vLLM execution, supporting CUDA only currently.')
parser.add_argument(
"--enable-prefix-caching",
action='store_true',
help="enable automatic prefix caching for vLLM backend.")
parser.add_argument(
"--load-in-low-bit",
type=str,
choices=["sym_int4", "fp8", "fp16"],
default="sym_int4",
help="Low-bit format quantization with IPEX-LLM")
args = parser.parse_args()
if args.tokenizer is None:
args.tokenizer = args.model
if args.dataset is None:
assert args.input_len is not None
assert args.output_len is not None
else:
assert args.input_len is None
if args.backend == "vllm":
if args.hf_max_batch_size is not None:
raise ValueError("HF max batch size is only for HF backend.")
elif args.backend == "hf":
if args.hf_max_batch_size is None:
raise ValueError("HF max batch size is required for HF backend.")
if args.quantization is not None:
raise ValueError("Quantization is only for vLLM backend.")
elif args.backend == "mii":
if args.dtype != "auto":
raise ValueError("dtype must be auto for MII backend.")
if args.n != 1:
raise ValueError("n must be 1 for MII backend.")
if args.use_beam_search:
raise ValueError("Beam search is not supported for MII backend.")
if args.quantization is not None:
raise ValueError("Quantization is only for vLLM backend.")
if args.hf_max_batch_size is not None:
raise ValueError("HF max batch size is only for HF backend.")
if args.tokenizer != args.model:
raise ValueError("Tokenizer must be the same as the model for MII "
"backend.")
main(args)

View file

@ -1,139 +0,0 @@
#!/bin/bash
usage() {
echo "Usage: $0 [-m --mode <controller|worker>] [-w --worker <model_worker|vllm_worker>] [--help]"
echo "--help: Print help message."
echo "The following environment variables can be set."
echo "MODEL_PATH (default: empty)."
echo "CONTROLLER_HOST (default: localhost)."
echo "CONTROLLER_PORT (default: 21001)."
echo "WORKER_HOST (default: localhost)."
echo "WORKER_PORT (default: 21002)."
echo "API_HOST (default: localhost)."
echo "API_PORT (default: 8000)."
exit 1
}
# Default values
controller_host="localhost"
controller_port="21001"
worker_host="localhost"
worker_port="21002"
api_host="localhost"
api_port="8000"
model_path=""
mode=""
dispatch_method="shortest_queue" # shortest_queue or lottery
stream_interval=1
worker_type="model_worker"
# We do not have any arguments, just run bash
if [ "$#" == 0 ]; then
echo "[INFO] no command is passed in"
echo "[INFO] enter pass-through mode"
exec /usr/bin/bash -s -- "bash"
else
# Parse command-line options
options=$(getopt -o "m:hw:" --long "mode:,help,worker:" -n "$0" -- "$@")
if [ $? != 0 ]; then
usage
fi
eval set -- "$options"
while true; do
case "$1" in
-m|--mode)
mode="$2"
[[ $mode == "controller" || $mode == "worker" ]] || usage
shift 2
;;
-w|--worker)
worker_type="$2"
[[ $worker_type == "model_worker" || $worker_type == "vllm_worker" ]] || usage
shift 2
;;
-h|--help)
usage
;;
--)
shift
break
;;
*)
usage
;;
esac
done
if [ "$worker_type" == "model_worker" ]; then
worker_type="ipex_llm.serving.model_worker"
elif [ "$worker_type" == "vllm_worker" ]; then
worker_type="ipex_llm.serving.vllm_worker"
fi
if [[ -n $CONTROLLER_HOST ]]; then
controller_host=$CONTROLLER_HOST
fi
if [[ -n $CONTROLLER_PORT ]]; then
controller_port=$CONTROLLER_PORT
fi
if [[ -n $WORKER_HOST ]]; then
worker_host=$WORKER_HOST
fi
if [[ -n $WORKER_PORT ]]; then
worker_port=$WORKER_PORT
fi
if [[ -n $MODEL_PATH ]]; then
model_path=$MODEL_PATH
fi
if [[ -n $API_HOST ]]; then
api_host=$API_HOST
fi
if [[ -n $API_PORT ]]; then
api_port=$API_PORT
fi
if [[ -n $DISPATCH_METHOD ]]; then
dispatch_method=$DISPATCH_METHOD
fi
if [[ -n $STREAM_INTERVAL ]]; then
stream_interval=$STREAM_INTERVAL
fi
controller_address="http://$controller_host:$controller_port"
unset http_proxy
unset https_proxy
if [[ $mode == "controller" ]]; then
api_address="http://$api_host:$api_port"
echo "Controller address: $controller_address"
echo "OpenAI API address: $api_address"
python3 -m fastchat.serve.controller --host $controller_host --port $controller_port --dispatch-method $dispatch_method &
python3 -m fastchat.serve.openai_api_server --host $api_host --port $api_port --controller-address $controller_address
else
worker_address="http://$worker_host:$worker_port"
echo "Worker type: $worker_type"
echo "Worker address: $worker_address"
echo "Controller address: $controller_address"
if [ "$worker_type" == "ipex_llm.serving.model_worker" ]; then
python3 -m "$worker_type" --model-path $model_path --device xpu --host $worker_host --port $worker_port --worker-address $worker_address --controller-address $controller_address --stream-interval $stream_interval
elif [ "$worker_type" == "ipex_llm.serving.vllm_worker" ]; then
python3 -m "$worker_type" --model-path $model_path --device xpu --host $worker_host --port $worker_port --worker-address $worker_address --controller-address $controller_address
fi
fi
fi
exec /usr/bin/bash -s -- "bash"

View file

@ -0,0 +1,61 @@
#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Some parts of this file is adapted from
# https://github.com/vllm-project/vllm/blob/v0.2.1.post1/examples/offline_inference.py
# which is licensed under Apache License 2.0
#
# Copyright 2023 The vLLM team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from vllm import SamplingParams
from ipex_llm.vllm.engine import IPEXLLMClass as LLM
# Sample prompts.
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
# Create a sampling params object.
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
# Create an LLM.
llm = LLM(model="YOUR_MODEL",
device="xpu",
dtype="float16",
enforce_eager=True,
load_in_low_bit="sym_int4",
tensor_parallel_size=1)
# Generate texts from the prompts. The output is a list of RequestOutput objects
# that contain the prompt, generated text, and other information.
outputs = llm.generate(prompts, sampling_params)
# Print the outputs.
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")

View file

@ -0,0 +1,20 @@
wrk.method = "POST"
wrk.headers["Content-Type"] = "application/json"
wrk.body = [[
{
"model": "llama2",
"prompt": "Once upon a time, there existed a little girl who liked to have adventures. She wanted to go to places and meet new people, and have fun. However, her parents were always telling her to stay close to home, to be careful, and to avoid any danger. But the little girl was stubborn, and she wanted to see what was on the other side of the mountain. So she sneaked out of the house one night, leaving a note for her parents, and set off on her journey. As she climbed the mountain, the little girl felt a sense of excitement and wonder. She had never been this far away from home before, and she couldnt wait to see what she would find on the other side. She climbed higher and higher, her lungs burning from the thin air, until she finally reached the top of the mountain. And there, she found a beautiful meadow filled with wildflowers and a sparkling stream. The little girl danced and played in the meadow, feeling free and alive. She knew she had to return home eventually, but for now, she was content to enjoy her adventure. As the sun began to set, the little girl reluctantly made her way back down the mountain, but she knew that she would never forget her adventure and the joy of discovering something new and exciting. And whenever she felt scared or unsure, she would remember the thrill of climbing the mountain and the beauty of the meadow on the other side, and she would know that she could face any challenge that came her way, with courage and determination. She carried the memories of her journey in her heart, a constant reminder of the strength she possessed. The little girl returned home to her worried parents, who had discovered her note and anxiously awaited her arrival. They scolded her for disobeying their instructions and venturing into the unknown. But as they looked into her sparkling eyes and saw the glow on her face, their anger softened. They realized that their little girl had grown, that she had experienced something extraordinary. The little girl shared her tales of the mountain and the meadow with her parents, painting vivid pictures with her words. She spoke of the breathtaking view from the mountaintop, where the world seemed to stretch endlessly before her. She described the delicate petals of the wildflowers, vibrant hues that danced in the gentle breeze. And she recounted the soothing melody of the sparkling stream, its waters reflecting the golden rays of the setting sun. Her parents listened intently, captivated by her story. They realized that their daughter had discovered a part of herself on that journey—a spirit of curiosity and a thirst for exploration. They saw that she had learned valuable lessons about independence, resilience, and the beauty that lies beyond ones comfort zone. From that day forward, the little girls parents encouraged her to pursue her dreams and embrace new experiences. They understood that while there were risks in the world, there were also rewards waiting to be discovered. They supported her as she continued to embark on adventures, always reminding her to stay safe but never stifling her spirit. As the years passed, the little girl grew into a remarkable woman, fearlessly exploring the world and making a difference wherever she went. The lessons she had learned on that fateful journey stayed with her, guiding her through challenges and inspiring her to live life to the fullest. And so, the once timid little girl became a symbol of courage and resilience, a reminder to all who knew her that the greatest joys in life often lie just beyond the mountains we fear to climb. Her story spread far and wide, inspiring others to embrace their own journeys and discover the wonders that awaited them. In the end, the little girls adventure became a timeless tale, passed down through generations, reminding us all that sometimes, the greatest rewards come to those who dare to step into the unknown and follow their hearts. With each passing day, the little girls story continued to inspire countless individuals, igniting a spark within their souls and encouraging them to embark on their own extraordinary adventures. The tale of her bravery and determination resonated deeply with people from all walks of life, reminding them of the limitless possibilities that awaited them beyond the boundaries of their comfort zones. People marveled at the little girls unwavering spirit and her unwavering belief in the power of dreams. They saw themselves reflected in her journey, finding solace in the knowledge that they too could overcome their fears and pursue their passions. The little girl's story became a beacon of hope, a testament to the human spirit",
"max_tokens": 128,
"temperature": 0.5,
"n": 1,
"use_beam_search": false
}
]]
logfile = io.open("wrk.log", "w");
response = function(status, header, body)
logfile:write("status:" .. status .. "\n" .. body .. "\n-------------------------------------------------\n");
end

View file

@ -0,0 +1,125 @@
#!/bin/bash
usage() {
echo "Usage: $0 [-w --worker <model_worker|vllm_worker>] [--help]"
echo "--help: Print help message."
echo "The following environment variables can be set."
echo "MODEL_PATH (default: empty)."
echo "LOW_BIT_FORMAT (default: sym_int4)"
echo "CONTROLLER_HOST (default: localhost)."
echo "CONTROLLER_PORT (default: 21001)."
echo "WORKER_HOST (default: localhost)."
echo "WORKER_PORT (default: 21002)."
echo "API_HOST (default: localhost)."
echo "API_PORT (default: 8000)."
exit 1
}
# Default values
controller_host="localhost"
controller_port="21001"
worker_host="localhost"
worker_port="21002"
api_host="localhost"
api_port="8000"
model_path=""
mode=""
dispatch_method="shortest_queue" # shortest_queue or lottery
stream_interval=1
worker_type="model_worker"
low_bit_format="sym_int4"
# We do not have any arguments, just run bash
# Parse command-line options
options=$(getopt -o "hw:" --long "help,worker:" -n "$0" -- "$@")
if [ $? != 0 ]; then
usage
fi
eval set -- "$options"
while true; do
case "$1" in
-w|--worker)
worker_type="$2"
[[ $worker_type == "model_worker" || $worker_type == "vllm_worker" ]] || usage
shift 2
;;
-h|--help)
usage
;;
--)
shift
break
;;
*)
usage
;;
esac
done
if [ "$worker_type" == "model_worker" ]; then
worker_type="ipex_llm.serving.fastchat.ipex_llm_worker"
elif [ "$worker_type" == "vllm_worker" ]; then
worker_type="ipex_llm.serving.fastchat.vllm_worker"
fi
if [[ -n $CONTROLLER_HOST ]]; then
controller_host=$CONTROLLER_HOST
fi
if [[ -n $CONTROLLER_PORT ]]; then
controller_port=$CONTROLLER_PORT
fi
if [[ -n $LOW_BIT_FORMAT ]]; then
low_bit_format=$LOW_BIT_FORMAT
fi
if [[ -n $WORKER_HOST ]]; then
worker_host=$WORKER_HOST
fi
if [[ -n $WORKER_PORT ]]; then
worker_port=$WORKER_PORT
fi
if [[ -n $MODEL_PATH ]]; then
model_path=$MODEL_PATH
fi
if [[ -n $API_HOST ]]; then
api_host=$API_HOST
fi
if [[ -n $API_PORT ]]; then
api_port=$API_PORT
fi
if [[ -n $DISPATCH_METHOD ]]; then
dispatch_method=$DISPATCH_METHOD
fi
if [[ -n $STREAM_INTERVAL ]]; then
stream_interval=$STREAM_INTERVAL
fi
controller_address="http://$controller_host:$controller_port"
echo "Controller address: $controller_address"
python3 -m fastchat.serve.controller --host $controller_host --port $controller_port --dispatch-method $dispatch_method &
worker_address="http://$worker_host:$worker_port"
echo "Worker type: $worker_type"
echo "Worker address: $worker_address"
if [ "$worker_type" == "ipex_llm.serving.fastchat.ipex_llm_worker" ]; then
python3 -m "$worker_type" --model-path $model_path --device xpu --low-bit $low_bit_format --host $worker_host --port $worker_port --worker-address $worker_address --controller-address $controller_address --stream-interval $stream_interval &
elif [ "$worker_type" == "ipex_llm.serving.fastchat.vllm_worker" ]; then
python3 -m "$worker_type" --model-path $model_path --device xpu --load-in-low-bit $low_bit_format --host $worker_host --port $worker_port --worker-address $worker_address --controller-address $controller_address --enforce-eager --gpu-memory-utilization 0.85 &
fi
sleep 10
api_address="http://$api_host:$api_port"
echo "OpenAI API address: $api_address"
python3 -m fastchat.serve.openai_api_server --host $api_host --port $api_port --controller-address $controller_address

View file

@ -0,0 +1,19 @@
#!/bin/bash
model="YOUR_MODEL_PATH"
served_model_name="YOUR_MODEL_NAME"
python -m ipex_llm.vllm.entrypoints.openai.api_server \
--served-model-name $served_model_name \
--port 8000 \
--model $model \
--trust-remote-code \
--gpu-memory-utilization 0.75 \
--device xpu \
--dtype float16 \
--enforce-eager \
--load-in-low-bit sym_int4 \
--max-model-len 4096 \
--max-num-batched-tokens 10240 \
--max-num-seqs 12 \
--tensor-parallel-size 1

View file

@ -83,24 +83,93 @@ To fully utilize the continuous batching feature of the `vLLM`, you can send req
For vLLM, you can start the service using the following command:
```bash
#!/bin/bash
model="YOUR_MODEL_PATH"
served_model_name="YOUR_MODEL_NAME"
# You may need to adjust the value of
# --max-model-len, --max-num-batched-tokens, --max-num-seqs
# to acquire the best performance
python -m ipex_llm.vllm.entrypoints.openai.api_server \
--model /MODEL_PATH/Llama-2-7b-chat-hf/ --port 8000 \
--device xpu --dtype float16 \
--served-model-name $served_model_name \
--port 8000 \
--model $model \
--trust-remote-code \
--gpu-memory-utilization 0.75 \
--device xpu \
--dtype float16 \
--enforce-eager \
--load-in-low-bit sym_int4 \
--max-num-batched-tokens 4096
--max-model-len 4096 \
--max-num-batched-tokens 10240 \
--max-num-seqs 12 \
--tensor-parallel-size 1
```
You can tune the service using these four arguments:
1. --gpu-memory-utilization: The fraction of GPU memory to be used for the model executor, which can range from 0 to 1. For example, a value of 0.5 would imply 50% GPU memory utilization. If unspecified, will use the default value of 0.9.
2. --max-model-len: Model context length. If unspecified, will be automatically derived from the model config.
3. --max-num-batched-token: Maximum number of batched tokens per iteration.
4. --max-num-seq: Maximum number of sequences per iteration. Default: 256
Then you can access the api server as follows:
After the service has been booted successfully, you can send a test request using curl. Here, the `YOUR_MODEL` should be set equal to `$served_model_name` in your booting script.
```bash
curl http://localhost:8000/v1/completions \
-H "Content-Type: application/json" \
-d '{
"model": "/MODEL_PATH/Llama-2-7b-chat-hf/",
curl http://localhost:8000/v1/completions \
-H "Content-Type: application/json" \
-d '{
"model": "YOUR_MODEL_NAME",
"prompt": "San Francisco is a",
"max_tokens": 128,
"temperature": 0
}' &
```
#### Tensor parallel
> Note: We recommend to use docker for tensor parallel deployment.
We have also supported tensor parallel by using multiple XPU cards. To enable tensor parallel, you will need to install `libfabric-dev` in your environment. In ubuntu, you can install it by:
```bash
sudo apt-get install libfabric-dev
```
To deploy your model across multiple cards, simplely change the value of `--tensor-parallel-size` to the desired value.
For instance, if you have two Arc A770 cards in your environment, then you can set this value to 2. Some OneCCL environment variable settings are also needed, try check the following example:
```bash
#!/bin/bash
model="YOUR_MODEL_PATH"
served_model_name="YOUR_MODEL_NAME"
# CCL needed environment variables
export CCL_WORKER_COUNT=2
export FI_PROVIDER=shm
export CCL_ATL_TRANSPORT=ofi
export CCL_ZE_IPC_EXCHANGE=sockets
export CCL_ATL_SHM=1
# You may need to adjust the value of
# --max-model-len, --max-num-batched-tokens, --max-num-seqs
# to acquire the best performance
python -m ipex_llm.vllm.entrypoints.openai.api_server \
--served-model-name $served_model_name \
--port 8000 \
--model $model \
--trust-remote-code \
--gpu-memory-utilization 0.75 \
--device xpu \
--dtype float16 \
--enforce-eager \
--load-in-low-bit sym_int4 \
--max-model-len 4096 \
--max-num-batched-tokens 10240 \
--max-num-seqs 12 \
--tensor-parallel-size 2
```

View file

@ -49,7 +49,8 @@ llm = LLM(model="YOUR_MODEL",
device="xpu",
dtype="float16",
enforce_eager=True,
load_in_low_bit="sym_int4")
load_in_low_bit="sym_int4",
tensor_parallel_size=1)
# Generate texts from the prompts. The output is a list of RequestOutput objects
# that contain the prompt, generated text, and other information.
outputs = llm.generate(prompts, sampling_params)

View file

@ -291,16 +291,6 @@ if __name__ == "__main__":
help="Trust remote code (e.g., from HuggingFace) when"
"downloading the model and tokenizer.",
)
parser.add_argument(
"--gpu_memory_utilization",
type=float,
default=0.9,
help="The ratio (between 0 and 1) of GPU memory to"
"reserve for the model weights, activations, and KV cache. Higher"
"values will increase the KV cache size and thus improve the model's"
"throughput. However, if the value is too high, it may cause out-of-"
"memory (OOM) errors.",
)
parser = AsyncEngineArgs.add_cli_args(parser)
args = parser.parse_args()

View file

@ -135,6 +135,13 @@ def is_linear_module(module):
in_features = module.input_size_per_partition
elif isinstance(module, ColumnParallelLinear) and tp_size >= 2:
out_features = module.output_size_per_partition
else:
# Also check for Linear module
if isinstance(module, nn.Linear) or is_awq:
in_features = module.in_features
out_features = module.out_features
mp_group = None
result = True
else:
result = False
elif is_gptq_linear(module):