From 511cbcf77383178615214f799b4b552556c05de2 Mon Sep 17 00:00:00 2001 From: Cengguang Zhang Date: Tue, 16 Jan 2024 19:14:26 +0800 Subject: [PATCH] LLM: add Ceval benchmark test. (#9872) * init ceval benchmark test. * upload dataset. * add other tests. * add qwen evaluator. * fix qwen evaluator style. * fix qwen evaluator style. * update qwen evaluator. * add llama evaluator. * update eval * fix typo. * fix * fix typo. * fix llama evaluator. * fix bug. * fix style. * delete dataset. * fix style. * fix style. * add README.md and fix typo. * fix comments. * remove run scripts --- python/llm/dev/benchmark/ceval/README.md | 32 ++ python/llm/dev/benchmark/ceval/eval.py | 310 ++++++++++++++++++ .../benchmark/ceval/evaluators/evaluator.py | 31 ++ .../dev/benchmark/ceval/evaluators/llama.py | 236 +++++++++++++ .../dev/benchmark/ceval/evaluators/qwen.py | 154 +++++++++ python/llm/dev/benchmark/ceval/run.sh | 7 + 6 files changed, 770 insertions(+) create mode 100644 python/llm/dev/benchmark/ceval/README.md create mode 100644 python/llm/dev/benchmark/ceval/eval.py create mode 100644 python/llm/dev/benchmark/ceval/evaluators/evaluator.py create mode 100644 python/llm/dev/benchmark/ceval/evaluators/llama.py create mode 100644 python/llm/dev/benchmark/ceval/evaluators/qwen.py create mode 100644 python/llm/dev/benchmark/ceval/run.sh diff --git a/python/llm/dev/benchmark/ceval/README.md b/python/llm/dev/benchmark/ceval/README.md new file mode 100644 index 00000000..608c170c --- /dev/null +++ b/python/llm/dev/benchmark/ceval/README.md @@ -0,0 +1,32 @@ +## C-Eval Benchmark Test + +C-Eval benchmark test allows users to test on [C-Eval](https://cevalbenchmark.com) datasets, which is a multi-level multi-discipline chinese evaluation suite for foundation models. It consists of 13948 multi-choice questions spanning 52 diverse disciplines and four difficulty levels. Please check [paper](https://arxiv.org/abs/2305.08322) and [github repo](https://github.com/hkust-nlp/ceval) for more information. + +### Download dataset +Please download and unzip the dataset for evaluation. +```shell +wget https://huggingface.co/datasets/ceval/ceval-exam/resolve/main/ceval-exam.zip +mkdir data +mv ceval-exam.zip data +cd data; unzip ceval-exam.zip +``` + +### Run +You can run evaluation with following command. +```shell +bash run.sh +``` ++ `run.sh` +```shell +python eval.py \ + --model_family llama \ + --model_path "path to model" \ + --eval_type validation \ + --device xpu \ + --eval_data_path data \ + --qtype sym_int4 +``` + +> **Note** +> +> `eval_type` there is two types of evaluation, first type is `validation`, which runs on validation dataset and output evaluation scores. The second type is `test`, which runs on test dataset and output `submission.json` file for submission on https://cevalbenchmark.com to get the evaluation score. diff --git a/python/llm/dev/benchmark/ceval/eval.py b/python/llm/dev/benchmark/ceval/eval.py new file mode 100644 index 00000000..ceff8b20 --- /dev/null +++ b/python/llm/dev/benchmark/ceval/eval.py @@ -0,0 +1,310 @@ +# +# Copyright 2016 The BigDL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import os +import argparse +import pandas as pd +import torch +import json +from tqdm import tqdm + +from bigdl.llm.utils.common.log4Error import invalidInputError +from evaluators.qwen import QwenEvaluator +from evaluators.llama import LlamaEvaluator + + +TASK_NAME_MAPPING = { + "computer_network": ["Computer Network", "\u8ba1\u7b97\u673a\u7f51\u7edc", "STEM"], + "operating_system": ["Operating System", "\u64cd\u4f5c\u7cfb\u7edf", "STEM"], + "computer_architecture": [ + "Computer Architecture", + "\u8ba1\u7b97\u673a\u7ec4\u6210", + "STEM", + ], + "college_programming": ["College Programming", "\u5927\u5b66\u7f16\u7a0b", "STEM"], + "college_physics": ["College Physics", "\u5927\u5b66\u7269\u7406", "STEM"], + "college_chemistry": ["College Chemistry", "\u5927\u5b66\u5316\u5b66", "STEM"], + "advanced_mathematics": [ + "Advanced Mathematics", + "\u9ad8\u7b49\u6570\u5b66", + "STEM", + ], + "probability_and_statistics": [ + "Probability and Statistics", + "\u6982\u7387\u7edf\u8ba1", + "STEM", + ], + "discrete_mathematics": [ + "Discrete Mathematics", + "\u79bb\u6563\u6570\u5b66", + "STEM", + ], + "electrical_engineer": [ + "Electrical Engineer", + "\u6ce8\u518c\u7535\u6c14\u5de5\u7a0b\u5e08", + "STEM", + ], + "metrology_engineer": [ + "Metrology Engineer", + "\u6ce8\u518c\u8ba1\u91cf\u5e08", + "STEM", + ], + "high_school_mathematics": [ + "High School Mathematics", + "\u9ad8\u4e2d\u6570\u5b66", + "STEM", + ], + "high_school_physics": ["High School Physics", "\u9ad8\u4e2d\u7269\u7406", "STEM"], + "high_school_chemistry": [ + "High School Chemistry", + "\u9ad8\u4e2d\u5316\u5b66", + "STEM", + ], + "high_school_biology": ["High School Biology", "\u9ad8\u4e2d\u751f\u7269", "STEM"], + "middle_school_mathematics": [ + "Middle School Mathematics", + "\u521d\u4e2d\u6570\u5b66", + "STEM", + ], + "middle_school_biology": [ + "Middle School Biology", + "\u521d\u4e2d\u751f\u7269", + "STEM", + ], + "middle_school_physics": [ + "Middle School Physics", + "\u521d\u4e2d\u7269\u7406", + "STEM", + ], + "middle_school_chemistry": [ + "Middle School Chemistry", + "\u521d\u4e2d\u5316\u5b66", + "STEM", + ], + "veterinary_medicine": ["Veterinary Medicine", "\u517d\u533b\u5b66", "STEM"], + "college_economics": [ + "College Economics", + "\u5927\u5b66\u7ecf\u6d4e\u5b66", + "Social Science", + ], + "business_administration": [ + "Business Administration", + "\u5de5\u5546\u7ba1\u7406", + "Social Science", + ], + "marxism": [ + "Marxism", + "\u9a6c\u514b\u601d\u4e3b\u4e49\u57fa\u672c\u539f\u7406", + "Social Science", + ], + "mao_zedong_thought": [ + "Mao Zedong Thought", + "\u6bdb\u6cfd\u4e1c\u601d\u60f3\u548c\u4e2d\u56fd\u7279\u8272\u793e\u4f1a\u4e3b\u4e49\u7406\u8bba\u4f53\u7cfb\u6982\u8bba", + "Social Science", + ], + "education_science": ["Education Science", "\u6559\u80b2\u5b66", "Social Science"], + "teacher_qualification": [ + "Teacher Qualification", + "\u6559\u5e08\u8d44\u683c", + "Social Science", + ], + "high_school_politics": [ + "High School Politics", + "\u9ad8\u4e2d\u653f\u6cbb", + "Social Science", + ], + "high_school_geography": [ + "High School Geography", + "\u9ad8\u4e2d\u5730\u7406", + "Social Science", + ], + "middle_school_politics": [ + "Middle School Politics", + "\u521d\u4e2d\u653f\u6cbb", + "Social Science", + ], + "middle_school_geography": [ + "Middle School Geography", + "\u521d\u4e2d\u5730\u7406", + "Social Science", + ], + "modern_chinese_history": [ + "Modern Chinese History", + "\u8fd1\u4ee3\u53f2\u7eb2\u8981", + "Humanities", + ], + "ideological_and_moral_cultivation": [ + "Ideological and Moral Cultivation", + "\u601d\u60f3\u9053\u5fb7\u4fee\u517b\u4e0e\u6cd5\u5f8b\u57fa\u7840", + "Humanities", + ], + "logic": ["Logic", "\u903b\u8f91\u5b66", "Humanities"], + "law": ["Law", "\u6cd5\u5b66", "Humanities"], + "chinese_language_and_literature": [ + "Chinese Language and Literature", + "\u4e2d\u56fd\u8bed\u8a00\u6587\u5b66", + "Humanities", + ], + "art_studies": ["Art Studies", "\u827a\u672f\u5b66", "Humanities"], + "professional_tour_guide": [ + "Professional Tour Guide", + "\u5bfc\u6e38\u8d44\u683c", + "Humanities", + ], + "legal_professional": [ + "Legal Professional", + "\u6cd5\u5f8b\u804c\u4e1a\u8d44\u683c", + "Humanities", + ], + "high_school_chinese": [ + "High School Chinese", + "\u9ad8\u4e2d\u8bed\u6587", + "Humanities", + ], + "high_school_history": [ + "High School History", + "\u9ad8\u4e2d\u5386\u53f2", + "Humanities", + ], + "middle_school_history": [ + "Middle School History", + "\u521d\u4e2d\u5386\u53f2", + "Humanities", + ], + "civil_servant": ["Civil Servant", "\u516c\u52a1\u5458", "Other"], + "sports_science": ["Sports Science", "\u4f53\u80b2\u5b66", "Other"], + "plant_protection": ["Plant Protection", "\u690d\u7269\u4fdd\u62a4", "Other"], + "basic_medicine": ["Basic Medicine", "\u57fa\u7840\u533b\u5b66", "Other"], + "clinical_medicine": ["Clinical Medicine", "\u4e34\u5e8a\u533b\u5b66", "Other"], + "urban_and_rural_planner": [ + "Urban and Rural Planner", + "\u6ce8\u518c\u57ce\u4e61\u89c4\u5212\u5e08", + "Other", + ], + "accountant": ["Accountant", "\u6ce8\u518c\u4f1a\u8ba1\u5e08", "Other"], + "fire_engineer": [ + "Fire Engineer", + "\u6ce8\u518c\u6d88\u9632\u5de5\u7a0b\u5e08", + "Other", + ], + "environmental_impact_assessment_engineer": [ + "Environmental Impact Assessment Engineer", + "\u73af\u5883\u5f71\u54cd\u8bc4\u4ef7\u5de5\u7a0b\u5e08", + "Other", + ], + "tax_accountant": ["Tax Accountant", "\u7a0e\u52a1\u5e08", "Other"], + "physician": ["Physician", "\u533b\u5e08\u8d44\u683c", "Other"], +} +hard_list = [ + "advanced_mathematics", + "discrete_mathematics", + "probability_and_statistics", + "college_physics", + "college_chemistry", + "high_school_mathematics", + "high_school_physics", + "high_school_chemistry", +] +choices = ["A", "B", "C", "D"] + + +def cal_ceval(res): + acc_sum_dict = dict() + acc_norm_sum_dict = dict() + cnt_dict = dict() + acc_sum = 0.0 + cnt = 0 + hard_cnt = 0 + hard_acc_sum = 0.0 + for tt in res.keys(): + name = tt.split("-")[-1] + acc_sum += float(res[tt]) + cnt += 1 + class_ = TASK_NAME_MAPPING[name][2] + if class_ not in acc_sum_dict: + acc_sum_dict[class_] = 0.0 + acc_norm_sum_dict[class_] = 0.0 + cnt_dict[class_] = 0.0 + if name in hard_list: + hard_cnt += 1 + hard_acc_sum += float(res[tt]) + acc_sum_dict[class_] += float(res[tt]) + cnt_dict[class_] += 1 + print("\n\n\n") + for k in ["STEM", "Social Science", "Humanities", "Other"]: + if k in cnt_dict: + print("%s acc: %.2f " % (k, acc_sum_dict[k] / cnt_dict[k])) + if hard_cnt > 0: + print("Hard acc:%.2f " % (hard_acc_sum / hard_cnt)) + print("AVERAGE acc:%.2f " % (acc_sum / cnt)) + + +def main(args, evaluator): + if args.eval_type == "validation": + result = {} + for subject_name in tqdm(TASK_NAME_MAPPING.keys()): + val_file_path = os.path.join( + args.eval_data_path, "val", f"{subject_name}_val.csv" + ) + val_df = pd.read_csv(val_file_path) + score, _ = evaluator.eval_subject(subject_name, val_df, args.eval_type) + result[subject_name] = score + cal_ceval(result) + elif args.eval_type == "test": + all_answers = {} + for subject_name in tqdm(TASK_NAME_MAPPING.keys()): + test_file_path = os.path.join( + args.eval_data_path, "test", f"{subject_name}_test.csv" + ) + test_df = pd.read_csv(test_file_path) + _, answers = evaluator.eval_subject(subject_name, test_df, args.eval_type) + all_answers[subject_name] = answers + json.dump(all_answers, open('submission.json','w'), ensure_ascii=False, indent=4) + else: + invalidInputError(False, + "Invalid eval_type, please use validation or test.") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--model_family", type=str, default="llama") + parser.add_argument("--model_path", type=str, default="meta-llama/Llama-2-7b-chat-hf") + parser.add_argument("--eval_type", type=str, default="validation") + parser.add_argument("--device", type=str, default="xpu") + parser.add_argument("--eval_data_path", type=str, default="data") + parser.add_argument("--qtype", type=str, default="sym_int4") + + args = parser.parse_args() + + if args.model_family == "llama": + evaluator = LlamaEvaluator( + choices=choices, + model_path=args.model_path, + device=args.device, + qtype=args.qtype + ) + elif args.model_family == "qwen": + evaluator = QwenEvaluator( + choices=choices, + model_path=args.model_path, + device=args.device, + qtype=args.qtype + ) + else: + invalidInputError( + False, + "Invalid model_family, currently support llama and qwen only.") + main(args, evaluator=evaluator) diff --git a/python/llm/dev/benchmark/ceval/evaluators/evaluator.py b/python/llm/dev/benchmark/ceval/evaluators/evaluator.py new file mode 100644 index 00000000..d2d4abf3 --- /dev/null +++ b/python/llm/dev/benchmark/ceval/evaluators/evaluator.py @@ -0,0 +1,31 @@ +# +# Copyright 2016 The BigDL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +class Evaluator: + def __init__(self, choices, model_path, device, qtype): + self.choices = choices + self.model_path = model_path + self.device = device + self.qtype = qtype + + def format_example(self, line, **kwargs): + pass + + def eval_subject(self, subject_name, test_df, eval_type, **kwargs): + pass + + def extract_answer(self, response, row, **kwargs): + pass diff --git a/python/llm/dev/benchmark/ceval/evaluators/llama.py b/python/llm/dev/benchmark/ceval/evaluators/llama.py new file mode 100644 index 00000000..ba1dfc3e --- /dev/null +++ b/python/llm/dev/benchmark/ceval/evaluators/llama.py @@ -0,0 +1,236 @@ +# +# Copyright 2016 The BigDL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# refer to https://github.com/ymcui/Chinese-LLaMA-Alpaca-2/blob/main/scripts/ceval/llama_evaluator.py + +import re +import random +from tqdm import tqdm +import numpy as np +import torch +from transformers import LlamaTokenizer, GenerationConfig + +from bigdl.llm.transformers import AutoModelForCausalLM +from evaluators.evaluator import Evaluator + + +DEFAULT_SYSTEM_PROMPT = """You are a helpful assistant. 你是一个乐于助人的助手。""" + + +class LlamaEvaluator(Evaluator): + def __init__(self, choices, model_path="meta-llama/Llama-2-7b-chat-hf", device="xpu", qtype="sym_int4"): + super(LlamaEvaluator, self).__init__(choices, model_path, device, qtype) + self.tokenizer = LlamaTokenizer.from_pretrained( + self.model_path, + trust_remote_code=True + ) + self.model = AutoModelForCausalLM.from_pretrained( + self.model_path, + load_in_low_bit=self.qtype, + optimize_model=True, + use_cache=True, + trust_remote_code=True + ).eval().to(self.device) + self.generation_config = GenerationConfig( + temperature=0.2, + top_k=40, + top_p=0.9, + do_sample=True, + num_beams=1, + repetition_penalty=1.1, + max_new_tokens=20 + ) + self.sA_id = self.tokenizer.encode("A", add_special_tokens=False)[0] + self.sB_id = self.tokenizer.encode("B", add_special_tokens=False)[0] + self.sC_id = self.tokenizer.encode("C", add_special_tokens=False)[0] + self.sD_id = self.tokenizer.encode("D", add_special_tokens=False)[0] + self.A_id = self.tokenizer.encode(":A")[-1] + self.B_id = self.tokenizer.encode(":B")[-1] + self.C_id = self.tokenizer.encode(":C")[-1] + self.D_id = self.tokenizer.encode(":D")[-1] + + + @torch.no_grad() + def eval_subject(self, subject_name, + test_df, + eval_type="validation", + dev_df=None, + few_shot=False, + cot=False, + with_prompt=False, + constrained_decoding=False): + all_answers = {} + if constrained_decoding is True: + self.generation_config.output_scores = True + self.generation_config.return_dict_in_generate = True + self.generation_config.max_new_tokens = 1 + self.generation_config.top_p = 1.0 + self.generation_config.top_k = 0 + + correct_num = 0 + if few_shot: + if with_prompt: + history = self.generate_alpaca2_few_shot_prompt(subject_name, dev_df, cot=cot) + else: + history = self.generate_llama2_few_shot_prompt(subject_name, dev_df, cot=cot) + else: + history = '' + answers = ['NA'] * len(test_df) if (eval_type=="test") is True else list(test_df['answer']) + for row_index, row in tqdm(test_df.iterrows(), total=len(test_df)): + question = self.format_example(row, include_answer=False, cot=cot,with_prompt=with_prompt) + instruction = question + if with_prompt: + prompt_template = ( + "[INST] <>\n" + "{system_prompt}\n" + "<>\n\n" + "{instruction} [/INST]" + ) + + instruction = prompt_template.format_map({'instruction': instruction,'system_prompt':DEFAULT_SYSTEM_PROMPT}) + instruction = history + instruction + inputs = self.tokenizer(instruction, return_tensors="pt") + generation_output = self.model.generate( + input_ids = inputs["input_ids"].to(self.device), + attention_mask = inputs['attention_mask'].to(self.device), + eos_token_id=self.tokenizer.eos_token_id, + pad_token_id=self.tokenizer.pad_token_id, + generation_config = self.generation_config + ) + + _ , length = inputs.input_ids.shape + if constrained_decoding is True: + logits = generation_output.scores[0][0] + + logits = logits.float().cpu().detach() + choices1_logits = logits[[self.sA_id,self.sB_id,self.sC_id,self.sD_id]] + choices2_logits = logits[[self.A_id,self.B_id,self.C_id,self.D_id]] + choicesAll_logits = (choices1_logits + choices2_logits).numpy() + assert not (np.any(np.isinf(choicesAll_logits)) or np.any(np.isnan(choicesAll_logits))) + ans = {0: "A", 1: "B", 2: "C", 3: "D"}[np.argmax(choicesAll_logits)] + response = self.tokenizer.decode([logits.argmax(-1).item()]) + else: + response = self.tokenizer.decode(generation_output[0, length:], skip_special_tokens=True) + ans, _ = self.extract_answer(response, row) + if ans == answers[row_index]: + correct_num += 1 + + all_answers[str(row_index)] = ans + + correct_ratio = 100*correct_num/len(answers) + + return correct_ratio, all_answers + + + def format_example(self, line, include_answer=True, cot=False, with_prompt=False): + example = line['question'] + for choice in self.choices: + example += f'\n{choice}. {line[f"{choice}"]}' + if include_answer: + if cot: + example += "\n答案:让我们一步一步思考,\n" + \ + line["explanation"] + f"\n所以答案是{line['answer']}。\n\n" + else: + example += '\n答案:' + line["answer"] + '\n\n' + else: + if with_prompt is False: + if cot: + example += "\n答案:让我们一步一步思考,\n1." + else: + example += '\n答案:' + else: + if cot: + example += "\n答案是什么?让我们一步一步思考,\n1." + else: + example += '\n答案:' + return example + + + def generate_llama2_few_shot_prompt(self, subject, dev_df, cot=False): + prompt = f"以下是中国关于{subject}考试的单项选择题,请选出其中的正确答案。\n\n" + k = self.k + if self.k == -1: + k = dev_df.shape[0] + for i in range(k): + prompt += self.format_example( + dev_df.iloc[i, :], + include_answer=True, + cot=cot + ) + return prompt + + + def generate_alpaca2_few_shot_prompt(self, subject, dev_df, cot=False): + prompt = f"以下是中国关于{subject}考试的单项选择题,请选出其中的正确答案。\n\n" + prompt_template = ( + "[INST] <>\n" + "{system_prompt}\n" + "<>\n\n" + "{instruction} [/INST]好的,我会结合{subject}相关知识回答" + ) + + prompt = prompt_template.format_map({'instruction':prompt,'system_prompt':DEFAULT_SYSTEM_PROMPT,'subject':subject}) + k = self.k + if self.k == -1: + k = dev_df.shape[0] + for i in range(k): + line = dev_df.iloc[i, :] + q=line['question'] + for choice in self.choices: + q += f'\n{choice}. {line[f"{choice}"]}' + + a = line['answer'] + prompt += "[INST] "+q+"\n答案:[/INST]"+a+"\n" + return prompt + + + def extract_answer(self, response, row): + m = re.findall(r'所以答案是(.+?)。', response, re.M) + if len(m) > 0 and m[-1] in self.choices: + return m[-1], True + answer_patterns = [ + r'([ABCD])是正确的', + r'选项([ABCD])正确', + r'答案为([ABCD])', + r'答案是([ABCD])', + r'答案([ABCD])', + r'选择([ABCD])', + r'答案:([ABCD])', + r'选择答案([ABCD])' + ] + # RE extraction + for answer_pattern in answer_patterns: + m = re.search(answer_pattern, response, re.M) + if m: + answer = m.group(1) + return answer, False + # only containing one choice-character + m = re.findall(r'[ABCD]', response, re.M) + if len(m) >= 1: + answer = m[0] + return answer, False + # only containing one choice-context + choices_dict = {} + pattern = "" + for c in self.choices: + choices_dict[str(row[f'{c}'])] = c + pattern += re.escape(str(row[f'{c}']))+"|" + pattern = pattern[:-1] + m = re.findall(pattern, response, re.M) + print("w/ escape:",repr(pattern),response,(len(m)>=1)) + if len(m) >= 1: + answer = choices_dict[m[0]] + return answer, False + return random.choice('ABCD'), False diff --git a/python/llm/dev/benchmark/ceval/evaluators/qwen.py b/python/llm/dev/benchmark/ceval/evaluators/qwen.py new file mode 100644 index 00000000..561bb6da --- /dev/null +++ b/python/llm/dev/benchmark/ceval/evaluators/qwen.py @@ -0,0 +1,154 @@ +# +# Copyright 2016 The BigDL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# refer to https://github.com/QwenLM/Qwen/blob/main/eval/evaluate_chat_ceval.py + +import re +from tqdm import tqdm +import torch +from thefuzz import process +from transformers import AutoTokenizer +from transformers.generation import GenerationConfig + +from bigdl.llm.transformers import AutoModelForCausalLM +from evaluators.evaluator import Evaluator + + +class QwenEvaluator(Evaluator): + def __init__(self, choices, model_path="Qwen/Qwen-7B-Chat", device="xpu", qtype="sym_int4"): + super(QwenEvaluator, self).__init__(choices, model_path, device, qtype) + self.tokenizer = AutoTokenizer.from_pretrained( + self.model_path, + trust_remote_code=True + ) + self.model = AutoModelForCausalLM.from_pretrained( + self.model_path, + load_in_low_bit=self.qtype, + optimize_model=True, + use_cache=True, + trust_remote_code=True + ).eval().to(self.device) + self.model.generation_config = GenerationConfig.from_pretrained( + self.model_path, + trust_remote_code=True + ) + self.model.generation_config.do_sample = False # use greedy decoding + self.model.generation_config.repetition_penalty = 1.0 # disable repetition penalty + + + def process_before_extraction(self, gen, question, choice_dict): + + question_split = question.rstrip("。").split("。")[-1].split("_") + + if len(question_split[0].strip()) > 4: + gen = gen.replace(question_split[0], "答案是") + if len(question_split[-1].strip()) > 4: + gen = gen.replace(question_split[-1], "") + + for key, val in sorted(choice_dict.items(), key=lambda x: len(x[1]), reverse=True): + gen = gen.replace(val.rstrip("。"), key) + return gen + + + def count_substr(self, gen, pattern): + return len(re.findall(pattern, gen)) + + + def extract_choice(self, gen, prompt, choice_list): + res = re.search( + r"(?:(?:选|选择|选定)[::]?\s*|(?:(?:答案|选项)(?![^ABCD]{0,10}?(?:不|非)[^ABCD]{0,10}?(?:是|选|为|:|:|】))[^ABCD]{0,10}?(?:是|选|为|:|:|】))[^ABCD]{0,10}?)(A|B|C|D)(?:选项)?(?:\)|。|\.|,|,|.|、|A|B|C|D|$|:|:|\)|))", + gen, + ) + + if res is None: + res = re.search( + r"(A|B|C|D)(?:选?项)?(?![^ABCD]{0,4}?(?:不|非)[^ABCD]{0,4}?(?:正确|对[的,。:]|符合))[^ABCD]{0,4}?(?:正确|对[的,。:]|符合)", + gen, + ) + + if res is None: + res = re.search(r"^[\((]?(A|B|C|D)(?:。|\)|)|\.|,|,|.|:|:|$)", gen) + + if res is None: + res = re.search(r"(?