separate make_csv from the file
This commit is contained in:
parent
bfa98666a6
commit
6fe5344fa6
1 changed files with 10 additions and 49 deletions
|
|
@ -15,7 +15,7 @@
|
|||
#
|
||||
"""
|
||||
Usage:
|
||||
python make_table_results.py <input_dir>
|
||||
python make_csv.py <input_dir> <output_dir>
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
|
@ -32,43 +32,6 @@ logging.basicConfig(level=logging.INFO)
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def make_table(result_dict):
|
||||
"""Generate table of results."""
|
||||
md_writer = MarkdownTableWriter()
|
||||
latex_writer = LatexTableWriter()
|
||||
md_writer.headers = ["Model", "Precision", "Arc", "Hellaswag", "MMLU", "TruthfulQA","Winogrande", "GSM8K"]
|
||||
latex_writer.headers = ["Model", "Precision", "Arc", "Hellaswag", "MMLU", "TruthfulQA","Winogrande", "GSM8K"]
|
||||
|
||||
tasks = ["arc", "hellaswag", "mmlu", "truthfulqa", "winogrande", "gsm8k"]
|
||||
values = []
|
||||
for model, model_results in result_dict.items():
|
||||
for precision, prec_results in model_results.items():
|
||||
value = [model, precision]
|
||||
for task in tasks:
|
||||
|
||||
task_results = prec_results.get(task, None)
|
||||
if task_results is None:
|
||||
value.append("")
|
||||
else:
|
||||
m = task_to_metric[task]
|
||||
results = task_results["results"]
|
||||
if len(results) > 1:
|
||||
result = results[task]
|
||||
else:
|
||||
result = list(results.values())[0]
|
||||
value.append("%.2f" % (result[m] * 100))
|
||||
values.append(value)
|
||||
model = ""
|
||||
precision = ""
|
||||
|
||||
md_writer.value_matrix = values
|
||||
latex_writer.value_matrix = values
|
||||
|
||||
# todo: make latex table look good
|
||||
# print(latex_writer.dumps())
|
||||
|
||||
return md_writer.dumps()
|
||||
|
||||
def make_csv(result_dict, output_path=None):
|
||||
current_date = datetime.datetime.now().strftime("%Y-%m-%d")
|
||||
file_name = f'results_{current_date}.csv'
|
||||
|
|
@ -102,7 +65,7 @@ def merge_results(path):
|
|||
# for each dir, load json files
|
||||
print('Read from', path)
|
||||
merged_results = dict()
|
||||
for dirpath, dirnames, filenames in os.walk(sys.argv[1]):
|
||||
for dirpath, dirnames, filenames in os.walk(path):
|
||||
# skip dirs without files
|
||||
if not filenames:
|
||||
continue
|
||||
|
|
@ -120,19 +83,17 @@ def merge_results(path):
|
|||
|
||||
|
||||
def main(*args):
|
||||
if len(args) > 1:
|
||||
input_path = args[1]
|
||||
else:
|
||||
raise ValueError("Input path is required")
|
||||
|
||||
if len(args) > 2:
|
||||
output_path = args[2] # use the third argument as the output path
|
||||
else:
|
||||
output_path = "./" # default to current directory
|
||||
assert len(args) > 2, \
|
||||
"""Usage:
|
||||
python make_csv.py <input_dir> <output_dir>
|
||||
"""
|
||||
|
||||
input_path = args[1]
|
||||
output_path = args[2]
|
||||
|
||||
|
||||
merged_results = merge_results(input_path)
|
||||
make_csv(merged_results, output_path)
|
||||
print(make_table(merged_results))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
Loading…
Reference in a new issue