CDM-Demo / evaluation.py
wufan's picture
Upload 47 files
685cc58 verified
raw
history blame
No virus
11 kB
import sys
import os
import re
import json
import time
import shutil
import argparse
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
from datetime import datetime
from multiprocessing import Pool
from multiprocessing.dummy import Pool as ThreadPool
from PIL import Image, ImageDraw
from skimage.measure import ransac
from modules.latex2bbox_color import latex2bbox_color
from modules.tokenize_latex.tokenize_latex import tokenize_latex
from modules.visual_matcher import HungarianMatcher, SimpleAffineTransform
def gen_color_list(num=10, gap=15):
num += 1
single_num = 255 // gap + 1
max_num = single_num ** 3
num = min(num, max_num)
color_list = []
for idx in range(num):
R = idx // single_num**2
GB = idx % single_num**2
G = GB // single_num
B = GB % single_num
color_list.append((R*gap, G*gap, B*gap))
return color_list[1:]
def update_inliers(ori_inliers, sub_inliers):
inliers = np.copy(ori_inliers)
sub_idx = -1
for idx in range(len(ori_inliers)):
if ori_inliers[idx] == False:
sub_idx += 1
if sub_inliers[sub_idx] == True:
inliers[idx] = True
return inliers
def reshape_inliers(ori_inliers, sub_inliers):
inliers = np.copy(ori_inliers)
sub_idx = -1
for idx in range(len(ori_inliers)):
if ori_inliers[idx] == False:
sub_idx += 1
if sub_inliers[sub_idx] == True:
inliers[idx] = True
else:
inliers[idx] = False
return inliers
def gen_token_order(box_list):
new_box_list = copy.deepcopy(box_list)
for idx, box in enumerate(new_box_list):
new_box_list[idx]['order'] = idx / len(new_box_list)
return new_box_list
def evaluation(data_root, user_id="test"):
data_root = os.path.join(data_root, user_id)
gt_box_dir = os.path.join(data_root, "gt")
pred_box_dir = os.path.join(data_root, "pred")
match_vis_dir = os.path.join(data_root, "vis_match")
os.makedirs(match_vis_dir, exist_ok=True)
max_iter = 3
min_samples = 3
residual_threshold = 25
max_trials = 50
metrics_per_img = {}
gt_basename_list = [item.split(".")[0] for item in os.listdir(os.path.join(gt_box_dir, 'bbox'))]
for basename in tqdm(gt_basename_list):
gt_valid, pred_valid = True, True
if not os.path.exists(os.path.join(gt_box_dir, 'bbox', basename+".jsonl")):
gt_valid = False
else:
with open(os.path.join(gt_box_dir, 'bbox', basename+".jsonl"), 'r') as f:
box_gt = []
for line in f:
info = json.loads(line)
if info['bbox']:
box_gt.append(info)
if not box_gt:
gt_valid = False
if not gt_valid:
continue
if not os.path.exists(os.path.join(pred_box_dir, 'bbox', basename+".jsonl")):
pred_valid = False
else:
with open(os.path.join(pred_box_dir, 'bbox', basename+".jsonl"), 'r') as f:
box_pred = []
for line in f:
info = json.loads(line)
if info['bbox']:
box_pred.append(info)
if not box_pred:
pred_valid = False
if not pred_valid:
metrics_per_img[basename] = {
"recall": 0,
"precision": 0,
"F1_score": 0,
}
continue
gt_img_path = os.path.join(gt_box_dir, 'vis', basename+"_base.png")
pred_img_path = os.path.join(pred_box_dir, 'vis', basename+"_base.png")
img_gt = Image.open(gt_img_path)
img_pred = Image.open(pred_img_path)
matcher = HungarianMatcher()
matched_idxes = matcher(box_gt, box_pred, img_gt.size, img_pred.size)
src = []
dst = []
for (idx1, idx2) in matched_idxes:
x1min, y1min, x1max, y1max = box_gt[idx1]['bbox']
x2min, y2min, x2max, y2max = box_pred[idx2]['bbox']
x1_c, y1_c = float((x1min+x1max)/2), float((y1min+y1max)/2)
x2_c, y2_c = float((x2min+x2max)/2), float((y2min+y2max)/2)
src.append([y1_c, x1_c])
dst.append([y2_c, x2_c])
src = np.array(src)
dst = np.array(dst)
if src.shape[0] <= min_samples:
inliers = np.array([True for _ in matched_idxes])
else:
inliers = np.array([False for _ in matched_idxes])
for i in range(max_iter):
if src[inliers==False].shape[0] <= min_samples:
break
model, inliers_1 = ransac((src[inliers==False], dst[inliers==False]), SimpleAffineTransform, min_samples=min_samples, residual_threshold=residual_threshold, max_trials=max_trials, random_state=42)
if inliers_1 is not None and inliers_1.any():
inliers = update_inliers(inliers, inliers_1)
else:
break
if len(inliers[inliers==True]) >= len(matched_idxes):
break
for idx, (a,b) in enumerate(matched_idxes):
if inliers[idx] == True and matcher.cost['token'][a, b] == 1:
inliers[idx] = False
final_match_num = len(inliers[inliers==True])
recall = round(final_match_num/(len(box_gt)), 3)
precision = round(final_match_num/(len(box_pred)), 3)
F1_score = round(2*final_match_num/(len(box_gt)+len(box_pred)), 3)
metrics_per_img[basename] = {
"recall": recall,
"precision": precision,
"F1_score": F1_score,
}
if True:
gap = 5
W1, H1 = img_gt.size
W2, H2 = img_pred.size
H = H1 + H2 + gap
W = max(W1, W2)
vis_img = Image.new('RGB', (W, H), (255, 255, 255))
vis_img.paste(img_gt, (0, 0))
vis_img.paste(Image.new('RGB', (W, gap), (120, 120, 120)), (0, H1))
vis_img.paste(img_pred, (0, H1+gap))
match_img = vis_img.copy()
match_draw = ImageDraw.Draw(match_img)
gt_matched_idx = {
a: flag
for (a,b), flag in
zip(matched_idxes, inliers)
}
pred_matched_idx = {
b: flag
for (a,b), flag in
zip(matched_idxes, inliers)
}
for idx, box in enumerate(box_gt):
if idx in gt_matched_idx and gt_matched_idx[idx]==True:
color = "green"
else:
color = "red"
x_min, y_min, x_max, y_max = box['bbox']
match_draw.rectangle([x_min-1, y_min-1, x_max+1, y_max+1], fill=None, outline=color, width=2)
for idx, box in enumerate(box_pred):
if idx in pred_matched_idx and pred_matched_idx[idx]==True:
color = "green"
else:
color = "red"
x_min, y_min, x_max, y_max = box['bbox']
match_draw.rectangle([x_min-1, y_min-1+H1+gap, x_max+1, y_max+1+H1+gap], fill=None, outline=color, width=2)
vis_img.save(os.path.join(match_vis_dir, basename+"_base.png"))
match_img.save(os.path.join(match_vis_dir, basename+".png"))
score_list = [val['F1_score'] for _, val in metrics_per_img.items()]
exp_list = [1 if score==1 else 0 for score in score_list]
metrics_res = {
"mean_score": round(np.mean(score_list), 3),
"exp_rate": round(np.mean(exp_list), 3),
"details": metrics_per_img
}
metric_res_path = os.path.join(data_root, "metrics_res.json")
with open(metric_res_path, "w") as f:
f.write(json.dumps(metrics_res, indent=2))
return metrics_res, metric_res_path, match_vis_dir
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--input', '-i', type=str, default="assets/example/input_example.json")
parser.add_argument('--output', '-o', type=str, default="output")
parser.add_argument('--pools', '-p', type=int, default=240)
args = parser.parse_args()
print(args)
json_input, data_root, pool_num = args.input, args.output, args.pools
temp_dir = os.path.join(data_root, "temp_dir")
exp_name = os.path.basename(json_input).split('.')[0]
with open(json_input, "r") as f:
input_data = json.load(f)
img_ids = []
groundtruths = []
predictions = []
for idx, item in enumerate(input_data):
if "img_id" in item:
img_ids.append(item["img_id"])
else:
img_ids.append(f"sample_{idx}")
groundtruths.append(item['gt'])
predictions.append(item['pred'])
a = time.time()
user_id = exp_name
total_color_list = gen_color_list(num=5800)
data_root = os.path.join(data_root, user_id)
output_dir_info = {}
input_args = []
for subset, latex_list in zip(['gt', 'pred'], [groundtruths, predictions]):
sub_temp_dir = os.path.join(temp_dir, f"{exp_name}_{subset}")
os.makedirs(sub_temp_dir, exist_ok=True)
output_path = os.path.join(data_root, subset)
output_dir_info[output_path] = []
os.makedirs(os.path.join(output_path, 'bbox'), exist_ok=True)
os.makedirs(os.path.join(output_path, 'vis'), exist_ok=True)
for idx, latex in tqdm(enumerate(latex_list), desc=f"collect {subset} latex ..."):
basename = img_ids[idx]
input_arg = latex, basename, output_path, sub_temp_dir, total_color_list
input_args.append(input_arg)
if pool_num > 1:
print(datetime.now().strftime('%Y-%m-%d %H:%M:%S'), "using processpool, pool num:", pool_num, ", job num:", len(input_args))
myP = Pool(args.pools)
for input_arg in input_args:
myP.apply_async(latex2bbox_color, args=(input_arg,))
myP.close()
myP.join()
else:
for input_arg in input_args:
latex2bbox_color(input_arg)
b = time.time()
print(datetime.now().strftime('%Y-%m-%d %H:%M:%S'), "extract bbox done, time cost:", round(b-a, 3), "s")
for subset in ['gt', 'pred']:
shutil.rmtree(os.path.join(temp_dir, f"{exp_name}_{subset}"))
c = time.time()
metrics_res, metric_res_path, match_vis_dir = evaluation(args.output, exp_name)
d = time.time()
print(datetime.now().strftime('%Y-%m-%d %H:%M:%S'), "calculate metrics done, time cost:", round(d-c, 3), "s")
print(f"=> process done, mean f1 score: {metrics_res['mean_score']}.")
print(f"=> more details of metrics are saved in `{metric_res_path}`")
print(f"=> visulization images are saved under `{match_vis_dir}`")