This commit is contained in:
Xin Qiu 2024-05-08 14:28:05 +08:00 committed by GitHub
parent 5973d6c753
commit dfa3147278
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -33,6 +33,7 @@ import sys
sys.path.append(benchmark_util_path)
from benchmark_util import BenchmarkWrapper
from ipex_llm.utils.common.log4Error import invalidInputError
from ipex_llm.utils.common import invalidInputError
LLAMA_IDS = ['meta-llama/Llama-2-7b-chat-hf','meta-llama/Llama-2-13b-chat-hf',
'meta-llama/Llama-2-70b-chat-hf','decapoda-research/llama-7b-hf',
@ -110,6 +111,8 @@ def run_model(repo_id, test_api, in_out_pairs, local_model_hub=None, warm_up=1,
result = run_speculative_gpu(repo_id, local_model_hub, in_out_pairs, warm_up, num_trials, num_beams, batch_size)
elif test_api == 'pipeline_parallel_gpu':
result = run_pipeline_parallel_gpu(repo_id, local_model_hub, in_out_pairs, warm_up, num_trials, num_beams, low_bit, batch_size, cpu_embedding, fp16=use_fp16_torch_dtype, n_gpu=n_gpu)
else:
invalidInputError(False, "Unknown test_api " + test_api + ", please check your config.yaml.")
for in_out_pair in in_out_pairs:
if result and result[in_out_pair]: