-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathpck_schematic.py
More file actions
173 lines (148 loc) · 6.08 KB
/
pck_schematic.py
File metadata and controls
173 lines (148 loc) · 6.08 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
"""
This script is used to create a schematic of the PCK (Percentage of Correct Keypoints) metric.
It is used to visualize the PCK metric for a given keypoint and a given threshold.
This is not directly used in the ground-truth pipeline, but is useful for understanding the PCK metric and will be a panel in the paper.
The script is hard-coded since it is just an example.
"""
# Produces: Figure 1A (PCK evaluation schematic panel).
# -----------------------------------------------------------------------
# NOTE: img_path and output_svg_path below are set to internal cluster
# paths used during the experiments. Update them for your environment.
# -----------------------------------------------------------------------
import cv2
import numpy as np
import pandas as pd
import plotnine as p9
# Helper to convert BGR to hex for plotnine
def bgr_to_hex(bgr_color):
return "#{:02x}{:02x}{:02x}".format(bgr_color[2], bgr_color[1], bgr_color[0])
# Helper to generate points for a circle
def get_circle_df(center_x, center_y, radius, color_hex, group_id, num_points=100):
points = []
for i in range(num_points + 1): # +1 to close the circle path
angle = 2 * np.pi * i / num_points
x = center_x + radius * np.cos(angle)
y = center_y + radius * np.sin(angle)
points.append((x, y))
df = pd.DataFrame(points, columns=['x', 'y'])
df['color'] = color_hex
df['group'] = group_id
return df
# --- Configuration ---
img_path = "/projects/kumar-lab/choij/maze-gt-evaluation/plots/pck_example_image.png"
output_svg_path = "/projects/kumar-lab/choij/maze-gt-evaluation/plots/pck_visualization_plotnine.svg"
# Load image and get dimensions
img_bgr = cv2.imread(img_path)
if img_bgr is None:
print(f"Error: Could not load image from {img_path}")
exit()
img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
img_height, img_width = img_rgb.shape[:2]
# --- Prepare pixel DataFrame for background image ---
x_coords = np.arange(img_width)
y_coords = np.arange(img_height)
xv, yv = np.meshgrid(x_coords, y_coords)
rgb_flat = img_rgb.reshape(-1, 3)
hex_colors = [f"#{px[0]:02x}{px[1]:02x}{px[2]:02x}" for px in rgb_flat]
pixel_df = pd.DataFrame({
'x': xv.ravel(),
'y': yv.ravel(),
'color_val': hex_colors
})
# --- End of pixel DataFrame preparation ---
# Define keypoints (x, y)
nose = (351, 280)
tail_base = (197, 219)
# Compute body length
body_length = np.linalg.norm(np.array(nose) - np.array(tail_base))
# Define thresholds and colors
thresholds_pct = [0.1, 0.2, 0.5]
# PCK circle colors (original BGR)
set1_green_bgr = (74, 175, 77)
set1_blue_bgr = (184, 126, 55)
set1_red_bgr = (28, 26, 228)
pck_colors_bgr = [set1_green_bgr, set1_blue_bgr, set1_red_bgr]
pck_colors_hex = [bgr_to_hex(c) for c in pck_colors_bgr]
# Keypoint colors (original BGR)
nose_color_bgr = (255, 0, 255) # Magenta
tail_color_bgr = (255, 255, 0) # Cyan
nose_color_hex = bgr_to_hex(nose_color_bgr)
tail_color_hex = bgr_to_hex(tail_color_bgr)
# --- Data Preparation for Plotnine ---
# Keypoints DataFrame
keypoints_df = pd.DataFrame({
'x': [nose[0], tail_base[0]],
'y': [nose[1], tail_base[1]],
'label': ['Nose', 'Tail Base'],
'color': [nose_color_hex, tail_color_hex]
})
# PCK Circles DataFrame
pck_circles_list = []
for i, t_pct in enumerate(thresholds_pct):
radius = body_length * t_pct
circle_df = get_circle_df(nose[0], nose[1], radius, pck_colors_hex[i], group_id=f"pck_circle_{i}")
pck_circles_list.append(circle_df)
all_pck_circles_df = pd.concat(pck_circles_list, ignore_index=True)
# Legend Info DataFrame
legend_texts_data = [
{"text": "Magenta = Nose", "y_pos": 20, "color": nose_color_hex},
{"text": "Cyan = Tail Base", "y_pos": 40, "color": tail_color_hex},
{"text": "10% body length", "y_pos": 60, "color": pck_colors_hex[0]},
{"text": "20% body length", "y_pos": 80, "color": pck_colors_hex[1]},
{"text": "50% body length", "y_pos": 100, "color": pck_colors_hex[2]}
]
legend_df = pd.DataFrame(legend_texts_data)
legend_df['x'] = 10 # Common x position for all legend items
# Font size for legend text (approximate conversion from OpenCV)
plotnine_font_size = 10 # pt
# Consolidate all unique colors for scale_color_manual
all_colors_used = pd.concat([
keypoints_df[['color']],
all_pck_circles_df[['color']],
legend_df[['color']]
])['color'].unique().tolist()
color_mapping_dict = {c: c for c in all_colors_used}
# --- Plotnine Visualization ---
p = (
p9.ggplot()
# Background image using geom_raster
+ p9.geom_raster(data=pixel_df, mapping=p9.aes(x='x', y='y', fill='color_val'), interpolate=False)
+ p9.scale_fill_identity()
# PCK Circles
+ p9.geom_path(
data=all_pck_circles_df,
mapping=p9.aes(x='x', y='y', group='group', color='color'),
size=1 # Thickness of the circle lines
)
# Keypoints (solid filled circles)
+ p9.geom_point(
data=keypoints_df,
mapping=p9.aes(x='x', y='y', color='color'),
size=3 # Size of the points
)
# Legend Text
+ p9.geom_text(
data=legend_df,
mapping=p9.aes(x='x', y='y_pos', label='text', color='color'),
size=plotnine_font_size,
ha='left', # Horizontal alignment
va='top', # Vertical alignment (anchor text from top-left)
nudge_x=0,
nudge_y=0
)
# Manual color scale for all elements that use the 'color' aesthetic
+ p9.scale_color_manual(values=color_mapping_dict)
# Coordinate system, scales, and theme
+ p9.scale_x_continuous(limits=(0, img_width), expand=(0, 0))
+ p9.scale_y_continuous(limits=(img_height, 0), expand=(0, 0)) # Y inverted for image coords (0 at top)
+ p9.coord_fixed(ratio=1) # Ensure 1:1 aspect ratio
+ p9.theme_void() # Minimal theme (no axes, gridlines, etc.)
+ p9.theme(legend_position='none') # Disable plotnine's automatic legend
)
# Save the plot
dpi = 100
fig_width_inches = img_width / dpi
fig_height_inches = img_height / dpi
p9.ggsave(p, filename=output_svg_path, format='svg',
width=fig_width_inches, height=fig_height_inches, units='in', dpi=dpi, verbose=False)
print(f"Plotnine SVG visualization saved to {output_svg_path}")