From 226f398c2ab2dc19813f015b4c4bd3ae24775b87 Mon Sep 17 00:00:00 2001 From: Ovo233 <76120304+Mingyu-Wei@users.noreply.github.com> Date: Tue, 30 Jan 2024 16:26:21 +0800 Subject: [PATCH] fix ppl test errors (#10036) --- python/llm/dev/benchmark/perplexity/README.md | 4 ++-- python/llm/dev/benchmark/perplexity/ppl.py | 6 ++++++ python/llm/dev/benchmark/perplexity/run.py | 2 +- 3 files changed, 9 insertions(+), 3 deletions(-) diff --git a/python/llm/dev/benchmark/perplexity/README.md b/python/llm/dev/benchmark/perplexity/README.md index 359feeb4..3919e543 100644 --- a/python/llm/dev/benchmark/perplexity/README.md +++ b/python/llm/dev/benchmark/perplexity/README.md @@ -3,9 +3,9 @@ Perplexity (PPL) is one of the most common metrics for evaluating language model ## HOW TO RUN ```python -python run.py --model_path --low_bit sym_int4 fp4 mixed_fp4 sym_int8 fp8_e5m2 fp8_e4m3 mixed_fp8 --device xpu --dataset path=,name= +python run.py --model_path --precisions sym_int4 fp4 mixed_fp4 sym_int8 fp8_e5m2 fp8_e4m3 mixed_fp8 --device xpu --dataset path=,name= ``` A more specific example to run perplexity on Llama2-7B and wikitext: ```python -python run.py --model_path meta-llama/Llama-2-7b-chat-hf --low_bit float16 sym_int4 --device xpu --dataset path=wikitext,name=wikitext-2-raw-v1 +python run.py --model_path meta-llama/Llama-2-7b-chat-hf --precisions float16 sym_int4 --device xpu --dataset path=wikitext,name=wikitext-2-raw-v1 ``` \ No newline at end of file diff --git a/python/llm/dev/benchmark/perplexity/ppl.py b/python/llm/dev/benchmark/perplexity/ppl.py index 161ef634..3fb6be40 100644 --- a/python/llm/dev/benchmark/perplexity/ppl.py +++ b/python/llm/dev/benchmark/perplexity/ppl.py @@ -19,6 +19,7 @@ import numpy as np import torch from tqdm import tqdm from transformers import AutoTokenizer +import gc from bigdl.llm.transformers import AutoModelForCausalLM @@ -81,4 +82,9 @@ class BigDLPPL: self.ppl_evaluator(data.numpy()[0, seq_len//2:, :], input_ids_chunks.numpy()[0, seq_len//2:]) progress_bar.set_description(f"{self.ppl_evaluator}") + torch.xpu.synchronize() + torch.xpu.empty_cache() + del self.model + gc.collect() + return self.ppl_evaluator.result() diff --git a/python/llm/dev/benchmark/perplexity/run.py b/python/llm/dev/benchmark/perplexity/run.py index 320ba85f..c3ad9ff0 100644 --- a/python/llm/dev/benchmark/perplexity/run.py +++ b/python/llm/dev/benchmark/perplexity/run.py @@ -44,7 +44,7 @@ def main(): additional_model_kwargs = parse_kwargs(args.model_kwargs) summary = {} for precision in args.precisions: - model_kwargs = additional_model_kwargs + model_kwargs = additional_model_kwargs.copy() if precision in ggml_tensor_qtype.keys(): model_kwargs['load_in_low_bit'] = precision else: