fix benchmark script(#11243)

This commit is contained in:
Guancheng Fu 2024-06-06 17:44:19 +08:00 committed by GitHub
parent 8aabb5bac7
commit eeffeeb2e2
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -76,6 +76,7 @@ def run_vllm(
enable_prefix_caching: bool, enable_prefix_caching: bool,
gpu_memory_utilization: float = 0.9, gpu_memory_utilization: float = 0.9,
load_in_low_bit: str = "sym_int4", load_in_low_bit: str = "sym_int4",
max_num_batched_tokens: int = 5000,
) -> float: ) -> float:
from vllm import SamplingParams from vllm import SamplingParams
from ipex_llm.vllm.xpu.engine import IPEXLLMClass as LLM from ipex_llm.vllm.xpu.engine import IPEXLLMClass as LLM
@ -92,9 +93,30 @@ def run_vllm(
kv_cache_dtype=kv_cache_dtype, kv_cache_dtype=kv_cache_dtype,
device=device, device=device,
enable_prefix_caching=enable_prefix_caching, enable_prefix_caching=enable_prefix_caching,
load_in_low_bit=load_in_low_bit) load_in_low_bit=load_in_low_bit,
max_num_batched_tokens=max_num_batched_tokens,)
# Add the requests to the engine. # Add the requests to the engine.
warm_prompt = "hi " * (1024 - 1)
warm_requests = [(warm_prompt, 1024, 1024)
for _ in range(8)]
for prompt, _, output_len in warm_requests:
sampling_params = SamplingParams(
n=n,
temperature=0.0 if use_beam_search else 1.0,
top_p=1.0,
use_beam_search=use_beam_search,
ignore_eos=True,
max_tokens=output_len,
)
llm._add_request(
prompt=prompt,
prompt_token_ids=None,
sampling_params=sampling_params,
)
llm._run_engine(use_tqdm=True)
for prompt, _, output_len in requests: for prompt, _, output_len in requests:
sampling_params = SamplingParams( sampling_params = SamplingParams(
n=n, n=n,
@ -216,7 +238,7 @@ def main(args: argparse.Namespace):
args.tensor_parallel_size, args.seed, args.n, args.use_beam_search, args.tensor_parallel_size, args.seed, args.n, args.use_beam_search,
args.trust_remote_code, args.dtype, args.max_model_len, args.trust_remote_code, args.dtype, args.max_model_len,
args.enforce_eager, args.kv_cache_dtype, args.device, args.enforce_eager, args.kv_cache_dtype, args.device,
args.enable_prefix_caching, args.gpu_memory_utilization, args.load_in_low_bit) args.enable_prefix_caching, args.gpu_memory_utilization, args.load_in_low_bit, args.max_num_batched_tokens)
elif args.backend == "hf": elif args.backend == "hf":
assert args.tensor_parallel_size == 1 assert args.tensor_parallel_size == 1
elapsed_time = run_hf(requests, args.model, tokenizer, args.n, elapsed_time = run_hf(requests, args.model, tokenizer, args.n,
@ -323,6 +345,12 @@ if __name__ == "__main__":
choices=["sym_int4", "fp8", "fp16"], choices=["sym_int4", "fp8", "fp16"],
default="sym_int4", default="sym_int4",
help="Low-bit format quantization with IPEX-LLM") help="Low-bit format quantization with IPEX-LLM")
parser.add_argument('--max-num-batched-tokens',
type=int,
default=5000,
help='maximum number of batched tokens per iteration')
args = parser.parse_args() args = parser.parse_args()
if args.tokenizer is None: if args.tokenizer is None:
args.tokenizer = args.model args.tokenizer = args.model
@ -354,4 +382,5 @@ if __name__ == "__main__":
if args.tokenizer != args.model: if args.tokenizer != args.model:
raise ValueError("Tokenizer must be the same as the model for MII " raise ValueError("Tokenizer must be the same as the model for MII "
"backend.") "backend.")
main(args) main(args)