fix ppl test errors (#10036)

This commit is contained in:
Ovo233 2024-01-30 16:26:21 +08:00 committed by GitHub
parent 13e61738c5
commit 226f398c2a
3 changed files with 9 additions and 3 deletions

View file

@ -3,9 +3,9 @@ Perplexity (PPL) is one of the most common metrics for evaluating language model
## HOW TO RUN
```python
python run.py --model_path <path/to/model> --low_bit sym_int4 fp4 mixed_fp4 sym_int8 fp8_e5m2 fp8_e4m3 mixed_fp8 --device xpu --dataset path=<dataset_path>,name=<dataset_name>
python run.py --model_path <path/to/model> --precisions sym_int4 fp4 mixed_fp4 sym_int8 fp8_e5m2 fp8_e4m3 mixed_fp8 --device xpu --dataset path=<dataset_path>,name=<dataset_name>
```
A more specific example to run perplexity on Llama2-7B and wikitext:
```python
python run.py --model_path meta-llama/Llama-2-7b-chat-hf --low_bit float16 sym_int4 --device xpu --dataset path=wikitext,name=wikitext-2-raw-v1
python run.py --model_path meta-llama/Llama-2-7b-chat-hf --precisions float16 sym_int4 --device xpu --dataset path=wikitext,name=wikitext-2-raw-v1
```

View file

@ -19,6 +19,7 @@ import numpy as np
import torch
from tqdm import tqdm
from transformers import AutoTokenizer
import gc
from bigdl.llm.transformers import AutoModelForCausalLM
@ -81,4 +82,9 @@ class BigDLPPL:
self.ppl_evaluator(data.numpy()[0, seq_len//2:, :], input_ids_chunks.numpy()[0, seq_len//2:])
progress_bar.set_description(f"{self.ppl_evaluator}")
torch.xpu.synchronize()
torch.xpu.empty_cache()
del self.model
gc.collect()
return self.ppl_evaluator.result()

View file

@ -44,7 +44,7 @@ def main():
additional_model_kwargs = parse_kwargs(args.model_kwargs)
summary = {}
for precision in args.precisions:
model_kwargs = additional_model_kwargs
model_kwargs = additional_model_kwargs.copy()
if precision in ggml_tensor_qtype.keys():
model_kwargs['load_in_low_bit'] = precision
else: