fix ppl test errors (#10036)
This commit is contained in:
		
							parent
							
								
									13e61738c5
								
							
						
					
					
						commit
						226f398c2a
					
				
					 3 changed files with 9 additions and 3 deletions
				
			
		| 
						 | 
				
			
			@ -3,9 +3,9 @@ Perplexity (PPL) is one of the most common metrics for evaluating language model
 | 
			
		|||
 | 
			
		||||
## HOW TO RUN
 | 
			
		||||
```python
 | 
			
		||||
python run.py --model_path <path/to/model> --low_bit sym_int4 fp4 mixed_fp4 sym_int8 fp8_e5m2 fp8_e4m3 mixed_fp8 --device xpu --dataset path=<dataset_path>,name=<dataset_name>
 | 
			
		||||
python run.py --model_path <path/to/model> --precisions sym_int4 fp4 mixed_fp4 sym_int8 fp8_e5m2 fp8_e4m3 mixed_fp8 --device xpu --dataset path=<dataset_path>,name=<dataset_name>
 | 
			
		||||
```
 | 
			
		||||
A more specific example to run perplexity on Llama2-7B and wikitext:
 | 
			
		||||
```python
 | 
			
		||||
python run.py --model_path meta-llama/Llama-2-7b-chat-hf --low_bit float16 sym_int4 --device xpu --dataset path=wikitext,name=wikitext-2-raw-v1
 | 
			
		||||
python run.py --model_path meta-llama/Llama-2-7b-chat-hf --precisions float16 sym_int4 --device xpu --dataset path=wikitext,name=wikitext-2-raw-v1
 | 
			
		||||
```
 | 
			
		||||
| 
						 | 
				
			
			@ -19,6 +19,7 @@ import numpy as np
 | 
			
		|||
import torch
 | 
			
		||||
from tqdm import tqdm
 | 
			
		||||
from transformers import AutoTokenizer
 | 
			
		||||
import gc
 | 
			
		||||
 | 
			
		||||
from bigdl.llm.transformers import AutoModelForCausalLM
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -81,4 +82,9 @@ class BigDLPPL:
 | 
			
		|||
                self.ppl_evaluator(data.numpy()[0, seq_len//2:, :], input_ids_chunks.numpy()[0, seq_len//2:])
 | 
			
		||||
            progress_bar.set_description(f"{self.ppl_evaluator}")
 | 
			
		||||
 | 
			
		||||
        torch.xpu.synchronize()
 | 
			
		||||
        torch.xpu.empty_cache()
 | 
			
		||||
        del self.model
 | 
			
		||||
        gc.collect()
 | 
			
		||||
        
 | 
			
		||||
        return self.ppl_evaluator.result()
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -44,7 +44,7 @@ def main():
 | 
			
		|||
    additional_model_kwargs = parse_kwargs(args.model_kwargs)
 | 
			
		||||
    summary = {}
 | 
			
		||||
    for precision in args.precisions:
 | 
			
		||||
        model_kwargs = additional_model_kwargs
 | 
			
		||||
        model_kwargs = additional_model_kwargs.copy()
 | 
			
		||||
        if precision in ggml_tensor_qtype.keys():
 | 
			
		||||
            model_kwargs['load_in_low_bit'] = precision
 | 
			
		||||
        else:
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in a new issue