-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy patheval.py
More file actions
64 lines (53 loc) · 2 KB
/
Copy patheval.py
File metadata and controls
64 lines (53 loc) · 2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import sys
import yaml
import pandas as pd
import os
import json
from PIL import Image
from utils.visualize import visualize_predictions
from utils.build_vocab import clean_txt
from metrics.BLEU import compute_bleu
from metrics.CIDEr import compute_cider
def run_eval(config):
EXPERIMENT_NAME = config.get("EXPERIMENT_NAME")
EXPERIMENT_PATH = f"saved_models/{EXPERIMENT_NAME}"
if not os.path.exists(f"{EXPERIMENT_PATH}/test_results.csv"):
raise Exception("no test results found for the passed experiment")
df = pd.read_csv(f"{EXPERIMENT_PATH}/test_results.csv")
image_dir = config["DATASET"]["TEST"]["IMAGE_DIR"]
sampled_df = df.sample(n = 6)
visualize_predictions(
savepath=f"{EXPERIMENT_PATH}/viz.png",
images=[
Image.open(f"{image_dir}/{sampled_df.loc[i]['image']}").convert("RGB").resize((224, 224))
for i in sampled_df.index
],
true_captions=[sampled_df.loc[i]['caption'] for i in sampled_df.index],
pred_captions=[sampled_df.loc[i]['generated_caption'] for i in sampled_df.index],
n_rows=2,
n_cols=3
)
metrics = {}
df["caption"] = df["caption"].apply(clean_txt)
df["generated_caption"] = df["generated_caption"].apply(clean_txt)
df = df.set_index("image")
all_ims = list(df.index.unique())
all_refs = [df.loc[im]["caption"].to_list() for im in all_ims]
all_preds = [df.loc[im]["generated_caption"].unique().tolist()[0] for im in all_ims]
metrics["BLUE@4"] = compute_bleu(all_refs, all_preds, n=4)
metrics["CIDEr"] = compute_cider(all_preds, all_refs, max_n=5)
with open(f"{EXPERIMENT_PATH}/metrics.json", "w") as f:
json.dump(metrics, f)
return
def main():
args = sys.argv[1:]
if len(args) != 1:
print("Usage: python eval.py <config.yaml>")
return
config_path = args[0]
with open(config_path, "r") as f:
config = yaml.safe_load(f)
run_eval(config)
return
if __name__ == "__main__":
main()