* add ppl benchmark * add license * add readme * add dataset argument * add dataset usage * fixed low bit args * correct result * fix terminal display * fix ppl update * enable fp16 fp32 bf16 * format the desc * fix model_kwargs * add more readme
58 lines
No EOL
2.1 KiB
Python
58 lines
No EOL
2.1 KiB
Python
#
|
|
# Copyright 2016 The BigDL Authors.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
#
|
|
|
|
import torch
|
|
from bigdl.llm.ggml.quantize import ggml_tensor_qtype
|
|
from ppl import BigDLPPL
|
|
from datasets import load_dataset
|
|
import argparse
|
|
|
|
|
|
def parse_kwargs(kwstr):
|
|
kvpair = [item.split('=') for item in kwstr.split(',') if item != ""]
|
|
return {k:v for k, v in kvpair}
|
|
|
|
|
|
def parse_args():
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("--model_path", required=True, type=str)
|
|
parser.add_argument("--precisions", required=False, type=str, default=None, nargs='+')
|
|
parser.add_argument("--model_kwargs", required=False, type=str, default="")
|
|
parser.add_argument("--torch_dtype", type=str, default=None)
|
|
parser.add_argument("--device", type=str, default=None)
|
|
parser.add_argument("--dataset", type=str, default='path=wikitext,name=wikitext-2-raw-v1')
|
|
|
|
return parser.parse_args()
|
|
|
|
|
|
def main():
|
|
args = parse_args()
|
|
text = load_dataset(**parse_kwargs(args.dataset), split="test")["text"]
|
|
additional_model_kwargs = parse_kwargs(args.model_kwargs)
|
|
summary = {}
|
|
for precision in args.precisions:
|
|
model_kwargs = additional_model_kwargs
|
|
if precision in ggml_tensor_qtype.keys():
|
|
model_kwargs['load_in_low_bit'] = precision
|
|
else:
|
|
model_kwargs['torch_dtype'] = getattr(torch, precision)
|
|
print(model_kwargs)
|
|
ppl_evaluator = BigDLPPL(model_path=args.model_path, device=args.device, **model_kwargs)
|
|
ppl = ppl_evaluator.perplexity_hf(text)
|
|
summary[precision] = ppl
|
|
print(summary)
|
|
|
|
main() |