ipex-llm/python/llm/dev/benchmark/LongBench/metrics.py
Xu, Shuo ee33b93464
Longbench: NV code to ipex-llm (#11662)
* add nv longbench

* LongBench: NV code to ipex-llm

* ammend

* add more models support

* ammend

* optimize LongBench's user experience

* ammend

* ammend

* fix typo

* ammend

* remove cuda related information & add a readme

* add license to python scripts & polish the readme

* ammend

* ammend

---------

Co-authored-by: cyita <yitastudy@gmail.com>
Co-authored-by: ATMxsp01 <shou.xu@intel.com>
Co-authored-by: leonardozcm <leonardo1997zcm@gmail.com>
2024-09-18 15:55:14 +08:00

165 lines
5.8 KiB
Python

#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# This file is adapted from
# https://github.com/THUDM/LongBench/blob/main/metrics.py
# and
# https://github.com/FasterDecoding/SnapKV/blob/main/experiments/LongBench/metrics.py
import re
import string
import jieba
from fuzzywuzzy import fuzz
import difflib
from typing import List
from collections import Counter
from rouge import Rouge
def normalize_answer(s):
"""Lower text and remove punctuation, articles and extra whitespace."""
def remove_articles(text):
return re.sub(r"\b(a|an|the)\b", " ", text)
def white_space_fix(text):
return " ".join(text.split())
def remove_punc(text):
exclude = set(string.punctuation)
return "".join(ch for ch in text if ch not in exclude)
def lower(text):
return text.lower()
return white_space_fix(remove_articles(remove_punc(lower(s))))
def normalize_zh_answer(s):
"""Lower text and remove punctuation, extra whitespace."""
def white_space_fix(text):
return "".join(text.split())
def remove_punc(text):
cn_punctuation = "!?。。"#$%&'()*+,-/:;<=>@[\]^_`{|}~⦅⦆「」、、〃》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟〰〾〿–—‘’‛“”„‟…‧﹏."
all_punctuation = set(string.punctuation + cn_punctuation)
return "".join(ch for ch in text if ch not in all_punctuation)
def lower(text):
return text.lower()
return white_space_fix(remove_punc(lower(s)))
def count_score(prediction, ground_truth, **kwargs):
numbers = re.findall(r"\d+", prediction)
right_num = 0
for number in numbers:
if str(number) == str(ground_truth):
right_num += 1
final_score = 0.0 if len(numbers) == 0 else right_num / len(numbers)
return float(final_score)
def retrieval_score(prediction, ground_truth, **kwargs):
pattern = r'Paragraph (\d+)'
matches = re.findall(pattern, ground_truth)
ground_truth_id = matches[0]
numbers = re.findall(r"\d+", prediction)
right_num = 0
for number in numbers:
if str(number) == str(ground_truth_id):
right_num += 1
final_score = 0.0 if len(numbers) == 0 else right_num / len(numbers)
return float(final_score)
def retrieval_zh_score(prediction, ground_truth, **kwargs):
pattern = r'段落(\d+)'
matches = re.findall(pattern, ground_truth)
ground_truth_id = matches[0]
numbers = re.findall(r"\d+", prediction)
right_num = 0
for number in numbers:
if str(number) == str(ground_truth_id):
right_num += 1
final_score = 0.0 if len(numbers) == 0 else right_num / len(numbers)
return float(final_score)
def code_sim_score(prediction, ground_truth, **kwargs):
all_lines = prediction.lstrip('\n').split('\n')
prediction = ""
for line in all_lines:
if ('`' not in line) and ('#' not in line) and ('//' not in line):
prediction = line
break
return (fuzz.ratio(prediction, ground_truth) / 100)
def classification_score(prediction, ground_truth, **kwargs):
em_match_list = []
all_classes = kwargs["all_classes"]
for class_name in all_classes:
if class_name in prediction:
em_match_list.append(class_name)
for match_term in em_match_list:
if match_term in ground_truth and match_term != ground_truth:
em_match_list.remove(match_term)
if ground_truth in em_match_list:
score = (1.0 / len(em_match_list))
else:
score = 0.0
return score
def rouge_score(prediction, ground_truth, **kwargs):
rouge = Rouge()
try:
scores = rouge.get_scores([prediction], [ground_truth], avg=True)
except:
return 0.0
return scores["rouge-l"]["f"]
def rouge_zh_score(prediction, ground_truth, **kwargs):
prediction = " ".join(list(jieba.cut(prediction, cut_all=False)))
ground_truth = " ".join(list(jieba.cut(ground_truth, cut_all=False)))
score = rouge_score(prediction, ground_truth)
return score
def f1_score(prediction, ground_truth, **kwargs):
common = Counter(prediction) & Counter(ground_truth)
num_same = sum(common.values())
if num_same == 0:
return 0
precision = 1.0 * num_same / len(prediction)
recall = 1.0 * num_same / len(ground_truth)
f1 = (2 * precision * recall) / (precision + recall)
return f1
def qa_f1_score(prediction, ground_truth, **kwargs):
normalized_prediction = normalize_answer(prediction)
normalized_ground_truth = normalize_answer(ground_truth)
prediction_tokens = normalized_prediction.split()
ground_truth_tokens = normalized_ground_truth.split()
return f1_score(prediction_tokens, ground_truth_tokens)
def qa_f1_zh_score(prediction, ground_truth, **kwargs):
prediction_tokens = list(jieba.cut(prediction, cut_all=False))
ground_truth_tokens = list(jieba.cut(ground_truth, cut_all=False))
prediction_tokens = [normalize_zh_answer(token) for token in prediction_tokens]
ground_truth_tokens = [normalize_zh_answer(token) for token in ground_truth_tokens]
prediction_tokens = [token for token in prediction_tokens if len(token) > 0]
ground_truth_tokens = [token for token in ground_truth_tokens if len(token) > 0]
return f1_score(prediction_tokens, ground_truth_tokens)