Add ChatGLM C-Eval Evaluator (#10095)
* Add ChatGLM ceval evaluator * Modify ChatGLM Evaluator Reference
This commit is contained in:
parent
5e9710cec4
commit
3832eb0ce0
2 changed files with 232 additions and 4 deletions
|
|
@ -24,6 +24,7 @@ from tqdm import tqdm
|
|||
from bigdl.llm.utils.common.log4Error import invalidInputError
|
||||
from evaluators.qwen import QwenEvaluator
|
||||
from evaluators.llama import LlamaEvaluator
|
||||
from evaluators.chatglm import ChatGLMEvaluator
|
||||
|
||||
|
||||
TASK_NAME_MAPPING = {
|
||||
|
|
@ -280,7 +281,6 @@ def main(args, evaluator):
|
|||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--model_family", type=str, default="llama")
|
||||
parser.add_argument("--model_path", type=str, default="meta-llama/Llama-2-7b-chat-hf")
|
||||
parser.add_argument("--eval_type", type=str, default="validation")
|
||||
parser.add_argument("--device", type=str, default="xpu")
|
||||
|
|
@ -289,22 +289,39 @@ if __name__ == "__main__":
|
|||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.model_family == "llama":
|
||||
# decide the model family
|
||||
model_families = ['llama', 'qwen', 'chatglm']
|
||||
|
||||
model_family = None
|
||||
for family in model_families:
|
||||
if family in args.model_path.lower():
|
||||
model_family = family
|
||||
|
||||
assert model_family is not None, f"Model {args.model_path}'s model family is not implemented"
|
||||
|
||||
if model_family == "llama":
|
||||
evaluator = LlamaEvaluator(
|
||||
choices=choices,
|
||||
model_path=args.model_path,
|
||||
device=args.device,
|
||||
qtype=args.qtype
|
||||
)
|
||||
elif args.model_family == "qwen":
|
||||
elif model_family == "qwen":
|
||||
evaluator = QwenEvaluator(
|
||||
choices=choices,
|
||||
model_path=args.model_path,
|
||||
device=args.device,
|
||||
qtype=args.qtype
|
||||
)
|
||||
elif model_family == "chatglm":
|
||||
evaluator = ChatGLMEvaluator(
|
||||
choices=choices,
|
||||
model_path=args.model_path,
|
||||
device=args.device,
|
||||
qtype=args.qtype
|
||||
)
|
||||
else:
|
||||
invalidInputError(
|
||||
False,
|
||||
"Invalid model_family, currently support llama and qwen only.")
|
||||
"Invalid model_family, currently support llama, qwen, and chatglm only.")
|
||||
main(args, evaluator=evaluator)
|
||||
|
|
|
|||
211
python/llm/dev/benchmark/ceval/evaluators/chatglm.py
Normal file
211
python/llm/dev/benchmark/ceval/evaluators/chatglm.py
Normal file
|
|
@ -0,0 +1,211 @@
|
|||
#
|
||||
# Copyright 2016 The BigDL Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
# refer to https://github.com/THUDM/ChatGLM2-6B/blob/main/evaluation/evaluate_ceval.py
|
||||
|
||||
import re
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
from thefuzz import process
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from evaluators.evaluator import Evaluator
|
||||
from bigdl.llm.transformers import AutoModelForCausalLM
|
||||
from transformers.generation.utils import LogitsProcessorList
|
||||
from transformers.generation.logits_process import LogitsProcessor
|
||||
|
||||
|
||||
class InvalidScoreLogitsProcessor(LogitsProcessor):
|
||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
||||
if torch.isnan(scores).any() or torch.isinf(scores).any():
|
||||
scores.zero_()
|
||||
scores[..., 5] = 5e4
|
||||
return scores
|
||||
|
||||
|
||||
class ChatGLMEvaluator(Evaluator):
|
||||
def __init__(self, choices, model_path="THUDM/chatglm-6b", device="xpu", qtype="sym_int4"):
|
||||
super(ChatGLMEvaluator, self).__init__(choices, model_path, device, qtype)
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(
|
||||
self.model_path,
|
||||
trust_remote_code=True
|
||||
)
|
||||
self.model = AutoModelForCausalLM.from_pretrained(
|
||||
self.model_path,
|
||||
load_in_low_bit=self.qtype,
|
||||
optimize_model=True,
|
||||
use_cache=True,
|
||||
trust_remote_code=True
|
||||
).eval().to(self.device)
|
||||
|
||||
def generate_few_shot_prompt(self, subject, dev_df, cot=False):
|
||||
message = []
|
||||
k = self.k
|
||||
if self.k == -1:
|
||||
k = dev_df.shape[0]
|
||||
message.append(self.format_example(dev_df.iloc[0, :], cot=cot, add_prompt=f"以下是中国关于{subject}考试的单项选择题,请选出其中的正确答案。\n\n"))
|
||||
for i in range(1, k):
|
||||
message.append(self.format_example(dev_df.iloc[i, :], cot=cot))
|
||||
return message
|
||||
|
||||
def format_example(self, line, include_answer=True, cot=False, add_prompt=''):
|
||||
example = add_prompt + line['question']
|
||||
# print(example)
|
||||
for choice in self.choices:
|
||||
example += f'\n{choice}. {line[f"{choice}"]}'
|
||||
example += '\n答案:'
|
||||
if include_answer:
|
||||
if cot:
|
||||
ans = "让我们一步一步思考,\n" + line["explanation"] + f"\n所以答案是{line['answer']}。"
|
||||
else:
|
||||
ans = line["answer"]
|
||||
m = (example, ans)
|
||||
return m
|
||||
return example
|
||||
|
||||
def extract_cot_answer(self, line, gen_ans):
|
||||
m = re.findall(r'所以答案是(.+?)。', gen_ans, re.M)
|
||||
if len(m) > 0 and m[-1] in self.choices:
|
||||
return m[-1], True
|
||||
answer_patterns = [
|
||||
r'([ABCD])是正确的',
|
||||
r'选项([ABCD])正确',
|
||||
r'答案为([ABCD])',
|
||||
r'答案是([ABCD])',
|
||||
r'答案([ABCD])',
|
||||
r'选择([ABCD])',
|
||||
r'答案:([ABCD])',
|
||||
r'选择答案([ABCD])'
|
||||
]
|
||||
# RE extraction
|
||||
for answer_pattern in answer_patterns:
|
||||
m = re.search(answer_pattern, gen_ans, re.M)
|
||||
if m:
|
||||
answer = m.group(1)
|
||||
return answer, False
|
||||
# only containing one choice-character
|
||||
m = re.findall(r'[ABCD]', gen_ans, re.M)
|
||||
if len(m) == 1:
|
||||
answer = m[0]
|
||||
return answer, False
|
||||
answer_word_counter = 0
|
||||
# only containing one choice-context
|
||||
for c in self.choices:
|
||||
if str(line[f'{c}']) in gen_ans:
|
||||
answer = c
|
||||
answer_word_counter += 1
|
||||
if answer_word_counter == 1:
|
||||
return answer, False
|
||||
return '-', False
|
||||
|
||||
def build_prompt(self, text):
|
||||
return "[Round {}]\n\n问:{}\n\n答:".format(1, text)
|
||||
|
||||
def generate_dist(self, model, tokenizer, query, history, max_length=2048,
|
||||
do_sample=False, logits_processor=None):
|
||||
|
||||
if history is None:
|
||||
history = []
|
||||
if logits_processor is None:
|
||||
logits_processor = LogitsProcessorList()
|
||||
logits_processor.append(InvalidScoreLogitsProcessor())
|
||||
|
||||
if not history:
|
||||
prompt = query
|
||||
else:
|
||||
prompt = ""
|
||||
for i, (old_query, response) in enumerate(history):
|
||||
prompt += "[Round {}]\n问:{}\n答:{}\n".format(i, old_query, response)
|
||||
prompt += "[Round {}]\n问:{}\n答:".format(len(history), query)
|
||||
|
||||
# first round prompt
|
||||
inputs = tokenizer([prompt], padding=True, return_tensors="pt",
|
||||
truncation=True, max_length=max_length).to(model.device)
|
||||
|
||||
# first round generation
|
||||
outputs = model.generate(**inputs, do_sample=do_sample, max_new_tokens=512)
|
||||
|
||||
# organize intermediate_outputs
|
||||
intermediate_outputs = []
|
||||
for idx in range(len(outputs)):
|
||||
output = outputs.tolist()[idx][len(inputs["input_ids"][idx]):]
|
||||
response = tokenizer.decode(output)
|
||||
intermediate_outputs.append(response)
|
||||
|
||||
# prepare second round prompt
|
||||
extraction_prompt = '综上所述,ABCD中正确的选项是:'
|
||||
answer_texts = [query + intermediate + "\n" + extraction_prompt for intermediate in intermediate_outputs]
|
||||
input_tokens = [self.build_prompt(answer_text) for answer_text in answer_texts]
|
||||
inputs = tokenizer(input_tokens, padding=True, return_tensors="pt", truncation=True, max_length=2048).to(model.device)
|
||||
|
||||
# second round generation
|
||||
outputs = model(**inputs, return_last_logit=True)
|
||||
|
||||
logits = outputs.logits[:, -1]
|
||||
choice_tokens = [tokenizer.encode(choice, add_special_tokens=False)[0] for choice in self.choices]
|
||||
logits = logits[:, choice_tokens]
|
||||
preds = logits.argmax(dim=-1)
|
||||
|
||||
return self.choices[preds]
|
||||
|
||||
@torch.no_grad()
|
||||
def eval_subject(
|
||||
self,
|
||||
subject_name,
|
||||
test_df,
|
||||
eval_type="validation", # "test","validation",
|
||||
dev_df=None,
|
||||
few_shot=False,
|
||||
cot=False,
|
||||
):
|
||||
if eval_type == "validation":
|
||||
correct_num = 0
|
||||
|
||||
if few_shot:
|
||||
history = self.generate_few_shot_prompt(subject_name, dev_df, cot=cot)
|
||||
else:
|
||||
history = []
|
||||
|
||||
answers = list(test_df['answer'])
|
||||
|
||||
for row_index, row in tqdm(test_df.iterrows(), total=len(test_df)):
|
||||
question = self.format_example(row, include_answer=False, cot=cot)
|
||||
|
||||
if few_shot:
|
||||
response, _ = self.model.chat(self.tokenizer, question, do_sample=False, history=history)
|
||||
response = response.strip()
|
||||
# For ChatGLM, we use answer extraction in answer-only mode too.
|
||||
ans, direct_extract = self.extract_cot_answer(row, response)
|
||||
else: # zero-shot by extracting answer from distribution
|
||||
ans = self.generate_dist(self.model, self.tokenizer, question, do_sample=False, max_length=2048, history=history)
|
||||
|
||||
if ans == answers[row_index]:
|
||||
correct_num += 1
|
||||
|
||||
correct_ratio = 100*correct_num/len(answers)
|
||||
|
||||
return correct_ratio, None
|
||||
elif eval_type == "test":
|
||||
answers = {}
|
||||
for i, row in tqdm(test_df.iterrows(), total=len(test_df)):
|
||||
question = self.format_example(row)
|
||||
response, _ = self.model.chat(
|
||||
self.tokenizer,
|
||||
question,
|
||||
history=None,
|
||||
)
|
||||
pred = self.extract_answer(response, row)
|
||||
answers[str(i)] = pred
|
||||
return None, answers
|
||||
Loading…
Reference in a new issue