ipex-llm/docker/llm/serving/xpu/docker/benchmark_vllm_throughput.py
Guancheng Fu 69c8d36f16
Switching from vLLM v0.3.3 to vLLM 0.5.4 (#12042)
* Enable single card sync engine

* enable ipex-llm optimizations for vllm

* enable optimizations for lm_head

* Fix chatglm multi-reference problem

* Remove duplicate layer

* LLM: Update vLLM to v0.5.4 (#11746)

* Enable single card sync engine

* enable ipex-llm optimizations for vllm

* enable optimizations for lm_head

* Fix chatglm multi-reference problem

* update 0.5.4 api_server

* add dockerfile

* fix

* fix

* refine

* fix

---------

Co-authored-by: gc-fu <guancheng.fu@intel.com>

* Add vllm-0.5.4 Dockerfile (#11838)

* Update BIGDL_LLM_SDP_IGNORE_MASK in start-vllm-service.sh (#11957)

* Fix vLLM not convert issues (#11817) (#11918)

* Fix not convert issues

* refine

Co-authored-by: Guancheng Fu <110874468+gc-fu@users.noreply.github.com>

* Fix glm4-9b-chat nan error on vllm 0.5.4 (#11969)

* init

* update mlp forward

* fix minicpm error in vllm 0.5.4

* fix dependabot alerts (#12008)

* Update 0.5.4 dockerfile (#12021)

* Add vllm awq loading logic (#11987)

* [ADD] Add vllm awq loading logic

* [FIX] fix the module.linear_method path

* [FIX] fix quant_config path error

* Enable Qwen padding mlp to 256 to support batch_forward (#12030)

* Enable padding mlp

* padding to 256

* update style

* Install 27191 runtime in 0.5.4 docker image (#12040)

* fix rebase error

* fix rebase error

* vLLM: format for 0.5.4 rebase (#12043)

* format

* Update model_convert.py

* Fix serving docker related modifications (#12046)

* Fix undesired modifications (#12048)

* fix

* Refine offline_inference arguments

---------

Co-authored-by: Xiangyu Tian <109123695+xiangyuT@users.noreply.github.com>
Co-authored-by: Jun Wang <thoughts.times@gmail.com>
Co-authored-by: Wang, Jian4 <61138589+hzjane@users.noreply.github.com>
Co-authored-by: liu-shaojun <johnssalyn@outlook.com>
Co-authored-by: Shaojun Liu <61072813+liu-shaojun@users.noreply.github.com>
2024-09-10 15:37:43 +08:00

392 lines
14 KiB
Python

"""Benchmark offline inference throughput."""
import argparse
import json
import random
import time
from typing import List, Optional, Tuple
import torch
from transformers import (AutoModelForCausalLM, AutoTokenizer,
PreTrainedTokenizerBase)
from tqdm import tqdm
def sample_requests(
dataset_path: str,
num_requests: int,
tokenizer: PreTrainedTokenizerBase,
fixed_output_len: Optional[int],
) -> List[Tuple[str, int, int]]:
if fixed_output_len is not None and fixed_output_len < 4:
raise ValueError("output_len too small")
# Load the dataset.
with open(dataset_path) as f:
dataset = json.load(f)
# Filter out the conversations with less than 2 turns.
dataset = [data for data in dataset if len(data["conversations"]) >= 2]
# Only keep the first two turns of each conversation.
dataset = [(data["conversations"][0]["value"],
data["conversations"][1]["value"]) for data in dataset]
# Tokenize the prompts and completions.
prompts = [prompt for prompt, _ in dataset]
prompt_token_ids = tokenizer(prompts).input_ids
completions = [completion for _, completion in dataset]
completion_token_ids = tokenizer(completions).input_ids
tokenized_dataset = []
for i in range(len(dataset)):
output_len = len(completion_token_ids[i])
if fixed_output_len is not None:
output_len = fixed_output_len
tokenized_dataset.append((prompts[i], prompt_token_ids[i], output_len))
# Filter out too long sequences.
filtered_dataset: List[Tuple[str, int, int]] = []
for prompt, prompt_token_ids, output_len in tokenized_dataset:
prompt_len = len(prompt_token_ids)
if prompt_len < 4 or output_len < 4:
# Prune too short sequences.
continue
if prompt_len > 1024 or prompt_len + output_len > 2048:
# Prune too long sequences.
continue
filtered_dataset.append((prompt, prompt_len, output_len))
# Sample the requests.
sampled_requests = random.sample(filtered_dataset, num_requests)
return sampled_requests
def run_vllm(
requests: List[Tuple[str, int, int]],
model: str,
tokenizer: str,
quantization: Optional[str],
tensor_parallel_size: int,
seed: int,
n: int,
use_beam_search: bool,
trust_remote_code: bool,
dtype: str,
max_model_len: Optional[int],
enforce_eager: bool,
kv_cache_dtype: str,
device: str,
enable_prefix_caching: bool,
gpu_memory_utilization: float = 0.9,
load_in_low_bit: str = "sym_int4",
max_num_batched_tokens: int = 5000,
max_num_seqs: int = 256,
) -> float:
from vllm import SamplingParams
from ipex_llm.vllm.xpu.engine import IPEXLLMClass as LLM
llm = LLM(model=model,
tokenizer=tokenizer,
quantization=quantization,
tensor_parallel_size=tensor_parallel_size,
seed=seed,
trust_remote_code=trust_remote_code,
dtype=dtype,
max_model_len=max_model_len,
gpu_memory_utilization=gpu_memory_utilization,
enforce_eager=enforce_eager,
kv_cache_dtype=kv_cache_dtype,
device=device,
enable_prefix_caching=enable_prefix_caching,
load_in_low_bit=load_in_low_bit,
max_num_batched_tokens=max_num_batched_tokens,
max_num_seqs=max_num_seqs,)
# Add the requests to the engine.
warm_prompt = "hi " * (1024 - 1)
warm_requests = [(warm_prompt, 1024, 1024)
for _ in range(8)]
prompts: List[str] = []
sampling_params: List[SamplingParams] = []
for prompt, _, output_len in warm_requests:
prompts.append(prompt)
sampling_params.append(
SamplingParams(
n=n,
temperature=0.0 if use_beam_search else 1.0,
top_p=1.0,
use_beam_search=use_beam_search,
ignore_eos=True,
max_tokens=output_len,
))
llm.generate(prompts, sampling_params, use_tqdm=True)
prompts: List[str] = []
sampling_params: List[SamplingParams] = []
for prompt, _, output_len in requests:
prompts.append(prompt)
sampling_params.append(
SamplingParams(
n=n,
temperature=0.0 if use_beam_search else 1.0,
top_p=1.0,
use_beam_search=use_beam_search,
ignore_eos=True,
max_tokens=output_len,
))
start = time.perf_counter()
llm.generate(prompts, sampling_params, use_tqdm=True)
end = time.perf_counter()
return end - start
def run_hf(
requests: List[Tuple[str, int, int]],
model: str,
tokenizer: PreTrainedTokenizerBase,
n: int,
use_beam_search: bool,
max_batch_size: int,
trust_remote_code: bool,
) -> float:
assert not use_beam_search
llm = AutoModelForCausalLM.from_pretrained(
model, torch_dtype=torch.float16, trust_remote_code=trust_remote_code)
if llm.config.model_type == "llama":
# To enable padding in the HF backend.
tokenizer.pad_token = tokenizer.eos_token
llm = llm.cuda()
pbar = tqdm(total=len(requests))
start = time.perf_counter()
batch: List[str] = []
max_prompt_len = 0
max_output_len = 0
for i in range(len(requests)):
prompt, prompt_len, output_len = requests[i]
# Add the prompt to the batch.
batch.append(prompt)
max_prompt_len = max(max_prompt_len, prompt_len)
max_output_len = max(max_output_len, output_len)
if len(batch) < max_batch_size and i != len(requests) - 1:
# Check if we can add more requests to the batch.
_, next_prompt_len, next_output_len = requests[i + 1]
if (max(max_prompt_len, next_prompt_len) +
max(max_output_len, next_output_len)) <= 2048:
# We can add more requests to the batch.
continue
# Generate the sequences.
input_ids = tokenizer(batch, return_tensors="pt",
padding=True).input_ids
llm_outputs = llm.generate(
input_ids=input_ids.cuda(),
do_sample=not use_beam_search,
num_return_sequences=n,
temperature=1.0,
top_p=1.0,
use_cache=True,
max_new_tokens=max_output_len,
)
# Include the decoding time.
tokenizer.batch_decode(llm_outputs, skip_special_tokens=True)
pbar.update(len(batch))
# Clear the batch.
batch = []
max_prompt_len = 0
max_output_len = 0
end = time.perf_counter()
return end - start
def run_mii(
requests: List[Tuple[str, int, int]],
model: str,
tensor_parallel_size: int,
output_len: int,
) -> float:
from mii import pipeline
llm = pipeline(model, tensor_parallel=tensor_parallel_size)
prompts = [prompt for prompt, _, _ in requests]
start = time.perf_counter()
llm(prompts, max_new_tokens=output_len)
end = time.perf_counter()
return end - start
def main(args: argparse.Namespace):
print(args)
random.seed(args.seed)
# Sample the requests.
tokenizer = AutoTokenizer.from_pretrained(
args.tokenizer, trust_remote_code=args.trust_remote_code)
if args.dataset is None:
# Synthesize a prompt with the given input length.
prompt = "hi" * (args.input_len - 1)
requests = [(prompt, args.input_len, args.output_len)
for _ in range(args.num_prompts)]
else:
requests = sample_requests(args.dataset, args.num_prompts, tokenizer,
args.output_len)
if args.backend == "vllm":
elapsed_time = run_vllm(
requests, args.model, args.tokenizer, args.quantization,
args.tensor_parallel_size, args.seed, args.n, args.use_beam_search,
args.trust_remote_code, args.dtype, args.max_model_len,
args.enforce_eager, args.kv_cache_dtype, args.device,
args.enable_prefix_caching, args.gpu_memory_utilization, args.load_in_low_bit,
args.max_num_batched_tokens,args.max_num_seqs)
elif args.backend == "hf":
assert args.tensor_parallel_size == 1
elapsed_time = run_hf(requests, args.model, tokenizer, args.n,
args.use_beam_search, args.hf_max_batch_size,
args.trust_remote_code)
elif args.backend == "mii":
elapsed_time = run_mii(requests, args.model, args.tensor_parallel_size,
args.output_len)
else:
raise ValueError(f"Unknown backend: {args.backend}")
total_num_tokens = sum(prompt_len + output_len
for _, prompt_len, output_len in requests)
print(f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, "
f"{total_num_tokens / elapsed_time:.2f} tokens/s")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Benchmark the throughput.")
parser.add_argument("--backend",
type=str,
choices=["vllm", "hf", "mii"],
default="vllm")
parser.add_argument("--dataset",
type=str,
default=None,
help="Path to the dataset.")
parser.add_argument("--input-len",
type=int,
default=None,
help="Input prompt length for each request")
parser.add_argument("--output-len",
type=int,
default=None,
help="Output length for each request. Overrides the "
"output length from the dataset.")
parser.add_argument("--model", type=str, default="facebook/opt-125m")
parser.add_argument("--tokenizer", type=str, default=None)
parser.add_argument('--quantization',
'-q',
choices=['awq', 'gptq', 'squeezellm', None],
default=None)
parser.add_argument("--tensor-parallel-size", "-tp", type=int, default=1)
parser.add_argument("--n",
type=int,
default=1,
help="Number of generated sequences per prompt.")
parser.add_argument("--use-beam-search", action="store_true")
parser.add_argument("--num-prompts",
type=int,
default=1000,
help="Number of prompts to process.")
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--hf-max-batch-size",
type=int,
default=None,
help="Maximum batch size for HF backend.")
parser.add_argument('--trust-remote-code',
action='store_true',
help='trust remote code from huggingface')
parser.add_argument(
'--max-model-len',
type=int,
default=None,
help='Maximum length of a sequence (including prompt and output). '
'If None, will be derived from the model.')
parser.add_argument(
'--dtype',
type=str,
default='auto',
choices=['auto', 'half', 'float16', 'bfloat16', 'float', 'float32'],
help='data type for model weights and activations. '
'The "auto" option will use FP16 precision '
'for FP32 and FP16 models, and BF16 precision '
'for BF16 models.')
parser.add_argument('--gpu-memory-utilization',
type=float,
default=0.9,
help='the fraction of GPU memory to be used for '
'the model executor, which can range from 0 to 1.'
'If unspecified, will use the default value of 0.9.')
parser.add_argument("--enforce-eager",
action="store_true",
help="enforce eager execution")
parser.add_argument(
"--kv-cache-dtype",
type=str,
choices=["auto", "fp8_e5m2"],
default="auto",
help=
'Data type for kv cache storage. If "auto", will use model data type.')
parser.add_argument(
"--device",
type=str,
default="cuda",
choices=["cuda", "xpu"],
help='device type for vLLM execution, supporting CUDA only currently.')
parser.add_argument(
"--enable-prefix-caching",
action='store_true',
help="enable automatic prefix caching for vLLM backend.")
parser.add_argument(
"--load-in-low-bit",
type=str,
choices=["sym_int4", "fp8", "fp8_e4m3", "fp16", "fp6"],
default="sym_int4",
help="Low-bit format quantization with IPEX-LLM")
parser.add_argument('--max-num-batched-tokens',
type=int,
default=4096,
help='maximum number of batched tokens per iteration')
parser.add_argument('--max-num-seqs',
type=int,
default=256,
help='Maximum number of sequences per iteration.')
args = parser.parse_args()
if args.tokenizer is None:
args.tokenizer = args.model
if args.dataset is None:
assert args.input_len is not None
assert args.output_len is not None
else:
assert args.input_len is None
if args.backend == "vllm":
if args.hf_max_batch_size is not None:
raise ValueError("HF max batch size is only for HF backend.")
elif args.backend == "hf":
if args.hf_max_batch_size is None:
raise ValueError("HF max batch size is required for HF backend.")
if args.quantization is not None:
raise ValueError("Quantization is only for vLLM backend.")
elif args.backend == "mii":
if args.dtype != "auto":
raise ValueError("dtype must be auto for MII backend.")
if args.n != 1:
raise ValueError("n must be 1 for MII backend.")
if args.use_beam_search:
raise ValueError("Beam search is not supported for MII backend.")
if args.quantization is not None:
raise ValueError("Quantization is only for vLLM backend.")
if args.hf_max_batch_size is not None:
raise ValueError("HF max batch size is only for HF backend.")
if args.tokenizer != args.model:
raise ValueError("Tokenizer must be the same as the model for MII "
"backend.")
main(args)