fix benchmark script(#11243)
This commit is contained in:
parent
8aabb5bac7
commit
eeffeeb2e2
1 changed files with 32 additions and 3 deletions
|
|
@ -76,6 +76,7 @@ def run_vllm(
|
|||
enable_prefix_caching: bool,
|
||||
gpu_memory_utilization: float = 0.9,
|
||||
load_in_low_bit: str = "sym_int4",
|
||||
max_num_batched_tokens: int = 5000,
|
||||
) -> float:
|
||||
from vllm import SamplingParams
|
||||
from ipex_llm.vllm.xpu.engine import IPEXLLMClass as LLM
|
||||
|
|
@ -92,9 +93,30 @@ def run_vllm(
|
|||
kv_cache_dtype=kv_cache_dtype,
|
||||
device=device,
|
||||
enable_prefix_caching=enable_prefix_caching,
|
||||
load_in_low_bit=load_in_low_bit)
|
||||
load_in_low_bit=load_in_low_bit,
|
||||
max_num_batched_tokens=max_num_batched_tokens,)
|
||||
|
||||
|
||||
# Add the requests to the engine.
|
||||
warm_prompt = "hi " * (1024 - 1)
|
||||
warm_requests = [(warm_prompt, 1024, 1024)
|
||||
for _ in range(8)]
|
||||
for prompt, _, output_len in warm_requests:
|
||||
sampling_params = SamplingParams(
|
||||
n=n,
|
||||
temperature=0.0 if use_beam_search else 1.0,
|
||||
top_p=1.0,
|
||||
use_beam_search=use_beam_search,
|
||||
ignore_eos=True,
|
||||
max_tokens=output_len,
|
||||
)
|
||||
llm._add_request(
|
||||
prompt=prompt,
|
||||
prompt_token_ids=None,
|
||||
sampling_params=sampling_params,
|
||||
)
|
||||
llm._run_engine(use_tqdm=True)
|
||||
|
||||
for prompt, _, output_len in requests:
|
||||
sampling_params = SamplingParams(
|
||||
n=n,
|
||||
|
|
@ -216,7 +238,7 @@ def main(args: argparse.Namespace):
|
|||
args.tensor_parallel_size, args.seed, args.n, args.use_beam_search,
|
||||
args.trust_remote_code, args.dtype, args.max_model_len,
|
||||
args.enforce_eager, args.kv_cache_dtype, args.device,
|
||||
args.enable_prefix_caching, args.gpu_memory_utilization, args.load_in_low_bit)
|
||||
args.enable_prefix_caching, args.gpu_memory_utilization, args.load_in_low_bit, args.max_num_batched_tokens)
|
||||
elif args.backend == "hf":
|
||||
assert args.tensor_parallel_size == 1
|
||||
elapsed_time = run_hf(requests, args.model, tokenizer, args.n,
|
||||
|
|
@ -323,6 +345,12 @@ if __name__ == "__main__":
|
|||
choices=["sym_int4", "fp8", "fp16"],
|
||||
default="sym_int4",
|
||||
help="Low-bit format quantization with IPEX-LLM")
|
||||
|
||||
parser.add_argument('--max-num-batched-tokens',
|
||||
type=int,
|
||||
default=5000,
|
||||
help='maximum number of batched tokens per iteration')
|
||||
|
||||
args = parser.parse_args()
|
||||
if args.tokenizer is None:
|
||||
args.tokenizer = args.model
|
||||
|
|
@ -354,4 +382,5 @@ if __name__ == "__main__":
|
|||
if args.tokenizer != args.model:
|
||||
raise ValueError("Tokenizer must be the same as the model for MII "
|
||||
"backend.")
|
||||
main(args)
|
||||
main(args)
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue