From 99255fe36e3835d8f50f87438693af64f6d8dd74 Mon Sep 17 00:00:00 2001 From: ZehuaCao <47251317+Romanticoseu@users.noreply.github.com> Date: Mon, 13 May 2024 13:57:19 +0800 Subject: [PATCH] fix ppl (#10996) --- python/llm/dev/benchmark/perplexity/ppl.py | 5 +++-- python/llm/dev/benchmark/perplexity/run.py | 2 +- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/python/llm/dev/benchmark/perplexity/ppl.py b/python/llm/dev/benchmark/perplexity/ppl.py index 1b71d9fe..d340297e 100644 --- a/python/llm/dev/benchmark/perplexity/ppl.py +++ b/python/llm/dev/benchmark/perplexity/ppl.py @@ -74,8 +74,9 @@ class BigDLPPL: ppl_mean = np.mean(np.array(ppls)[~np.isnan(np.array(ppls))]) finally: - torch.xpu.synchronize() - torch.xpu.empty_cache() + if self.device == "xpu": + torch.xpu.synchronize() + torch.xpu.empty_cache() del self.model gc.collect() diff --git a/python/llm/dev/benchmark/perplexity/run.py b/python/llm/dev/benchmark/perplexity/run.py index d548e984..881d82e0 100644 --- a/python/llm/dev/benchmark/perplexity/run.py +++ b/python/llm/dev/benchmark/perplexity/run.py @@ -100,7 +100,7 @@ def main(): dumped = json.dumps(results, indent=2) print(dumped) - if args.output_path: + if output_path: with open(f"{log_dir}/result.json", "w") as f: f.write(dumped)