separate make_csv from the file
This commit is contained in:
parent
bfa98666a6
commit
6fe5344fa6
1 changed files with 10 additions and 49 deletions
|
|
@ -15,7 +15,7 @@
|
||||||
#
|
#
|
||||||
"""
|
"""
|
||||||
Usage:
|
Usage:
|
||||||
python make_table_results.py <input_dir>
|
python make_csv.py <input_dir> <output_dir>
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
|
@ -32,43 +32,6 @@ logging.basicConfig(level=logging.INFO)
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def make_table(result_dict):
|
|
||||||
"""Generate table of results."""
|
|
||||||
md_writer = MarkdownTableWriter()
|
|
||||||
latex_writer = LatexTableWriter()
|
|
||||||
md_writer.headers = ["Model", "Precision", "Arc", "Hellaswag", "MMLU", "TruthfulQA","Winogrande", "GSM8K"]
|
|
||||||
latex_writer.headers = ["Model", "Precision", "Arc", "Hellaswag", "MMLU", "TruthfulQA","Winogrande", "GSM8K"]
|
|
||||||
|
|
||||||
tasks = ["arc", "hellaswag", "mmlu", "truthfulqa", "winogrande", "gsm8k"]
|
|
||||||
values = []
|
|
||||||
for model, model_results in result_dict.items():
|
|
||||||
for precision, prec_results in model_results.items():
|
|
||||||
value = [model, precision]
|
|
||||||
for task in tasks:
|
|
||||||
|
|
||||||
task_results = prec_results.get(task, None)
|
|
||||||
if task_results is None:
|
|
||||||
value.append("")
|
|
||||||
else:
|
|
||||||
m = task_to_metric[task]
|
|
||||||
results = task_results["results"]
|
|
||||||
if len(results) > 1:
|
|
||||||
result = results[task]
|
|
||||||
else:
|
|
||||||
result = list(results.values())[0]
|
|
||||||
value.append("%.2f" % (result[m] * 100))
|
|
||||||
values.append(value)
|
|
||||||
model = ""
|
|
||||||
precision = ""
|
|
||||||
|
|
||||||
md_writer.value_matrix = values
|
|
||||||
latex_writer.value_matrix = values
|
|
||||||
|
|
||||||
# todo: make latex table look good
|
|
||||||
# print(latex_writer.dumps())
|
|
||||||
|
|
||||||
return md_writer.dumps()
|
|
||||||
|
|
||||||
def make_csv(result_dict, output_path=None):
|
def make_csv(result_dict, output_path=None):
|
||||||
current_date = datetime.datetime.now().strftime("%Y-%m-%d")
|
current_date = datetime.datetime.now().strftime("%Y-%m-%d")
|
||||||
file_name = f'results_{current_date}.csv'
|
file_name = f'results_{current_date}.csv'
|
||||||
|
|
@ -102,7 +65,7 @@ def merge_results(path):
|
||||||
# for each dir, load json files
|
# for each dir, load json files
|
||||||
print('Read from', path)
|
print('Read from', path)
|
||||||
merged_results = dict()
|
merged_results = dict()
|
||||||
for dirpath, dirnames, filenames in os.walk(sys.argv[1]):
|
for dirpath, dirnames, filenames in os.walk(path):
|
||||||
# skip dirs without files
|
# skip dirs without files
|
||||||
if not filenames:
|
if not filenames:
|
||||||
continue
|
continue
|
||||||
|
|
@ -120,19 +83,17 @@ def merge_results(path):
|
||||||
|
|
||||||
|
|
||||||
def main(*args):
|
def main(*args):
|
||||||
if len(args) > 1:
|
assert len(args) > 2, \
|
||||||
input_path = args[1]
|
"""Usage:
|
||||||
else:
|
python make_csv.py <input_dir> <output_dir>
|
||||||
raise ValueError("Input path is required")
|
"""
|
||||||
|
|
||||||
|
input_path = args[1]
|
||||||
|
output_path = args[2]
|
||||||
|
|
||||||
if len(args) > 2:
|
|
||||||
output_path = args[2] # use the third argument as the output path
|
|
||||||
else:
|
|
||||||
output_path = "./" # default to current directory
|
|
||||||
|
|
||||||
merged_results = merge_results(input_path)
|
merged_results = merge_results(input_path)
|
||||||
make_csv(merged_results, output_path)
|
make_csv(merged_results, output_path)
|
||||||
print(make_table(merged_results))
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
Loading…
Reference in a new issue