LLM: Update ppl tests (#10092)

* update ppl tests

* use load_dataset api

* add exception handling

* add language argument

* address comments
This commit is contained in:
Ovo233 2024-02-06 17:31:48 +08:00 committed by GitHub
parent 3a46b57253
commit 2aaa21c41d
3 changed files with 124 additions and 87 deletions

View file

@ -1,11 +1,20 @@
# Perplexity # Perplexity
Perplexity (PPL) is one of the most common metrics for evaluating language models. This benchmark implementation was from [transformers/perplexity](https://huggingface.co/docs/transformers/perplexity#perplexity-of-fixed-length-models) and [llm_perplexity.py](https://github.com/luo-cheng2021/ov.cpu.llm.experimental/blob/main/llm_perplexity.py) Perplexity (PPL) is one of the most common metrics for evaluating language models. This benchmark implementation was from [transformers/perplexity](https://huggingface.co/docs/transformers/perplexity#perplexity-of-fixed-length-models) and [benchmark_patch_llm.py](https://github.com/insuhan/hyper-attn/blob/main/benchmark_patch_llm.py)
## HOW TO RUN ## HOW TO RUN
```python ```bash
python run.py --model_path <path/to/model> --precisions sym_int4 fp4 mixed_fp4 sym_int8 fp8_e5m2 fp8_e4m3 mixed_fp8 --device xpu --dataset path=<dataset_path>,name=<dataset_name> python run.py --model_path <path/to/model> --precisions sym_int4 fp4 mixed_fp4 sym_int8 fp8_e5m2 fp8_e4m3 mixed_fp8 --device xpu --datasets dataset_names --dataset_path <path/to/dataset> --language en
``` ```
A more specific example to run perplexity on Llama2-7B and wikitext: A more specific example to run perplexity on Llama2-7B using the default English datasets:
```python ```bash
python run.py --model_path meta-llama/Llama-2-7b-chat-hf --precisions float16 sym_int4 --device xpu --dataset path=wikitext,name=wikitext-2-raw-v1 python run.py --model_path meta-llama/Llama-2-7b-chat-hf --precisions float16 sym_int4 --device xpu --language en
``` ```
> Note: We currently only support the `THUDM/LongBench` [dataset](https://github.com/THUDM/LongBench)
- If you want to test model perplexity on a few selected datasets from the `LongBench` dataset, please use the format below.
```bash
--datasets narrativeqa qasper ...
```
- The `language` argument will only take effect if `datasets` is `None`. The choices for this argument are `en, zh, all`, which stands for all the English datasets, all the Chinese datasets and all the datasets respectively during testing.
- If you want to test perplexity on pre-downloaded datasets, please specify the `<path/to/dataset>` in the `dataset_path` argument in your command.

View file

@ -14,77 +14,69 @@
# limitations under the License. # limitations under the License.
# #
import intel_extension_for_pytorch as ipex
import numpy as np import numpy as np
import torch import torch
from torch.nn import CrossEntropyLoss
from tqdm import tqdm from tqdm import tqdm
from transformers import AutoTokenizer
import gc import gc
from bigdl.llm.transformers import AutoModelForCausalLM from bigdl.llm.transformers import AutoModelForCausalLM, AutoModel
class PPL:
def __init__(self):
self.nll = 0
self.cnt = 0
def __call__(self, all_logits, labels):
'''
all_logits [seq_length, vocab_size]
labels [seq_length]
'''
seq_length = all_logits.shape[0]
for i in range(0, seq_length - 1):
logits = all_logits[i, :]
max_logit = np.amax(logits)
sum_exp = np.sum(np.exp(logits - max_logit))
# logits at time-step i is for predicting token at time-step (i+1)
next_tok = labels[i + 1]
log_softmax_of_tok = (logits[next_tok] - max_logit) - np.log(sum_exp)
self.nll += -log_softmax_of_tok
self.cnt += 1
return np.exp(self.nll / self.cnt)
def result(self):
return np.exp(self.nll / self.cnt)
def __str__(self):
return f"PPL: {np.exp(self.nll / self.cnt):.3f}"
class BigDLPPL: class BigDLPPL:
def __init__(self, model_path, device, **model_kwargs) -> None: def __init__(self, model_path, device, **model_kwargs) -> None:
model_kwargs['trust_remote_code'] = model_kwargs.get('trust_remote_code', True) model_kwargs['trust_remote_code'] = model_kwargs.get('trust_remote_code', True)
model_kwargs['optimize_model'] = model_kwargs.get('optimize_model', True) model_kwargs['optimize_model'] = model_kwargs.get('optimize_model', True)
self.device = device self.device = device
self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
self.model = AutoModelForCausalLM.from_pretrained(model_path, **model_kwargs) if 'chatglm' in model_path.lower():
if 'xpu' in device: self.model = AutoModel.from_pretrained(model_path, **model_kwargs)
import intel_extension_for_pytorch as ipex else:
self.model = AutoModelForCausalLM.from_pretrained(model_path, **model_kwargs)
self.model.to(device) self.model.to(device)
self.ppl_evaluator = PPL()
def perplexity_hf(self, text):
inputs = self.tokenizer('\n\n'.join(text), return_tensors='pt').to(self.device)
input_ids = inputs['input_ids']
# attention_mask = inputs['attention_mask']
progress_bar = tqdm(range(0, input_ids.shape[1], 512))
for i0 in progress_bar:
input_ids_chunks = input_ids[:, i0:(i0+512)]
input_ids_chunks[:, 0] = 1
with torch.no_grad():
result = self.model.forward(input_ids_chunks, labels = input_ids_chunks, return_dict=True)
#print(f"ppl = {torch.exp(result.loss)}")
seq_len = result.logits.shape[1]
data = result.logits
data = data.to('cpu')
input_ids_chunks = input_ids_chunks.to('cpu')
self.ppl_evaluator(data.numpy()[0, seq_len//2:, :], input_ids_chunks.numpy()[0, seq_len//2:])
progress_bar.set_description(f"{self.ppl_evaluator}")
torch.xpu.synchronize() def perplexity_hf(self, encoded_texts):
torch.xpu.empty_cache() self.model.eval()
del self.model loss_fct = CrossEntropyLoss(reduction="none")
gc.collect() ppls = []
try:
pbar = tqdm(range(len(encoded_texts)))
for bid in pbar:
encoded_batch = encoded_texts[bid:bid+1]
if type(encoded_batch) == dict:
attn_mask = encoded_batch['attention_mask'] if 'attention_mask' in encoded_batch.keys() else None
encoded_batch = encoded_batch['input_ids']
elif type(encoded_batch) == list:
encoded_batch = encoded_batch[0]
encoded_batch = encoded_batch.to(self.device)
attn_mask = torch.ones_like(encoded_batch)
out_logits = self.model(encoded_batch).logits
labels = encoded_batch
shift_logits = out_logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
shift_attention_mask_batch = attn_mask[..., 1:].contiguous()
loss_ = loss_fct(shift_logits.transpose(1, 2), shift_labels).float()
perplexity_batch = torch.exp2(
(loss_ * shift_attention_mask_batch).sum(1)
/ shift_attention_mask_batch.sum(1)
)
ppls += perplexity_batch.tolist()
pbar.set_description(f"[{bid:<4}/{len(encoded_texts)}] avg_ppls: {np.mean(np.array(ppls)[~np.isnan(np.array(ppls))]):.4f}")
del out_logits, encoded_batch, attn_mask, shift_logits, shift_labels, shift_attention_mask_batch, perplexity_batch
ppl_mean = np.mean(np.array(ppls)[~np.isnan(np.array(ppls))])
finally:
torch.xpu.synchronize()
torch.xpu.empty_cache()
del self.model
gc.collect()
return self.ppl_evaluator.result() return ppl_mean

View file

@ -14,45 +14,81 @@
# limitations under the License. # limitations under the License.
# #
import torch
from bigdl.llm.ggml.quantize import ggml_tensor_qtype
from ppl import BigDLPPL
from datasets import load_dataset
import argparse import argparse
from tqdm import tqdm
import torch
from datasets import concatenate_datasets, load_dataset
from transformers import AutoTokenizer
from ppl import BigDLPPL
from bigdl.llm.ggml.quantize import ggml_tensor_qtype
def parse_kwargs(kwstr): import os
kvpair = [item.split('=') for item in kwstr.split(',') if item != ""]
return {k:v for k, v in kvpair}
def get_arguments():
def parse_args():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--seq_len", type=int, default=512)
parser.add_argument("--model_path", required=True, type=str) parser.add_argument("--model_path", required=True, type=str)
parser.add_argument("--datasets", required=False, type=str, default=None, nargs='*')
parser.add_argument("--dataset_path", required=False, type=str, default=None)
parser.add_argument("--language", required=False, type=str, default="en", choices=['en', 'zh', 'all'])
parser.add_argument("--precisions", required=False, type=str, default=None, nargs='+') parser.add_argument("--precisions", required=False, type=str, default=None, nargs='+')
parser.add_argument("--model_kwargs", required=False, type=str, default="") parser.add_argument("--device", type=str, default="xpu")
parser.add_argument("--torch_dtype", type=str, default=None)
parser.add_argument("--device", type=str, default=None)
parser.add_argument("--dataset", type=str, default='path=wikitext,name=wikitext-2-raw-v1')
return parser.parse_args() return parser.parse_args()
@torch.no_grad()
def main(): def main():
args = parse_args() args = get_arguments()
text = load_dataset(**parse_kwargs(args.dataset), split="test")["text"] for arg_name, arg_var in args.__dict__.items():
additional_model_kwargs = parse_kwargs(args.model_kwargs) print(f"{arg_name:<16} : {arg_var}")
tokenizer = AutoTokenizer.from_pretrained(args.model_path, trust_remote_code=True)
tokenizer.model_max_length = args.seq_len
en_datasets = ["narrativeqa", "qasper", "multifieldqa_en", "hotpotqa", "2wikimqa", "musique", "gov_report",
"qmsum", "multi_news", "trec", "triviaqa", "samsum", "passage_count", "passage_retrieval_en"]
zh_datasets = ["multifieldqa_zh", "dureader", "vcsum", "lsht", "passage_retrieval_zh"]
if args.datasets is None:
if args.language == 'en':
datasets = en_datasets
elif args.language == 'zh':
datasets = zh_datasets
else:
datasets = en_datasets + zh_datasets
else:
datasets = args.datasets
dataset_all = []
for dataset_name in datasets:
data_ = load_dataset(os.path.join(args.dataset_path, dataset_name), split='test') if args.dataset_path \
else load_dataset('THUDM/LongBench', f'{dataset_name}', split='test')
dataset_all.append(data_)
data = concatenate_datasets(dataset_all)
encoded_texts = []
pbar = tqdm(data)
for i, data_i in enumerate(pbar):
encoded_text = tokenizer.encode(data_i['context'], return_tensors='pt', truncation=True)
pbar.set_description(f"seq_len: {len(encoded_text[0])}, n_data: {len(encoded_texts)}")
if len(encoded_text[0]) < args.seq_len:
continue
encoded_texts.append(encoded_text)
summary = {} summary = {}
for precision in args.precisions: for precision in args.precisions:
model_kwargs = additional_model_kwargs.copy() model_kwargs = {}
if precision in ggml_tensor_qtype.keys(): if precision in ggml_tensor_qtype.keys():
model_kwargs['load_in_low_bit'] = precision model_kwargs['load_in_low_bit'] = precision
else: else:
model_kwargs['torch_dtype'] = getattr(torch, precision) model_kwargs['torch_dtype'] = getattr(torch, precision)
print(model_kwargs) print(model_kwargs)
ppl_evaluator = BigDLPPL(model_path=args.model_path, device=args.device, **model_kwargs) ppl_evaluator = BigDLPPL(model_path=args.model_path, device=args.device, **model_kwargs)
ppl = ppl_evaluator.perplexity_hf(text) ppl = ppl_evaluator.perplexity_hf(encoded_texts)
summary[precision] = ppl summary[precision] = ppl
print(summary) print(summary)
main() if __name__ == "__main__":
main()