feat: add mixed_precision argument on ppl longbench evaluation (#11837)

* feat: add mixed_precision argument on ppl longbench evaluation

* fix: delete two spaces

---------

Co-authored-by: Jinhe Tang <jin.tang1337@gmail.com>
This commit is contained in:
Chu,Youcheng 2024-08-19 10:00:44 +08:00 committed by GitHub
parent 580c94d0e2
commit 46a1cbfa64
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -37,6 +37,7 @@ def get_arguments():
parser.add_argument("--dataset_path", required=False, type=str, default=None) parser.add_argument("--dataset_path", required=False, type=str, default=None)
parser.add_argument("--language", required=False, type=str, default="en", choices=['en', 'zh', 'all']) parser.add_argument("--language", required=False, type=str, default="en", choices=['en', 'zh', 'all'])
parser.add_argument("--precisions", required=False, type=str, default=None, nargs='+') parser.add_argument("--precisions", required=False, type=str, default=None, nargs='+')
parser.add_argument("--mixed_precision", action="store_true")
parser.add_argument("--device", type=str, default="xpu") parser.add_argument("--device", type=str, default="xpu")
parser.add_argument("--output_path", default=None) parser.add_argument("--output_path", default=None)
return parser.parse_args() return parser.parse_args()
@ -95,11 +96,11 @@ def main():
log_dir = f"{output_path}/{model_name}/{args.device}/{precision}/{args.language}" log_dir = f"{output_path}/{model_name}/{args.device}/{precision}/{args.language}"
os.makedirs(log_dir, exist_ok=True) os.makedirs(log_dir, exist_ok=True)
results = {} results = {}
ppl_evaluator = BigDLPPL(model_path=args.model_path, device=args.device, **model_kwargs) ppl_evaluator = BigDLPPL(model_path=args.model_path, device=args.device, mixed_precision=args.mixed_precision, **model_kwargs)
ppl = ppl_evaluator.perplexity_hf(encoded_texts) ppl = ppl_evaluator.perplexity_hf(encoded_texts)
summary[precision] = ppl summary[precision] = ppl
results['results'] = ppl results['results'] = ppl
results['config'] = {"model": model_name, "precision": precision, "device": args.device, "seq_len": args.seq_len, "language": args.language} results['config'] = {"model": model_name, "precision": precision, "mixed_precision": args.mixed_precision, "device": args.device, "seq_len": args.seq_len, "language": args.language }
dumped = json.dumps(results, indent=2) dumped = json.dumps(results, indent=2)
print(dumped) print(dumped)