LLM: add Ceval benchmark test. (#9872)
* init ceval benchmark test. * upload dataset. * add other tests. * add qwen evaluator. * fix qwen evaluator style. * fix qwen evaluator style. * update qwen evaluator. * add llama evaluator. * update eval * fix typo. * fix * fix typo. * fix llama evaluator. * fix bug. * fix style. * delete dataset. * fix style. * fix style. * add README.md and fix typo. * fix comments. * remove run scripts
This commit is contained in:
parent
b909c5c9c2
commit
511cbcf773
6 changed files with 770 additions and 0 deletions
32
python/llm/dev/benchmark/ceval/README.md
Normal file
32
python/llm/dev/benchmark/ceval/README.md
Normal file
|
|
@ -0,0 +1,32 @@
|
||||||
|
## C-Eval Benchmark Test
|
||||||
|
|
||||||
|
C-Eval benchmark test allows users to test on [C-Eval](https://cevalbenchmark.com) datasets, which is a multi-level multi-discipline chinese evaluation suite for foundation models. It consists of 13948 multi-choice questions spanning 52 diverse disciplines and four difficulty levels. Please check [paper](https://arxiv.org/abs/2305.08322) and [github repo](https://github.com/hkust-nlp/ceval) for more information.
|
||||||
|
|
||||||
|
### Download dataset
|
||||||
|
Please download and unzip the dataset for evaluation.
|
||||||
|
```shell
|
||||||
|
wget https://huggingface.co/datasets/ceval/ceval-exam/resolve/main/ceval-exam.zip
|
||||||
|
mkdir data
|
||||||
|
mv ceval-exam.zip data
|
||||||
|
cd data; unzip ceval-exam.zip
|
||||||
|
```
|
||||||
|
|
||||||
|
### Run
|
||||||
|
You can run evaluation with following command.
|
||||||
|
```shell
|
||||||
|
bash run.sh
|
||||||
|
```
|
||||||
|
+ `run.sh`
|
||||||
|
```shell
|
||||||
|
python eval.py \
|
||||||
|
--model_family llama \
|
||||||
|
--model_path "path to model" \
|
||||||
|
--eval_type validation \
|
||||||
|
--device xpu \
|
||||||
|
--eval_data_path data \
|
||||||
|
--qtype sym_int4
|
||||||
|
```
|
||||||
|
|
||||||
|
> **Note**
|
||||||
|
>
|
||||||
|
> `eval_type` there is two types of evaluation, first type is `validation`, which runs on validation dataset and output evaluation scores. The second type is `test`, which runs on test dataset and output `submission.json` file for submission on https://cevalbenchmark.com to get the evaluation score.
|
||||||
310
python/llm/dev/benchmark/ceval/eval.py
Normal file
310
python/llm/dev/benchmark/ceval/eval.py
Normal file
|
|
@ -0,0 +1,310 @@
|
||||||
|
#
|
||||||
|
# Copyright 2016 The BigDL Authors.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
|
||||||
|
import os
|
||||||
|
import argparse
|
||||||
|
import pandas as pd
|
||||||
|
import torch
|
||||||
|
import json
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from bigdl.llm.utils.common.log4Error import invalidInputError
|
||||||
|
from evaluators.qwen import QwenEvaluator
|
||||||
|
from evaluators.llama import LlamaEvaluator
|
||||||
|
|
||||||
|
|
||||||
|
TASK_NAME_MAPPING = {
|
||||||
|
"computer_network": ["Computer Network", "\u8ba1\u7b97\u673a\u7f51\u7edc", "STEM"],
|
||||||
|
"operating_system": ["Operating System", "\u64cd\u4f5c\u7cfb\u7edf", "STEM"],
|
||||||
|
"computer_architecture": [
|
||||||
|
"Computer Architecture",
|
||||||
|
"\u8ba1\u7b97\u673a\u7ec4\u6210",
|
||||||
|
"STEM",
|
||||||
|
],
|
||||||
|
"college_programming": ["College Programming", "\u5927\u5b66\u7f16\u7a0b", "STEM"],
|
||||||
|
"college_physics": ["College Physics", "\u5927\u5b66\u7269\u7406", "STEM"],
|
||||||
|
"college_chemistry": ["College Chemistry", "\u5927\u5b66\u5316\u5b66", "STEM"],
|
||||||
|
"advanced_mathematics": [
|
||||||
|
"Advanced Mathematics",
|
||||||
|
"\u9ad8\u7b49\u6570\u5b66",
|
||||||
|
"STEM",
|
||||||
|
],
|
||||||
|
"probability_and_statistics": [
|
||||||
|
"Probability and Statistics",
|
||||||
|
"\u6982\u7387\u7edf\u8ba1",
|
||||||
|
"STEM",
|
||||||
|
],
|
||||||
|
"discrete_mathematics": [
|
||||||
|
"Discrete Mathematics",
|
||||||
|
"\u79bb\u6563\u6570\u5b66",
|
||||||
|
"STEM",
|
||||||
|
],
|
||||||
|
"electrical_engineer": [
|
||||||
|
"Electrical Engineer",
|
||||||
|
"\u6ce8\u518c\u7535\u6c14\u5de5\u7a0b\u5e08",
|
||||||
|
"STEM",
|
||||||
|
],
|
||||||
|
"metrology_engineer": [
|
||||||
|
"Metrology Engineer",
|
||||||
|
"\u6ce8\u518c\u8ba1\u91cf\u5e08",
|
||||||
|
"STEM",
|
||||||
|
],
|
||||||
|
"high_school_mathematics": [
|
||||||
|
"High School Mathematics",
|
||||||
|
"\u9ad8\u4e2d\u6570\u5b66",
|
||||||
|
"STEM",
|
||||||
|
],
|
||||||
|
"high_school_physics": ["High School Physics", "\u9ad8\u4e2d\u7269\u7406", "STEM"],
|
||||||
|
"high_school_chemistry": [
|
||||||
|
"High School Chemistry",
|
||||||
|
"\u9ad8\u4e2d\u5316\u5b66",
|
||||||
|
"STEM",
|
||||||
|
],
|
||||||
|
"high_school_biology": ["High School Biology", "\u9ad8\u4e2d\u751f\u7269", "STEM"],
|
||||||
|
"middle_school_mathematics": [
|
||||||
|
"Middle School Mathematics",
|
||||||
|
"\u521d\u4e2d\u6570\u5b66",
|
||||||
|
"STEM",
|
||||||
|
],
|
||||||
|
"middle_school_biology": [
|
||||||
|
"Middle School Biology",
|
||||||
|
"\u521d\u4e2d\u751f\u7269",
|
||||||
|
"STEM",
|
||||||
|
],
|
||||||
|
"middle_school_physics": [
|
||||||
|
"Middle School Physics",
|
||||||
|
"\u521d\u4e2d\u7269\u7406",
|
||||||
|
"STEM",
|
||||||
|
],
|
||||||
|
"middle_school_chemistry": [
|
||||||
|
"Middle School Chemistry",
|
||||||
|
"\u521d\u4e2d\u5316\u5b66",
|
||||||
|
"STEM",
|
||||||
|
],
|
||||||
|
"veterinary_medicine": ["Veterinary Medicine", "\u517d\u533b\u5b66", "STEM"],
|
||||||
|
"college_economics": [
|
||||||
|
"College Economics",
|
||||||
|
"\u5927\u5b66\u7ecf\u6d4e\u5b66",
|
||||||
|
"Social Science",
|
||||||
|
],
|
||||||
|
"business_administration": [
|
||||||
|
"Business Administration",
|
||||||
|
"\u5de5\u5546\u7ba1\u7406",
|
||||||
|
"Social Science",
|
||||||
|
],
|
||||||
|
"marxism": [
|
||||||
|
"Marxism",
|
||||||
|
"\u9a6c\u514b\u601d\u4e3b\u4e49\u57fa\u672c\u539f\u7406",
|
||||||
|
"Social Science",
|
||||||
|
],
|
||||||
|
"mao_zedong_thought": [
|
||||||
|
"Mao Zedong Thought",
|
||||||
|
"\u6bdb\u6cfd\u4e1c\u601d\u60f3\u548c\u4e2d\u56fd\u7279\u8272\u793e\u4f1a\u4e3b\u4e49\u7406\u8bba\u4f53\u7cfb\u6982\u8bba",
|
||||||
|
"Social Science",
|
||||||
|
],
|
||||||
|
"education_science": ["Education Science", "\u6559\u80b2\u5b66", "Social Science"],
|
||||||
|
"teacher_qualification": [
|
||||||
|
"Teacher Qualification",
|
||||||
|
"\u6559\u5e08\u8d44\u683c",
|
||||||
|
"Social Science",
|
||||||
|
],
|
||||||
|
"high_school_politics": [
|
||||||
|
"High School Politics",
|
||||||
|
"\u9ad8\u4e2d\u653f\u6cbb",
|
||||||
|
"Social Science",
|
||||||
|
],
|
||||||
|
"high_school_geography": [
|
||||||
|
"High School Geography",
|
||||||
|
"\u9ad8\u4e2d\u5730\u7406",
|
||||||
|
"Social Science",
|
||||||
|
],
|
||||||
|
"middle_school_politics": [
|
||||||
|
"Middle School Politics",
|
||||||
|
"\u521d\u4e2d\u653f\u6cbb",
|
||||||
|
"Social Science",
|
||||||
|
],
|
||||||
|
"middle_school_geography": [
|
||||||
|
"Middle School Geography",
|
||||||
|
"\u521d\u4e2d\u5730\u7406",
|
||||||
|
"Social Science",
|
||||||
|
],
|
||||||
|
"modern_chinese_history": [
|
||||||
|
"Modern Chinese History",
|
||||||
|
"\u8fd1\u4ee3\u53f2\u7eb2\u8981",
|
||||||
|
"Humanities",
|
||||||
|
],
|
||||||
|
"ideological_and_moral_cultivation": [
|
||||||
|
"Ideological and Moral Cultivation",
|
||||||
|
"\u601d\u60f3\u9053\u5fb7\u4fee\u517b\u4e0e\u6cd5\u5f8b\u57fa\u7840",
|
||||||
|
"Humanities",
|
||||||
|
],
|
||||||
|
"logic": ["Logic", "\u903b\u8f91\u5b66", "Humanities"],
|
||||||
|
"law": ["Law", "\u6cd5\u5b66", "Humanities"],
|
||||||
|
"chinese_language_and_literature": [
|
||||||
|
"Chinese Language and Literature",
|
||||||
|
"\u4e2d\u56fd\u8bed\u8a00\u6587\u5b66",
|
||||||
|
"Humanities",
|
||||||
|
],
|
||||||
|
"art_studies": ["Art Studies", "\u827a\u672f\u5b66", "Humanities"],
|
||||||
|
"professional_tour_guide": [
|
||||||
|
"Professional Tour Guide",
|
||||||
|
"\u5bfc\u6e38\u8d44\u683c",
|
||||||
|
"Humanities",
|
||||||
|
],
|
||||||
|
"legal_professional": [
|
||||||
|
"Legal Professional",
|
||||||
|
"\u6cd5\u5f8b\u804c\u4e1a\u8d44\u683c",
|
||||||
|
"Humanities",
|
||||||
|
],
|
||||||
|
"high_school_chinese": [
|
||||||
|
"High School Chinese",
|
||||||
|
"\u9ad8\u4e2d\u8bed\u6587",
|
||||||
|
"Humanities",
|
||||||
|
],
|
||||||
|
"high_school_history": [
|
||||||
|
"High School History",
|
||||||
|
"\u9ad8\u4e2d\u5386\u53f2",
|
||||||
|
"Humanities",
|
||||||
|
],
|
||||||
|
"middle_school_history": [
|
||||||
|
"Middle School History",
|
||||||
|
"\u521d\u4e2d\u5386\u53f2",
|
||||||
|
"Humanities",
|
||||||
|
],
|
||||||
|
"civil_servant": ["Civil Servant", "\u516c\u52a1\u5458", "Other"],
|
||||||
|
"sports_science": ["Sports Science", "\u4f53\u80b2\u5b66", "Other"],
|
||||||
|
"plant_protection": ["Plant Protection", "\u690d\u7269\u4fdd\u62a4", "Other"],
|
||||||
|
"basic_medicine": ["Basic Medicine", "\u57fa\u7840\u533b\u5b66", "Other"],
|
||||||
|
"clinical_medicine": ["Clinical Medicine", "\u4e34\u5e8a\u533b\u5b66", "Other"],
|
||||||
|
"urban_and_rural_planner": [
|
||||||
|
"Urban and Rural Planner",
|
||||||
|
"\u6ce8\u518c\u57ce\u4e61\u89c4\u5212\u5e08",
|
||||||
|
"Other",
|
||||||
|
],
|
||||||
|
"accountant": ["Accountant", "\u6ce8\u518c\u4f1a\u8ba1\u5e08", "Other"],
|
||||||
|
"fire_engineer": [
|
||||||
|
"Fire Engineer",
|
||||||
|
"\u6ce8\u518c\u6d88\u9632\u5de5\u7a0b\u5e08",
|
||||||
|
"Other",
|
||||||
|
],
|
||||||
|
"environmental_impact_assessment_engineer": [
|
||||||
|
"Environmental Impact Assessment Engineer",
|
||||||
|
"\u73af\u5883\u5f71\u54cd\u8bc4\u4ef7\u5de5\u7a0b\u5e08",
|
||||||
|
"Other",
|
||||||
|
],
|
||||||
|
"tax_accountant": ["Tax Accountant", "\u7a0e\u52a1\u5e08", "Other"],
|
||||||
|
"physician": ["Physician", "\u533b\u5e08\u8d44\u683c", "Other"],
|
||||||
|
}
|
||||||
|
hard_list = [
|
||||||
|
"advanced_mathematics",
|
||||||
|
"discrete_mathematics",
|
||||||
|
"probability_and_statistics",
|
||||||
|
"college_physics",
|
||||||
|
"college_chemistry",
|
||||||
|
"high_school_mathematics",
|
||||||
|
"high_school_physics",
|
||||||
|
"high_school_chemistry",
|
||||||
|
]
|
||||||
|
choices = ["A", "B", "C", "D"]
|
||||||
|
|
||||||
|
|
||||||
|
def cal_ceval(res):
|
||||||
|
acc_sum_dict = dict()
|
||||||
|
acc_norm_sum_dict = dict()
|
||||||
|
cnt_dict = dict()
|
||||||
|
acc_sum = 0.0
|
||||||
|
cnt = 0
|
||||||
|
hard_cnt = 0
|
||||||
|
hard_acc_sum = 0.0
|
||||||
|
for tt in res.keys():
|
||||||
|
name = tt.split("-")[-1]
|
||||||
|
acc_sum += float(res[tt])
|
||||||
|
cnt += 1
|
||||||
|
class_ = TASK_NAME_MAPPING[name][2]
|
||||||
|
if class_ not in acc_sum_dict:
|
||||||
|
acc_sum_dict[class_] = 0.0
|
||||||
|
acc_norm_sum_dict[class_] = 0.0
|
||||||
|
cnt_dict[class_] = 0.0
|
||||||
|
if name in hard_list:
|
||||||
|
hard_cnt += 1
|
||||||
|
hard_acc_sum += float(res[tt])
|
||||||
|
acc_sum_dict[class_] += float(res[tt])
|
||||||
|
cnt_dict[class_] += 1
|
||||||
|
print("\n\n\n")
|
||||||
|
for k in ["STEM", "Social Science", "Humanities", "Other"]:
|
||||||
|
if k in cnt_dict:
|
||||||
|
print("%s acc: %.2f " % (k, acc_sum_dict[k] / cnt_dict[k]))
|
||||||
|
if hard_cnt > 0:
|
||||||
|
print("Hard acc:%.2f " % (hard_acc_sum / hard_cnt))
|
||||||
|
print("AVERAGE acc:%.2f " % (acc_sum / cnt))
|
||||||
|
|
||||||
|
|
||||||
|
def main(args, evaluator):
|
||||||
|
if args.eval_type == "validation":
|
||||||
|
result = {}
|
||||||
|
for subject_name in tqdm(TASK_NAME_MAPPING.keys()):
|
||||||
|
val_file_path = os.path.join(
|
||||||
|
args.eval_data_path, "val", f"{subject_name}_val.csv"
|
||||||
|
)
|
||||||
|
val_df = pd.read_csv(val_file_path)
|
||||||
|
score, _ = evaluator.eval_subject(subject_name, val_df, args.eval_type)
|
||||||
|
result[subject_name] = score
|
||||||
|
cal_ceval(result)
|
||||||
|
elif args.eval_type == "test":
|
||||||
|
all_answers = {}
|
||||||
|
for subject_name in tqdm(TASK_NAME_MAPPING.keys()):
|
||||||
|
test_file_path = os.path.join(
|
||||||
|
args.eval_data_path, "test", f"{subject_name}_test.csv"
|
||||||
|
)
|
||||||
|
test_df = pd.read_csv(test_file_path)
|
||||||
|
_, answers = evaluator.eval_subject(subject_name, test_df, args.eval_type)
|
||||||
|
all_answers[subject_name] = answers
|
||||||
|
json.dump(all_answers, open('submission.json','w'), ensure_ascii=False, indent=4)
|
||||||
|
else:
|
||||||
|
invalidInputError(False,
|
||||||
|
"Invalid eval_type, please use validation or test.")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("--model_family", type=str, default="llama")
|
||||||
|
parser.add_argument("--model_path", type=str, default="meta-llama/Llama-2-7b-chat-hf")
|
||||||
|
parser.add_argument("--eval_type", type=str, default="validation")
|
||||||
|
parser.add_argument("--device", type=str, default="xpu")
|
||||||
|
parser.add_argument("--eval_data_path", type=str, default="data")
|
||||||
|
parser.add_argument("--qtype", type=str, default="sym_int4")
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
if args.model_family == "llama":
|
||||||
|
evaluator = LlamaEvaluator(
|
||||||
|
choices=choices,
|
||||||
|
model_path=args.model_path,
|
||||||
|
device=args.device,
|
||||||
|
qtype=args.qtype
|
||||||
|
)
|
||||||
|
elif args.model_family == "qwen":
|
||||||
|
evaluator = QwenEvaluator(
|
||||||
|
choices=choices,
|
||||||
|
model_path=args.model_path,
|
||||||
|
device=args.device,
|
||||||
|
qtype=args.qtype
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
invalidInputError(
|
||||||
|
False,
|
||||||
|
"Invalid model_family, currently support llama and qwen only.")
|
||||||
|
main(args, evaluator=evaluator)
|
||||||
31
python/llm/dev/benchmark/ceval/evaluators/evaluator.py
Normal file
31
python/llm/dev/benchmark/ceval/evaluators/evaluator.py
Normal file
|
|
@ -0,0 +1,31 @@
|
||||||
|
#
|
||||||
|
# Copyright 2016 The BigDL Authors.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
|
||||||
|
class Evaluator:
|
||||||
|
def __init__(self, choices, model_path, device, qtype):
|
||||||
|
self.choices = choices
|
||||||
|
self.model_path = model_path
|
||||||
|
self.device = device
|
||||||
|
self.qtype = qtype
|
||||||
|
|
||||||
|
def format_example(self, line, **kwargs):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def eval_subject(self, subject_name, test_df, eval_type, **kwargs):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def extract_answer(self, response, row, **kwargs):
|
||||||
|
pass
|
||||||
236
python/llm/dev/benchmark/ceval/evaluators/llama.py
Normal file
236
python/llm/dev/benchmark/ceval/evaluators/llama.py
Normal file
|
|
@ -0,0 +1,236 @@
|
||||||
|
#
|
||||||
|
# Copyright 2016 The BigDL Authors.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
# refer to https://github.com/ymcui/Chinese-LLaMA-Alpaca-2/blob/main/scripts/ceval/llama_evaluator.py
|
||||||
|
|
||||||
|
import re
|
||||||
|
import random
|
||||||
|
from tqdm import tqdm
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from transformers import LlamaTokenizer, GenerationConfig
|
||||||
|
|
||||||
|
from bigdl.llm.transformers import AutoModelForCausalLM
|
||||||
|
from evaluators.evaluator import Evaluator
|
||||||
|
|
||||||
|
|
||||||
|
DEFAULT_SYSTEM_PROMPT = """You are a helpful assistant. 你是一个乐于助人的助手。"""
|
||||||
|
|
||||||
|
|
||||||
|
class LlamaEvaluator(Evaluator):
|
||||||
|
def __init__(self, choices, model_path="meta-llama/Llama-2-7b-chat-hf", device="xpu", qtype="sym_int4"):
|
||||||
|
super(LlamaEvaluator, self).__init__(choices, model_path, device, qtype)
|
||||||
|
self.tokenizer = LlamaTokenizer.from_pretrained(
|
||||||
|
self.model_path,
|
||||||
|
trust_remote_code=True
|
||||||
|
)
|
||||||
|
self.model = AutoModelForCausalLM.from_pretrained(
|
||||||
|
self.model_path,
|
||||||
|
load_in_low_bit=self.qtype,
|
||||||
|
optimize_model=True,
|
||||||
|
use_cache=True,
|
||||||
|
trust_remote_code=True
|
||||||
|
).eval().to(self.device)
|
||||||
|
self.generation_config = GenerationConfig(
|
||||||
|
temperature=0.2,
|
||||||
|
top_k=40,
|
||||||
|
top_p=0.9,
|
||||||
|
do_sample=True,
|
||||||
|
num_beams=1,
|
||||||
|
repetition_penalty=1.1,
|
||||||
|
max_new_tokens=20
|
||||||
|
)
|
||||||
|
self.sA_id = self.tokenizer.encode("A", add_special_tokens=False)[0]
|
||||||
|
self.sB_id = self.tokenizer.encode("B", add_special_tokens=False)[0]
|
||||||
|
self.sC_id = self.tokenizer.encode("C", add_special_tokens=False)[0]
|
||||||
|
self.sD_id = self.tokenizer.encode("D", add_special_tokens=False)[0]
|
||||||
|
self.A_id = self.tokenizer.encode(":A")[-1]
|
||||||
|
self.B_id = self.tokenizer.encode(":B")[-1]
|
||||||
|
self.C_id = self.tokenizer.encode(":C")[-1]
|
||||||
|
self.D_id = self.tokenizer.encode(":D")[-1]
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def eval_subject(self, subject_name,
|
||||||
|
test_df,
|
||||||
|
eval_type="validation",
|
||||||
|
dev_df=None,
|
||||||
|
few_shot=False,
|
||||||
|
cot=False,
|
||||||
|
with_prompt=False,
|
||||||
|
constrained_decoding=False):
|
||||||
|
all_answers = {}
|
||||||
|
if constrained_decoding is True:
|
||||||
|
self.generation_config.output_scores = True
|
||||||
|
self.generation_config.return_dict_in_generate = True
|
||||||
|
self.generation_config.max_new_tokens = 1
|
||||||
|
self.generation_config.top_p = 1.0
|
||||||
|
self.generation_config.top_k = 0
|
||||||
|
|
||||||
|
correct_num = 0
|
||||||
|
if few_shot:
|
||||||
|
if with_prompt:
|
||||||
|
history = self.generate_alpaca2_few_shot_prompt(subject_name, dev_df, cot=cot)
|
||||||
|
else:
|
||||||
|
history = self.generate_llama2_few_shot_prompt(subject_name, dev_df, cot=cot)
|
||||||
|
else:
|
||||||
|
history = ''
|
||||||
|
answers = ['NA'] * len(test_df) if (eval_type=="test") is True else list(test_df['answer'])
|
||||||
|
for row_index, row in tqdm(test_df.iterrows(), total=len(test_df)):
|
||||||
|
question = self.format_example(row, include_answer=False, cot=cot,with_prompt=with_prompt)
|
||||||
|
instruction = question
|
||||||
|
if with_prompt:
|
||||||
|
prompt_template = (
|
||||||
|
"[INST] <<SYS>>\n"
|
||||||
|
"{system_prompt}\n"
|
||||||
|
"<</SYS>>\n\n"
|
||||||
|
"{instruction} [/INST]"
|
||||||
|
)
|
||||||
|
|
||||||
|
instruction = prompt_template.format_map({'instruction': instruction,'system_prompt':DEFAULT_SYSTEM_PROMPT})
|
||||||
|
instruction = history + instruction
|
||||||
|
inputs = self.tokenizer(instruction, return_tensors="pt")
|
||||||
|
generation_output = self.model.generate(
|
||||||
|
input_ids = inputs["input_ids"].to(self.device),
|
||||||
|
attention_mask = inputs['attention_mask'].to(self.device),
|
||||||
|
eos_token_id=self.tokenizer.eos_token_id,
|
||||||
|
pad_token_id=self.tokenizer.pad_token_id,
|
||||||
|
generation_config = self.generation_config
|
||||||
|
)
|
||||||
|
|
||||||
|
_ , length = inputs.input_ids.shape
|
||||||
|
if constrained_decoding is True:
|
||||||
|
logits = generation_output.scores[0][0]
|
||||||
|
|
||||||
|
logits = logits.float().cpu().detach()
|
||||||
|
choices1_logits = logits[[self.sA_id,self.sB_id,self.sC_id,self.sD_id]]
|
||||||
|
choices2_logits = logits[[self.A_id,self.B_id,self.C_id,self.D_id]]
|
||||||
|
choicesAll_logits = (choices1_logits + choices2_logits).numpy()
|
||||||
|
assert not (np.any(np.isinf(choicesAll_logits)) or np.any(np.isnan(choicesAll_logits)))
|
||||||
|
ans = {0: "A", 1: "B", 2: "C", 3: "D"}[np.argmax(choicesAll_logits)]
|
||||||
|
response = self.tokenizer.decode([logits.argmax(-1).item()])
|
||||||
|
else:
|
||||||
|
response = self.tokenizer.decode(generation_output[0, length:], skip_special_tokens=True)
|
||||||
|
ans, _ = self.extract_answer(response, row)
|
||||||
|
if ans == answers[row_index]:
|
||||||
|
correct_num += 1
|
||||||
|
|
||||||
|
all_answers[str(row_index)] = ans
|
||||||
|
|
||||||
|
correct_ratio = 100*correct_num/len(answers)
|
||||||
|
|
||||||
|
return correct_ratio, all_answers
|
||||||
|
|
||||||
|
|
||||||
|
def format_example(self, line, include_answer=True, cot=False, with_prompt=False):
|
||||||
|
example = line['question']
|
||||||
|
for choice in self.choices:
|
||||||
|
example += f'\n{choice}. {line[f"{choice}"]}'
|
||||||
|
if include_answer:
|
||||||
|
if cot:
|
||||||
|
example += "\n答案:让我们一步一步思考,\n" + \
|
||||||
|
line["explanation"] + f"\n所以答案是{line['answer']}。\n\n"
|
||||||
|
else:
|
||||||
|
example += '\n答案:' + line["answer"] + '\n\n'
|
||||||
|
else:
|
||||||
|
if with_prompt is False:
|
||||||
|
if cot:
|
||||||
|
example += "\n答案:让我们一步一步思考,\n1."
|
||||||
|
else:
|
||||||
|
example += '\n答案:'
|
||||||
|
else:
|
||||||
|
if cot:
|
||||||
|
example += "\n答案是什么?让我们一步一步思考,\n1."
|
||||||
|
else:
|
||||||
|
example += '\n答案:'
|
||||||
|
return example
|
||||||
|
|
||||||
|
|
||||||
|
def generate_llama2_few_shot_prompt(self, subject, dev_df, cot=False):
|
||||||
|
prompt = f"以下是中国关于{subject}考试的单项选择题,请选出其中的正确答案。\n\n"
|
||||||
|
k = self.k
|
||||||
|
if self.k == -1:
|
||||||
|
k = dev_df.shape[0]
|
||||||
|
for i in range(k):
|
||||||
|
prompt += self.format_example(
|
||||||
|
dev_df.iloc[i, :],
|
||||||
|
include_answer=True,
|
||||||
|
cot=cot
|
||||||
|
)
|
||||||
|
return prompt
|
||||||
|
|
||||||
|
|
||||||
|
def generate_alpaca2_few_shot_prompt(self, subject, dev_df, cot=False):
|
||||||
|
prompt = f"以下是中国关于{subject}考试的单项选择题,请选出其中的正确答案。\n\n"
|
||||||
|
prompt_template = (
|
||||||
|
"[INST] <<SYS>>\n"
|
||||||
|
"{system_prompt}\n"
|
||||||
|
"<</SYS>>\n\n"
|
||||||
|
"{instruction} [/INST]好的,我会结合{subject}相关知识回答"
|
||||||
|
)
|
||||||
|
|
||||||
|
prompt = prompt_template.format_map({'instruction':prompt,'system_prompt':DEFAULT_SYSTEM_PROMPT,'subject':subject})
|
||||||
|
k = self.k
|
||||||
|
if self.k == -1:
|
||||||
|
k = dev_df.shape[0]
|
||||||
|
for i in range(k):
|
||||||
|
line = dev_df.iloc[i, :]
|
||||||
|
q=line['question']
|
||||||
|
for choice in self.choices:
|
||||||
|
q += f'\n{choice}. {line[f"{choice}"]}'
|
||||||
|
|
||||||
|
a = line['answer']
|
||||||
|
prompt += "[INST] "+q+"\n答案:[/INST]"+a+"\n"
|
||||||
|
return prompt
|
||||||
|
|
||||||
|
|
||||||
|
def extract_answer(self, response, row):
|
||||||
|
m = re.findall(r'所以答案是(.+?)。', response, re.M)
|
||||||
|
if len(m) > 0 and m[-1] in self.choices:
|
||||||
|
return m[-1], True
|
||||||
|
answer_patterns = [
|
||||||
|
r'([ABCD])是正确的',
|
||||||
|
r'选项([ABCD])正确',
|
||||||
|
r'答案为([ABCD])',
|
||||||
|
r'答案是([ABCD])',
|
||||||
|
r'答案([ABCD])',
|
||||||
|
r'选择([ABCD])',
|
||||||
|
r'答案:([ABCD])',
|
||||||
|
r'选择答案([ABCD])'
|
||||||
|
]
|
||||||
|
# RE extraction
|
||||||
|
for answer_pattern in answer_patterns:
|
||||||
|
m = re.search(answer_pattern, response, re.M)
|
||||||
|
if m:
|
||||||
|
answer = m.group(1)
|
||||||
|
return answer, False
|
||||||
|
# only containing one choice-character
|
||||||
|
m = re.findall(r'[ABCD]', response, re.M)
|
||||||
|
if len(m) >= 1:
|
||||||
|
answer = m[0]
|
||||||
|
return answer, False
|
||||||
|
# only containing one choice-context
|
||||||
|
choices_dict = {}
|
||||||
|
pattern = ""
|
||||||
|
for c in self.choices:
|
||||||
|
choices_dict[str(row[f'{c}'])] = c
|
||||||
|
pattern += re.escape(str(row[f'{c}']))+"|"
|
||||||
|
pattern = pattern[:-1]
|
||||||
|
m = re.findall(pattern, response, re.M)
|
||||||
|
print("w/ escape:",repr(pattern),response,(len(m)>=1))
|
||||||
|
if len(m) >= 1:
|
||||||
|
answer = choices_dict[m[0]]
|
||||||
|
return answer, False
|
||||||
|
return random.choice('ABCD'), False
|
||||||
154
python/llm/dev/benchmark/ceval/evaluators/qwen.py
Normal file
154
python/llm/dev/benchmark/ceval/evaluators/qwen.py
Normal file
|
|
@ -0,0 +1,154 @@
|
||||||
|
#
|
||||||
|
# Copyright 2016 The BigDL Authors.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
# refer to https://github.com/QwenLM/Qwen/blob/main/eval/evaluate_chat_ceval.py
|
||||||
|
|
||||||
|
import re
|
||||||
|
from tqdm import tqdm
|
||||||
|
import torch
|
||||||
|
from thefuzz import process
|
||||||
|
from transformers import AutoTokenizer
|
||||||
|
from transformers.generation import GenerationConfig
|
||||||
|
|
||||||
|
from bigdl.llm.transformers import AutoModelForCausalLM
|
||||||
|
from evaluators.evaluator import Evaluator
|
||||||
|
|
||||||
|
|
||||||
|
class QwenEvaluator(Evaluator):
|
||||||
|
def __init__(self, choices, model_path="Qwen/Qwen-7B-Chat", device="xpu", qtype="sym_int4"):
|
||||||
|
super(QwenEvaluator, self).__init__(choices, model_path, device, qtype)
|
||||||
|
self.tokenizer = AutoTokenizer.from_pretrained(
|
||||||
|
self.model_path,
|
||||||
|
trust_remote_code=True
|
||||||
|
)
|
||||||
|
self.model = AutoModelForCausalLM.from_pretrained(
|
||||||
|
self.model_path,
|
||||||
|
load_in_low_bit=self.qtype,
|
||||||
|
optimize_model=True,
|
||||||
|
use_cache=True,
|
||||||
|
trust_remote_code=True
|
||||||
|
).eval().to(self.device)
|
||||||
|
self.model.generation_config = GenerationConfig.from_pretrained(
|
||||||
|
self.model_path,
|
||||||
|
trust_remote_code=True
|
||||||
|
)
|
||||||
|
self.model.generation_config.do_sample = False # use greedy decoding
|
||||||
|
self.model.generation_config.repetition_penalty = 1.0 # disable repetition penalty
|
||||||
|
|
||||||
|
|
||||||
|
def process_before_extraction(self, gen, question, choice_dict):
|
||||||
|
|
||||||
|
question_split = question.rstrip("。").split("。")[-1].split("_")
|
||||||
|
|
||||||
|
if len(question_split[0].strip()) > 4:
|
||||||
|
gen = gen.replace(question_split[0], "答案是")
|
||||||
|
if len(question_split[-1].strip()) > 4:
|
||||||
|
gen = gen.replace(question_split[-1], "")
|
||||||
|
|
||||||
|
for key, val in sorted(choice_dict.items(), key=lambda x: len(x[1]), reverse=True):
|
||||||
|
gen = gen.replace(val.rstrip("。"), key)
|
||||||
|
return gen
|
||||||
|
|
||||||
|
|
||||||
|
def count_substr(self, gen, pattern):
|
||||||
|
return len(re.findall(pattern, gen))
|
||||||
|
|
||||||
|
|
||||||
|
def extract_choice(self, gen, prompt, choice_list):
|
||||||
|
res = re.search(
|
||||||
|
r"(?:(?:选|选择|选定)[::]?\s*|(?:(?:答案|选项)(?![^ABCD]{0,10}?(?:不|非)[^ABCD]{0,10}?(?:是|选|为|:|:|】))[^ABCD]{0,10}?(?:是|选|为|:|:|】))[^ABCD]{0,10}?)(A|B|C|D)(?:选项)?(?:\)|。|\.|,|,|.|、|A|B|C|D|$|:|:|\)|))",
|
||||||
|
gen,
|
||||||
|
)
|
||||||
|
|
||||||
|
if res is None:
|
||||||
|
res = re.search(
|
||||||
|
r"(A|B|C|D)(?:选?项)?(?![^ABCD]{0,4}?(?:不|非)[^ABCD]{0,4}?(?:正确|对[的,。:]|符合))[^ABCD]{0,4}?(?:正确|对[的,。:]|符合)",
|
||||||
|
gen,
|
||||||
|
)
|
||||||
|
|
||||||
|
if res is None:
|
||||||
|
res = re.search(r"^[\((]?(A|B|C|D)(?:。|\)|)|\.|,|,|.|:|:|$)", gen)
|
||||||
|
|
||||||
|
if res is None:
|
||||||
|
res = re.search(r"(?<![a-zA-Z])(A|B|C|D)(?![a-zA-Z=])", gen)
|
||||||
|
|
||||||
|
if res is None:
|
||||||
|
return self.choices[choice_list.index(process.extractOne(gen, choice_list)[0])]
|
||||||
|
return res.group(1)
|
||||||
|
|
||||||
|
|
||||||
|
def format_example(self, line):
|
||||||
|
example = line["question"] + "\n\n"
|
||||||
|
for choice in self.choices:
|
||||||
|
example += f'{choice}. {line[f"{choice}"]}\n'
|
||||||
|
return example
|
||||||
|
|
||||||
|
|
||||||
|
def extract_answer(self, response, row):
|
||||||
|
prompt = row["question"]
|
||||||
|
gen = self.process_before_extraction(
|
||||||
|
response, prompt, {choice: row[choice] for choice in self.choices}
|
||||||
|
)
|
||||||
|
if not isinstance(prompt, str):
|
||||||
|
prompt = prompt[0]
|
||||||
|
pred = self.extract_choice(gen, prompt, [row[choice] for choice in self.choices])
|
||||||
|
return pred
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def eval_subject(
|
||||||
|
self,
|
||||||
|
subject_name,
|
||||||
|
test_df,
|
||||||
|
eval_type="validation" # "test","validation"
|
||||||
|
):
|
||||||
|
if eval_type == "validation":
|
||||||
|
responses = []
|
||||||
|
result = []
|
||||||
|
score = []
|
||||||
|
for _, row in tqdm(test_df.iterrows(), total=len(test_df)):
|
||||||
|
question = self.format_example(row)
|
||||||
|
|
||||||
|
response, _ = self.model.chat(
|
||||||
|
self.tokenizer,
|
||||||
|
question,
|
||||||
|
history=None,
|
||||||
|
)
|
||||||
|
pred = self.extract_answer(response, row)
|
||||||
|
if "answer" in row:
|
||||||
|
correct = 1 if pred == row["answer"] else 0
|
||||||
|
score.append(correct)
|
||||||
|
responses.append(response)
|
||||||
|
result.append(pred)
|
||||||
|
|
||||||
|
if score:
|
||||||
|
correct_ratio = 100 * sum(score) / len(score)
|
||||||
|
|
||||||
|
else:
|
||||||
|
correct_ratio = 0
|
||||||
|
|
||||||
|
return correct_ratio, None
|
||||||
|
elif eval_type == "test":
|
||||||
|
answers = {}
|
||||||
|
for i, row in tqdm(test_df.iterrows(), total=len(test_df)):
|
||||||
|
question = self.format_example(row)
|
||||||
|
response, _ = self.model.chat(
|
||||||
|
self.tokenizer,
|
||||||
|
question,
|
||||||
|
history=None,
|
||||||
|
)
|
||||||
|
pred = self.extract_answer(response, row)
|
||||||
|
answers[str(i)] = pred
|
||||||
|
return None, answers
|
||||||
7
python/llm/dev/benchmark/ceval/run.sh
Normal file
7
python/llm/dev/benchmark/ceval/run.sh
Normal file
|
|
@ -0,0 +1,7 @@
|
||||||
|
python eval.py \
|
||||||
|
--model_family llama \
|
||||||
|
--model_path "path to model" \
|
||||||
|
--eval_type validation \
|
||||||
|
--device xpu \
|
||||||
|
--eval_data_path data \
|
||||||
|
--qtype sym_int4
|
||||||
Loading…
Reference in a new issue