Add ChatGLM C-Eval Evaluator (#10095)
* Add ChatGLM ceval evaluator * Modify ChatGLM Evaluator Reference
This commit is contained in:
		
							parent
							
								
									5e9710cec4
								
							
						
					
					
						commit
						3832eb0ce0
					
				
					 2 changed files with 232 additions and 4 deletions
				
			
		| 
						 | 
				
			
			@ -24,6 +24,7 @@ from tqdm import tqdm
 | 
			
		|||
from bigdl.llm.utils.common.log4Error import invalidInputError
 | 
			
		||||
from evaluators.qwen import QwenEvaluator
 | 
			
		||||
from evaluators.llama import LlamaEvaluator
 | 
			
		||||
from evaluators.chatglm import ChatGLMEvaluator
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
TASK_NAME_MAPPING = {
 | 
			
		||||
| 
						 | 
				
			
			@ -280,7 +281,6 @@ def main(args, evaluator):
 | 
			
		|||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    parser = argparse.ArgumentParser()
 | 
			
		||||
    parser.add_argument("--model_family", type=str, default="llama")
 | 
			
		||||
    parser.add_argument("--model_path", type=str, default="meta-llama/Llama-2-7b-chat-hf")
 | 
			
		||||
    parser.add_argument("--eval_type", type=str, default="validation")
 | 
			
		||||
    parser.add_argument("--device", type=str, default="xpu")
 | 
			
		||||
| 
						 | 
				
			
			@ -289,22 +289,39 @@ if __name__ == "__main__":
 | 
			
		|||
 | 
			
		||||
    args = parser.parse_args()
 | 
			
		||||
 | 
			
		||||
    if args.model_family == "llama":
 | 
			
		||||
    # decide the model family
 | 
			
		||||
    model_families = ['llama', 'qwen', 'chatglm']
 | 
			
		||||
 | 
			
		||||
    model_family = None
 | 
			
		||||
    for family in model_families:
 | 
			
		||||
        if family in args.model_path.lower():
 | 
			
		||||
            model_family = family
 | 
			
		||||
 | 
			
		||||
    assert model_family is not None, f"Model {args.model_path}'s model family is not implemented"
 | 
			
		||||
 | 
			
		||||
    if model_family == "llama":
 | 
			
		||||
        evaluator = LlamaEvaluator(
 | 
			
		||||
            choices=choices,
 | 
			
		||||
            model_path=args.model_path,
 | 
			
		||||
            device=args.device,
 | 
			
		||||
            qtype=args.qtype
 | 
			
		||||
        )
 | 
			
		||||
    elif args.model_family == "qwen":
 | 
			
		||||
    elif model_family == "qwen":
 | 
			
		||||
        evaluator = QwenEvaluator(
 | 
			
		||||
            choices=choices,
 | 
			
		||||
            model_path=args.model_path,
 | 
			
		||||
            device=args.device,
 | 
			
		||||
            qtype=args.qtype
 | 
			
		||||
        )
 | 
			
		||||
    elif model_family == "chatglm":
 | 
			
		||||
        evaluator = ChatGLMEvaluator(
 | 
			
		||||
            choices=choices,
 | 
			
		||||
            model_path=args.model_path,
 | 
			
		||||
            device=args.device,
 | 
			
		||||
            qtype=args.qtype
 | 
			
		||||
        )
 | 
			
		||||
    else:
 | 
			
		||||
        invalidInputError(
 | 
			
		||||
            False,
 | 
			
		||||
            "Invalid model_family, currently support llama and qwen only.")
 | 
			
		||||
            "Invalid model_family, currently support llama, qwen, and chatglm only.")
 | 
			
		||||
    main(args, evaluator=evaluator)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
							
								
								
									
										211
									
								
								python/llm/dev/benchmark/ceval/evaluators/chatglm.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										211
									
								
								python/llm/dev/benchmark/ceval/evaluators/chatglm.py
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
				
			
			@ -0,0 +1,211 @@
 | 
			
		|||
#
 | 
			
		||||
# Copyright 2016 The BigDL Authors.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
#
 | 
			
		||||
# refer to https://github.com/THUDM/ChatGLM2-6B/blob/main/evaluation/evaluate_ceval.py
 | 
			
		||||
 | 
			
		||||
import re
 | 
			
		||||
import torch
 | 
			
		||||
from tqdm import tqdm
 | 
			
		||||
from thefuzz import process
 | 
			
		||||
from transformers import AutoTokenizer
 | 
			
		||||
 | 
			
		||||
from evaluators.evaluator import Evaluator
 | 
			
		||||
from bigdl.llm.transformers import AutoModelForCausalLM
 | 
			
		||||
from transformers.generation.utils import LogitsProcessorList
 | 
			
		||||
from transformers.generation.logits_process import LogitsProcessor
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class InvalidScoreLogitsProcessor(LogitsProcessor):
 | 
			
		||||
    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
 | 
			
		||||
        if torch.isnan(scores).any() or torch.isinf(scores).any():
 | 
			
		||||
            scores.zero_()
 | 
			
		||||
            scores[..., 5] = 5e4
 | 
			
		||||
        return scores
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ChatGLMEvaluator(Evaluator):
 | 
			
		||||
    def __init__(self, choices, model_path="THUDM/chatglm-6b", device="xpu", qtype="sym_int4"):
 | 
			
		||||
        super(ChatGLMEvaluator, self).__init__(choices, model_path, device, qtype)
 | 
			
		||||
        self.tokenizer = AutoTokenizer.from_pretrained(
 | 
			
		||||
            self.model_path,
 | 
			
		||||
            trust_remote_code=True
 | 
			
		||||
        )
 | 
			
		||||
        self.model = AutoModelForCausalLM.from_pretrained(
 | 
			
		||||
            self.model_path,
 | 
			
		||||
            load_in_low_bit=self.qtype,
 | 
			
		||||
            optimize_model=True,
 | 
			
		||||
            use_cache=True,
 | 
			
		||||
            trust_remote_code=True
 | 
			
		||||
        ).eval().to(self.device)
 | 
			
		||||
  
 | 
			
		||||
    def generate_few_shot_prompt(self, subject, dev_df, cot=False):
 | 
			
		||||
        message = []
 | 
			
		||||
        k = self.k
 | 
			
		||||
        if self.k == -1:
 | 
			
		||||
            k = dev_df.shape[0]
 | 
			
		||||
        message.append(self.format_example(dev_df.iloc[0, :], cot=cot, add_prompt=f"以下是中国关于{subject}考试的单项选择题,请选出其中的正确答案。\n\n"))
 | 
			
		||||
        for i in range(1, k):
 | 
			
		||||
            message.append(self.format_example(dev_df.iloc[i, :], cot=cot))
 | 
			
		||||
        return message
 | 
			
		||||
        
 | 
			
		||||
    def format_example(self, line, include_answer=True, cot=False, add_prompt=''):
 | 
			
		||||
        example = add_prompt + line['question']
 | 
			
		||||
        # print(example)
 | 
			
		||||
        for choice in self.choices:
 | 
			
		||||
            example += f'\n{choice}. {line[f"{choice}"]}'
 | 
			
		||||
        example += '\n答案:'
 | 
			
		||||
        if include_answer:
 | 
			
		||||
            if cot:
 | 
			
		||||
                ans = "让我们一步一步思考,\n" + line["explanation"] + f"\n所以答案是{line['answer']}。"
 | 
			
		||||
            else:
 | 
			
		||||
                ans = line["answer"]
 | 
			
		||||
            m = (example, ans)
 | 
			
		||||
            return m
 | 
			
		||||
        return example
 | 
			
		||||
    
 | 
			
		||||
    def extract_cot_answer(self, line, gen_ans):
 | 
			
		||||
        m = re.findall(r'所以答案是(.+?)。', gen_ans, re.M)
 | 
			
		||||
        if len(m) > 0 and m[-1] in self.choices:
 | 
			
		||||
            return m[-1], True
 | 
			
		||||
        answer_patterns = [
 | 
			
		||||
            r'([ABCD])是正确的',
 | 
			
		||||
            r'选项([ABCD])正确',
 | 
			
		||||
            r'答案为([ABCD])',
 | 
			
		||||
            r'答案是([ABCD])',
 | 
			
		||||
            r'答案([ABCD])',
 | 
			
		||||
            r'选择([ABCD])',
 | 
			
		||||
            r'答案:([ABCD])',
 | 
			
		||||
            r'选择答案([ABCD])'
 | 
			
		||||
        ]
 | 
			
		||||
        # RE extraction
 | 
			
		||||
        for answer_pattern in answer_patterns:
 | 
			
		||||
            m = re.search(answer_pattern, gen_ans, re.M)
 | 
			
		||||
            if m:
 | 
			
		||||
                answer = m.group(1)
 | 
			
		||||
                return answer, False
 | 
			
		||||
        # only containing one choice-character
 | 
			
		||||
        m = re.findall(r'[ABCD]', gen_ans, re.M)
 | 
			
		||||
        if len(m) == 1:
 | 
			
		||||
            answer = m[0]
 | 
			
		||||
            return answer, False
 | 
			
		||||
        answer_word_counter = 0
 | 
			
		||||
        # only containing one choice-context
 | 
			
		||||
        for c in self.choices:
 | 
			
		||||
            if str(line[f'{c}']) in gen_ans:
 | 
			
		||||
                answer = c
 | 
			
		||||
                answer_word_counter += 1
 | 
			
		||||
        if answer_word_counter == 1:
 | 
			
		||||
            return answer, False
 | 
			
		||||
        return '-', False
 | 
			
		||||
    
 | 
			
		||||
    def build_prompt(self, text):
 | 
			
		||||
        return "[Round {}]\n\n问:{}\n\n答:".format(1, text)
 | 
			
		||||
 | 
			
		||||
    def generate_dist(self, model, tokenizer, query, history, max_length=2048,
 | 
			
		||||
                      do_sample=False, logits_processor=None):
 | 
			
		||||
        
 | 
			
		||||
        if history is None:
 | 
			
		||||
            history = []
 | 
			
		||||
        if logits_processor is None:
 | 
			
		||||
            logits_processor = LogitsProcessorList()
 | 
			
		||||
        logits_processor.append(InvalidScoreLogitsProcessor())
 | 
			
		||||
 | 
			
		||||
        if not history:
 | 
			
		||||
            prompt = query
 | 
			
		||||
        else:
 | 
			
		||||
            prompt = ""
 | 
			
		||||
            for i, (old_query, response) in enumerate(history):
 | 
			
		||||
                prompt += "[Round {}]\n问:{}\n答:{}\n".format(i, old_query, response)
 | 
			
		||||
            prompt += "[Round {}]\n问:{}\n答:".format(len(history), query)
 | 
			
		||||
 | 
			
		||||
        # first round prompt
 | 
			
		||||
        inputs = tokenizer([prompt], padding=True, return_tensors="pt",
 | 
			
		||||
                           truncation=True, max_length=max_length).to(model.device)
 | 
			
		||||
        
 | 
			
		||||
        # first round generation
 | 
			
		||||
        outputs = model.generate(**inputs, do_sample=do_sample, max_new_tokens=512)
 | 
			
		||||
 | 
			
		||||
        # organize intermediate_outputs
 | 
			
		||||
        intermediate_outputs = []
 | 
			
		||||
        for idx in range(len(outputs)):
 | 
			
		||||
            output = outputs.tolist()[idx][len(inputs["input_ids"][idx]):]
 | 
			
		||||
            response = tokenizer.decode(output)
 | 
			
		||||
            intermediate_outputs.append(response)
 | 
			
		||||
 | 
			
		||||
        # prepare second round prompt
 | 
			
		||||
        extraction_prompt = '综上所述,ABCD中正确的选项是:'
 | 
			
		||||
        answer_texts = [query + intermediate + "\n" + extraction_prompt for intermediate in intermediate_outputs]
 | 
			
		||||
        input_tokens = [self.build_prompt(answer_text) for answer_text in answer_texts]
 | 
			
		||||
        inputs = tokenizer(input_tokens, padding=True, return_tensors="pt", truncation=True, max_length=2048).to(model.device)
 | 
			
		||||
 | 
			
		||||
        # second round generation
 | 
			
		||||
        outputs = model(**inputs, return_last_logit=True)
 | 
			
		||||
 | 
			
		||||
        logits = outputs.logits[:, -1]
 | 
			
		||||
        choice_tokens = [tokenizer.encode(choice, add_special_tokens=False)[0] for choice in self.choices]
 | 
			
		||||
        logits = logits[:, choice_tokens]
 | 
			
		||||
        preds = logits.argmax(dim=-1)
 | 
			
		||||
 | 
			
		||||
        return self.choices[preds]
 | 
			
		||||
 | 
			
		||||
    @torch.no_grad()
 | 
			
		||||
    def eval_subject(
 | 
			
		||||
        self,
 | 
			
		||||
        subject_name,
 | 
			
		||||
        test_df,
 | 
			
		||||
        eval_type="validation", # "test","validation",
 | 
			
		||||
        dev_df=None,
 | 
			
		||||
        few_shot=False,
 | 
			
		||||
        cot=False,
 | 
			
		||||
    ):
 | 
			
		||||
        if eval_type == "validation":
 | 
			
		||||
            correct_num = 0
 | 
			
		||||
 | 
			
		||||
            if few_shot:
 | 
			
		||||
                history = self.generate_few_shot_prompt(subject_name, dev_df, cot=cot)
 | 
			
		||||
            else:
 | 
			
		||||
                history = []
 | 
			
		||||
 | 
			
		||||
            answers = list(test_df['answer'])
 | 
			
		||||
 | 
			
		||||
            for row_index, row in tqdm(test_df.iterrows(), total=len(test_df)):
 | 
			
		||||
                question = self.format_example(row, include_answer=False, cot=cot)
 | 
			
		||||
 | 
			
		||||
                if few_shot:
 | 
			
		||||
                    response, _ = self.model.chat(self.tokenizer, question, do_sample=False, history=history)
 | 
			
		||||
                    response = response.strip()
 | 
			
		||||
                    # For ChatGLM, we use answer extraction in answer-only mode too.
 | 
			
		||||
                    ans, direct_extract = self.extract_cot_answer(row, response)
 | 
			
		||||
                else:   # zero-shot by extracting answer from distribution
 | 
			
		||||
                    ans = self.generate_dist(self.model, self.tokenizer, question, do_sample=False, max_length=2048, history=history)
 | 
			
		||||
 | 
			
		||||
                if ans == answers[row_index]:
 | 
			
		||||
                    correct_num += 1
 | 
			
		||||
 | 
			
		||||
            correct_ratio = 100*correct_num/len(answers)
 | 
			
		||||
 | 
			
		||||
            return correct_ratio, None
 | 
			
		||||
        elif eval_type == "test":
 | 
			
		||||
            answers = {}
 | 
			
		||||
            for i, row in tqdm(test_df.iterrows(), total=len(test_df)):
 | 
			
		||||
                question = self.format_example(row)
 | 
			
		||||
                response, _ = self.model.chat(
 | 
			
		||||
                    self.tokenizer,
 | 
			
		||||
                    question,
 | 
			
		||||
                    history=None,
 | 
			
		||||
                )
 | 
			
		||||
                pred = self.extract_answer(response, row)
 | 
			
		||||
                answers[str(i)] = pred
 | 
			
		||||
            return None, answers
 | 
			
		||||
		Loading…
	
		Reference in a new issue