-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathvis_matches.py
More file actions
72 lines (56 loc) · 2.41 KB
/
vis_matches.py
File metadata and controls
72 lines (56 loc) · 2.41 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import argparse
from pathlib import Path
from src.data.Data import Data
from src.model.ExReNet import ExReNet
from src.utils.Config import Config
from src.utils.Inference import Inference
import matplotlib.pyplot as plt
import numpy as np
import imageio
import cv2
parser = argparse.ArgumentParser(description="")
parser.add_argument('config')
parser.add_argument('model_dir')
parser.add_argument('image1')
parser.add_argument('image2')
args = parser.parse_args()
config = Config.from_file(args.config)
data = Data(config.get_with_prefix("data"))
model = ExReNet(config.get_with_prefix("model"), data)
model.load_weights(str(Path(args.model_dir) / "model.h5"))
image1 = imageio.imread(args.image1)
image1 = cv2.resize(image1, (data.image_size, data.image_size))
image2 = imageio.imread(args.image2)
image2 = cv2.resize(image2, (data.image_size, data.image_size))
cam_pose, matched_coordinates, all_dots, matching = model(image1[None] / 255.0, image2[None] / 255.0, training=False)
print("Click on the left image to see the matched point in the other image.")
full_matching = np.zeros((32, 32, 2))
for x in range(32):
for y in range(32):
full_matching[y, x] = matched_coordinates[-1][0, y * 4, x * 4]
fig = plt.figure()
plt.imshow(np.concatenate((image1, image2), axis=1), extent=[0, data.image_size * 2, data.image_size, 0])
for x in range(0, data.image_size * 2, 4):
plt.plot([x, x], [0, data.image_size], 'k-', linewidth=1, alpha=0.2)
for y in range(0, data.image_size, 4):
plt.plot([0, data.image_size * 2], [y, y], 'k-', linewidth=1, alpha=0.2)
# for (x,y) in [[13, 20]]:
# plt.plot([x * 4 + 2, 128+ matching[y, x, 0] + 2], [y * 4 + 2, matching[y, x, 1] + 2], 'g-')
lines = []
def onclick(event):
print('button=%d, x=%d, y=%d, xdata=%f, ydata=%f' % (event.button, event.x, event.y, event.xdata, event.ydata))
x = int(event.xdata)
y = int(event.ydata)
if 0 <= x < 128 and 0 <= y < 128:
for line in lines:
line.remove()
lines.clear()
x //= 4
y //= 4
lines.extend(plt.plot([x * 4 + 2, 128 + full_matching[y, x, 0] + 2], [y * 4 + 2, full_matching[y, x, 1] + 2], 'k-', linewidth=4))
lines.extend(plt.plot([x * 4 + 2, 128 + full_matching[y, x, 0] + 2], [y * 4 + 2, full_matching[y, x, 1] + 2], linewidth=2))
fig.canvas.draw()
cid = fig.canvas.mpl_connect('button_press_event', onclick)
plt.axis('off')
plt.show()
# plt.savefig("match_9.png", dpi=300)