fix ppl test errors (#10036)

This commit is contained in:
Ovo233 2024-01-30 16:26:21 +08:00 committed by GitHub
parent 13e61738c5
commit 226f398c2a
3 changed files with 9 additions and 3 deletions

View file

@ -3,9 +3,9 @@ Perplexity (PPL) is one of the most common metrics for evaluating language model
## HOW TO RUN ## HOW TO RUN
```python ```python
python run.py --model_path <path/to/model> --low_bit sym_int4 fp4 mixed_fp4 sym_int8 fp8_e5m2 fp8_e4m3 mixed_fp8 --device xpu --dataset path=<dataset_path>,name=<dataset_name> python run.py --model_path <path/to/model> --precisions sym_int4 fp4 mixed_fp4 sym_int8 fp8_e5m2 fp8_e4m3 mixed_fp8 --device xpu --dataset path=<dataset_path>,name=<dataset_name>
``` ```
A more specific example to run perplexity on Llama2-7B and wikitext: A more specific example to run perplexity on Llama2-7B and wikitext:
```python ```python
python run.py --model_path meta-llama/Llama-2-7b-chat-hf --low_bit float16 sym_int4 --device xpu --dataset path=wikitext,name=wikitext-2-raw-v1 python run.py --model_path meta-llama/Llama-2-7b-chat-hf --precisions float16 sym_int4 --device xpu --dataset path=wikitext,name=wikitext-2-raw-v1
``` ```

View file

@ -19,6 +19,7 @@ import numpy as np
import torch import torch
from tqdm import tqdm from tqdm import tqdm
from transformers import AutoTokenizer from transformers import AutoTokenizer
import gc
from bigdl.llm.transformers import AutoModelForCausalLM from bigdl.llm.transformers import AutoModelForCausalLM
@ -81,4 +82,9 @@ class BigDLPPL:
self.ppl_evaluator(data.numpy()[0, seq_len//2:, :], input_ids_chunks.numpy()[0, seq_len//2:]) self.ppl_evaluator(data.numpy()[0, seq_len//2:, :], input_ids_chunks.numpy()[0, seq_len//2:])
progress_bar.set_description(f"{self.ppl_evaluator}") progress_bar.set_description(f"{self.ppl_evaluator}")
torch.xpu.synchronize()
torch.xpu.empty_cache()
del self.model
gc.collect()
return self.ppl_evaluator.result() return self.ppl_evaluator.result()

View file

@ -44,7 +44,7 @@ def main():
additional_model_kwargs = parse_kwargs(args.model_kwargs) additional_model_kwargs = parse_kwargs(args.model_kwargs)
summary = {} summary = {}
for precision in args.precisions: for precision in args.precisions:
model_kwargs = additional_model_kwargs model_kwargs = additional_model_kwargs.copy()
if precision in ggml_tensor_qtype.keys(): if precision in ggml_tensor_qtype.keys():
model_kwargs['load_in_low_bit'] = precision model_kwargs['load_in_low_bit'] = precision
else: else: