From 46a1cbfa64b8015bca6b8a4bd0c0242b95f7c32b Mon Sep 17 00:00:00 2001 From: "Chu,Youcheng" <70999398+cranechu0131@users.noreply.github.com> Date: Mon, 19 Aug 2024 10:00:44 +0800 Subject: [PATCH] feat: add mixed_precision argument on ppl longbench evaluation (#11837) * feat: add mixed_precision argument on ppl longbench evaluation * fix: delete two spaces --------- Co-authored-by: Jinhe Tang --- python/llm/dev/benchmark/perplexity/run_longbench.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/python/llm/dev/benchmark/perplexity/run_longbench.py b/python/llm/dev/benchmark/perplexity/run_longbench.py index 92b4999a..c250d35f 100644 --- a/python/llm/dev/benchmark/perplexity/run_longbench.py +++ b/python/llm/dev/benchmark/perplexity/run_longbench.py @@ -37,6 +37,7 @@ def get_arguments(): parser.add_argument("--dataset_path", required=False, type=str, default=None) parser.add_argument("--language", required=False, type=str, default="en", choices=['en', 'zh', 'all']) parser.add_argument("--precisions", required=False, type=str, default=None, nargs='+') + parser.add_argument("--mixed_precision", action="store_true") parser.add_argument("--device", type=str, default="xpu") parser.add_argument("--output_path", default=None) return parser.parse_args() @@ -95,11 +96,11 @@ def main(): log_dir = f"{output_path}/{model_name}/{args.device}/{precision}/{args.language}" os.makedirs(log_dir, exist_ok=True) results = {} - ppl_evaluator = BigDLPPL(model_path=args.model_path, device=args.device, **model_kwargs) + ppl_evaluator = BigDLPPL(model_path=args.model_path, device=args.device, mixed_precision=args.mixed_precision, **model_kwargs) ppl = ppl_evaluator.perplexity_hf(encoded_texts) summary[precision] = ppl results['results'] = ppl - results['config'] = {"model": model_name, "precision": precision, "device": args.device, "seq_len": args.seq_len, "language": args.language} + results['config'] = {"model": model_name, "precision": precision, "mixed_precision": args.mixed_precision, "device": args.device, "seq_len": args.seq_len, "language": args.language } dumped = json.dumps(results, indent=2) print(dumped)