From a8c866c32b7155fc7b1292595210d85f5e094bae Mon Sep 17 00:00:00 2001 From: "Chen, Zhentao" Date: Thu, 18 Jan 2024 17:54:28 +0800 Subject: [PATCH] add ppl benchmark (#9914) * add ppl benchmark * add license * add readme * add dataset argument * add dataset usage * fixed low bit args * correct result * fix terminal display * fix ppl update * enable fp16 fp32 bf16 * format the desc * fix model_kwargs * add more readme --- python/llm/dev/benchmark/perplexity/README.md | 11 +++ python/llm/dev/benchmark/perplexity/ppl.py | 84 +++++++++++++++++++ python/llm/dev/benchmark/perplexity/run.py | 58 +++++++++++++ 3 files changed, 153 insertions(+) create mode 100644 python/llm/dev/benchmark/perplexity/README.md create mode 100644 python/llm/dev/benchmark/perplexity/ppl.py create mode 100644 python/llm/dev/benchmark/perplexity/run.py diff --git a/python/llm/dev/benchmark/perplexity/README.md b/python/llm/dev/benchmark/perplexity/README.md new file mode 100644 index 00000000..359feeb4 --- /dev/null +++ b/python/llm/dev/benchmark/perplexity/README.md @@ -0,0 +1,11 @@ +# Perplexity +Perplexity (PPL) is one of the most common metrics for evaluating language models. This benchmark implementation was from [transformers/perplexity](https://huggingface.co/docs/transformers/perplexity#perplexity-of-fixed-length-models) and [llm_perplexity.py](https://github.com/luo-cheng2021/ov.cpu.llm.experimental/blob/main/llm_perplexity.py) + +## HOW TO RUN +```python +python run.py --model_path --low_bit sym_int4 fp4 mixed_fp4 sym_int8 fp8_e5m2 fp8_e4m3 mixed_fp8 --device xpu --dataset path=,name= +``` +A more specific example to run perplexity on Llama2-7B and wikitext: +```python +python run.py --model_path meta-llama/Llama-2-7b-chat-hf --low_bit float16 sym_int4 --device xpu --dataset path=wikitext,name=wikitext-2-raw-v1 +``` \ No newline at end of file diff --git a/python/llm/dev/benchmark/perplexity/ppl.py b/python/llm/dev/benchmark/perplexity/ppl.py new file mode 100644 index 00000000..161ef634 --- /dev/null +++ b/python/llm/dev/benchmark/perplexity/ppl.py @@ -0,0 +1,84 @@ +# +# Copyright 2016 The BigDL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import intel_extension_for_pytorch as ipex +import numpy as np +import torch +from tqdm import tqdm +from transformers import AutoTokenizer + +from bigdl.llm.transformers import AutoModelForCausalLM + +class PPL: + def __init__(self): + self.nll = 0 + self.cnt = 0 + def __call__(self, all_logits, labels): + ''' + all_logits [seq_length, vocab_size] + labels [seq_length] + ''' + seq_length = all_logits.shape[0] + for i in range(0, seq_length - 1): + logits = all_logits[i, :] + max_logit = np.amax(logits) + sum_exp = np.sum(np.exp(logits - max_logit)) + # logits at time-step i is for predicting token at time-step (i+1) + next_tok = labels[i + 1] + log_softmax_of_tok = (logits[next_tok] - max_logit) - np.log(sum_exp) + self.nll += -log_softmax_of_tok + self.cnt += 1 + return np.exp(self.nll / self.cnt) + + def result(self): + return np.exp(self.nll / self.cnt) + + def __str__(self): + return f"PPL: {np.exp(self.nll / self.cnt):.3f}" + + +class BigDLPPL: + def __init__(self, model_path, device, **model_kwargs) -> None: + model_kwargs['trust_remote_code'] = model_kwargs.get('trust_remote_code', True) + model_kwargs['optimize_model'] = model_kwargs.get('optimize_model', True) + self.device = device + self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) + self.model = AutoModelForCausalLM.from_pretrained(model_path, **model_kwargs) + if 'xpu' in device: + import intel_extension_for_pytorch as ipex + self.model.to(device) + self.ppl_evaluator = PPL() + + def perplexity_hf(self, text): + inputs = self.tokenizer('\n\n'.join(text), return_tensors='pt').to(self.device) + input_ids = inputs['input_ids'] + # attention_mask = inputs['attention_mask'] + progress_bar = tqdm(range(0, input_ids.shape[1], 512)) + + for i0 in progress_bar: + input_ids_chunks = input_ids[:, i0:(i0+512)] + input_ids_chunks[:, 0] = 1 + with torch.no_grad(): + result = self.model.forward(input_ids_chunks, labels = input_ids_chunks, return_dict=True) + #print(f"ppl = {torch.exp(result.loss)}") + seq_len = result.logits.shape[1] + data = result.logits + data = data.to('cpu') + input_ids_chunks = input_ids_chunks.to('cpu') + self.ppl_evaluator(data.numpy()[0, seq_len//2:, :], input_ids_chunks.numpy()[0, seq_len//2:]) + progress_bar.set_description(f"{self.ppl_evaluator}") + + return self.ppl_evaluator.result() diff --git a/python/llm/dev/benchmark/perplexity/run.py b/python/llm/dev/benchmark/perplexity/run.py new file mode 100644 index 00000000..320ba85f --- /dev/null +++ b/python/llm/dev/benchmark/perplexity/run.py @@ -0,0 +1,58 @@ +# +# Copyright 2016 The BigDL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import torch +from bigdl.llm.ggml.quantize import ggml_tensor_qtype +from ppl import BigDLPPL +from datasets import load_dataset +import argparse + + +def parse_kwargs(kwstr): + kvpair = [item.split('=') for item in kwstr.split(',') if item != ""] + return {k:v for k, v in kvpair} + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--model_path", required=True, type=str) + parser.add_argument("--precisions", required=False, type=str, default=None, nargs='+') + parser.add_argument("--model_kwargs", required=False, type=str, default="") + parser.add_argument("--torch_dtype", type=str, default=None) + parser.add_argument("--device", type=str, default=None) + parser.add_argument("--dataset", type=str, default='path=wikitext,name=wikitext-2-raw-v1') + + return parser.parse_args() + + +def main(): + args = parse_args() + text = load_dataset(**parse_kwargs(args.dataset), split="test")["text"] + additional_model_kwargs = parse_kwargs(args.model_kwargs) + summary = {} + for precision in args.precisions: + model_kwargs = additional_model_kwargs + if precision in ggml_tensor_qtype.keys(): + model_kwargs['load_in_low_bit'] = precision + else: + model_kwargs['torch_dtype'] = getattr(torch, precision) + print(model_kwargs) + ppl_evaluator = BigDLPPL(model_path=args.model_path, device=args.device, **model_kwargs) + ppl = ppl_evaluator.perplexity_hf(text) + summary[precision] = ppl + print(summary) + +main() \ No newline at end of file