fix ppl test errors (#10036)
This commit is contained in:
parent
13e61738c5
commit
226f398c2a
3 changed files with 9 additions and 3 deletions
|
|
@ -3,9 +3,9 @@ Perplexity (PPL) is one of the most common metrics for evaluating language model
|
||||||
|
|
||||||
## HOW TO RUN
|
## HOW TO RUN
|
||||||
```python
|
```python
|
||||||
python run.py --model_path <path/to/model> --low_bit sym_int4 fp4 mixed_fp4 sym_int8 fp8_e5m2 fp8_e4m3 mixed_fp8 --device xpu --dataset path=<dataset_path>,name=<dataset_name>
|
python run.py --model_path <path/to/model> --precisions sym_int4 fp4 mixed_fp4 sym_int8 fp8_e5m2 fp8_e4m3 mixed_fp8 --device xpu --dataset path=<dataset_path>,name=<dataset_name>
|
||||||
```
|
```
|
||||||
A more specific example to run perplexity on Llama2-7B and wikitext:
|
A more specific example to run perplexity on Llama2-7B and wikitext:
|
||||||
```python
|
```python
|
||||||
python run.py --model_path meta-llama/Llama-2-7b-chat-hf --low_bit float16 sym_int4 --device xpu --dataset path=wikitext,name=wikitext-2-raw-v1
|
python run.py --model_path meta-llama/Llama-2-7b-chat-hf --precisions float16 sym_int4 --device xpu --dataset path=wikitext,name=wikitext-2-raw-v1
|
||||||
```
|
```
|
||||||
|
|
@ -19,6 +19,7 @@ import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from transformers import AutoTokenizer
|
from transformers import AutoTokenizer
|
||||||
|
import gc
|
||||||
|
|
||||||
from bigdl.llm.transformers import AutoModelForCausalLM
|
from bigdl.llm.transformers import AutoModelForCausalLM
|
||||||
|
|
||||||
|
|
@ -81,4 +82,9 @@ class BigDLPPL:
|
||||||
self.ppl_evaluator(data.numpy()[0, seq_len//2:, :], input_ids_chunks.numpy()[0, seq_len//2:])
|
self.ppl_evaluator(data.numpy()[0, seq_len//2:, :], input_ids_chunks.numpy()[0, seq_len//2:])
|
||||||
progress_bar.set_description(f"{self.ppl_evaluator}")
|
progress_bar.set_description(f"{self.ppl_evaluator}")
|
||||||
|
|
||||||
|
torch.xpu.synchronize()
|
||||||
|
torch.xpu.empty_cache()
|
||||||
|
del self.model
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
return self.ppl_evaluator.result()
|
return self.ppl_evaluator.result()
|
||||||
|
|
|
||||||
|
|
@ -44,7 +44,7 @@ def main():
|
||||||
additional_model_kwargs = parse_kwargs(args.model_kwargs)
|
additional_model_kwargs = parse_kwargs(args.model_kwargs)
|
||||||
summary = {}
|
summary = {}
|
||||||
for precision in args.precisions:
|
for precision in args.precisions:
|
||||||
model_kwargs = additional_model_kwargs
|
model_kwargs = additional_model_kwargs.copy()
|
||||||
if precision in ggml_tensor_qtype.keys():
|
if precision in ggml_tensor_qtype.keys():
|
||||||
model_kwargs['load_in_low_bit'] = precision
|
model_kwargs['load_in_low_bit'] = precision
|
||||||
else:
|
else:
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue