Skip to content

Commit a2b7d79

Browse files
timblakelycopybara-github
authored andcommitted
Optimizations for finding decision points on sparse volumes
- If the number of voxels is below a certain threshold do not count it as a segment to expand. - If number of unique segments in a subvolume is <2 do not expand. PiperOrigin-RevId: 888191240
1 parent dccf620 commit a2b7d79

2 files changed

Lines changed: 63 additions & 5 deletions

File tree

ffn/utils/decision_point.py

Lines changed: 30 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,14 @@
2424
from scipy import ndimage
2525

2626

27-
def find_decision_points(seg: np.ndarray,
28-
voxel_size: Sequence[float],
29-
max_distance: Optional[float] = None,
30-
subvol_box: Optional[bounding_box.BoundingBox] = None
31-
) -> dict[tuple[int, int], tuple[float, np.ndarray]]:
27+
def find_decision_points(
28+
seg: np.ndarray,
29+
voxel_size: Sequence[float],
30+
max_distance: Optional[float] = None,
31+
subvol_box: Optional[bounding_box.BoundingBox] = None,
32+
optimize_sparse: bool = False,
33+
sparse_noise_threshold: int = 0,
34+
) -> dict[tuple[int, int], tuple[float, np.ndarray]]:
3235
"""Identifies decision points in a segmentation subvolume.
3336
3437
Args:
@@ -39,12 +42,34 @@ def find_decision_points(seg: np.ndarray,
3942
subvol_box: selector for a subvolume within `seg` within which
4043
to search for decision points; the whole subvolume is always used
4144
to compute the distance transform
45+
optimize_sparse: if True, first counts the number of segments in `seg`
46+
and returns early if there are fewer than 2.
47+
sparse_noise_threshold: if > 0 and `optimize_sparse` is True, ignores
48+
components with voxel counts <= this threshold when counting segments.
4249
4350
Returns:
4451
dict from segment ID pairs to tuples of:
4552
approximate physical distance from the segment to the decision point
4653
(x, y, z) decision point
4754
"""
55+
if optimize_sparse:
56+
# Use pandas to quickly get the unique labels and their counts (excluding 0)
57+
counts = pd.Series(seg.ravel()).value_counts(sort=False)
58+
if 0 in counts:
59+
counts = counts.drop(0)
60+
61+
if sparse_noise_threshold > 0:
62+
small_segs = counts[counts <= sparse_noise_threshold].index.to_numpy()
63+
if len(small_segs) > 0:
64+
# Zero out small segments and update counts
65+
seg[np.isin(seg, small_segs)] = 0
66+
counts = counts.drop(small_segs)
67+
68+
if len(counts) <= 1:
69+
# If there are 0 or 1 unique segments (excluding background),
70+
# they cannot possibly touch another segment.
71+
return {}
72+
4873
# EDT is the Euclidean Distance Transform, specifying how far voxels added
4974
# in 'expanded_seg' are from the seeds in 'seg'.
5075
expanded_seg, edt = labels.watershed_expand(seg, voxel_size, max_distance)

ffn/utils/tests/decision_point_test.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,39 @@ def test_find_decision_point(self):
5555
self.assertIn((1, 2), points)
5656
self.assertLen(points, 1)
5757

58+
def test_find_decision_point_optimize_sparse(self):
59+
# 2 segments, but one is very small and should be filtered out
60+
seg = np.zeros((100, 80, 60), dtype=np.uint64)
61+
seg[:40, :, :] = 1
62+
# 2 voxels of label 2
63+
seg[60, 0, 0] = 2
64+
seg[61, 0, 0] = 2
65+
66+
# Without optimization, 1 and 2 might connect if they are grown far enough.
67+
points = decision_point.find_decision_points(seg, (1, 1, 1))
68+
self.assertIn((1, 2), points)
69+
70+
# With optimization but threshold 0, they still connect (no size filtering)
71+
points = decision_point.find_decision_points(
72+
seg, (1, 1, 1), optimize_sparse=True, sparse_noise_threshold=0
73+
)
74+
self.assertIn((1, 2), points)
75+
76+
# With optimization and threshold >= 2, label 2 is zeroed.
77+
# We're left with label 1 (size > 2), so 1 segment -> returns empty
78+
points = decision_point.find_decision_points(
79+
seg, (1, 1, 1), optimize_sparse=True, sparse_noise_threshold=2
80+
)
81+
self.assertEmpty(points)
82+
83+
# With just 1 segment from the start
84+
seg = np.zeros((100, 80, 60), dtype=np.uint64)
85+
seg[:40, :, :] = 1
86+
points = decision_point.find_decision_points(
87+
seg, (1, 1, 1), optimize_sparse=True, sparse_noise_threshold=0
88+
)
89+
self.assertEmpty(points)
90+
5891

5992
if __name__ == '__main__':
6093
absltest.main()

0 commit comments

Comments
 (0)