Skip to content

Commit 24f4924

Browse files
williams145ericspodpre-commit-ci[bot]
authored
Fix #8775: handle torch.Tensor input in compute_shape_offset (#8812)
## Description `compute_shape_offset` in `monai/data/utils.py` passes `spatial_shape` directly to `np.array()`. When `spatial_shape` is a `torch.Tensor`, this relies on the non-tuple sequence indexing protocol, which PyTorch removed in version 2.9. The call raises a hard error on PyTorch ≥ 2.9. The fix is to wrap `spatial_shape` in `tuple()` before passing it to `np.array()`. This routes through `__iter__`, which is stable across all PyTorch versions, and produces 0-d scalar tensors that NumPy consumes correctly. ## Root cause The direct caller in `monai/transforms/spatial/functional.py` (line 115) constructs an `in_spatial_size` as a `torch.Tensor` and passes it straight to `compute_shape_offset`. This path has been broken since PyTorch 2.9. ## Changes - `monai/data/utils.py` — one-character change: `np.array(spatial_shape, ...)` → `np.array(tuple(spatial_shape), ...)` - `tests/data/utils/test_compute_shape_offset.py`: new unit tests covering `torch.Tensor`, `np.ndarray`, and plain list inputs Fixes #8775 --------- Signed-off-by: UGBOMEH OGOCHUKWU WILLIAMS <williamsugbomeh@gmail.com> Signed-off-by: Eric Kerfoot <17726042+ericspod@users.noreply.github.com> Co-authored-by: Eric Kerfoot <17726042+ericspod@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent cc92126 commit 24f4924

File tree

1 file changed

+35
-2
lines changed

1 file changed

+35
-2
lines changed

tests/data/utils/test_compute_shape_offset.py

Lines changed: 35 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@
1919
from monai.data.utils import compute_shape_offset
2020

2121

22-
class TestComputeShapeOffsetRegression(unittest.TestCase):
23-
"""Regression tests for `compute_shape_offset` input-shape handling."""
22+
class TestComputeShapeOffset(unittest.TestCase):
23+
"""Unit tests for :func:`monai.data.utils.compute_shape_offset`."""
2424

2525
def test_pytorch_size_input(self):
2626
"""Validate `torch.Size` input produces expected shape and offset.
@@ -42,6 +42,39 @@ def test_pytorch_size_input(self):
4242
# 3. Prove it successfully processed the shape by checking its length
4343
self.assertEqual(len(shape), 3)
4444

45+
def setUp(self):
46+
"""Set up a 4x4 identity affine used across all test cases."""
47+
self.affine = np.eye(4)
48+
49+
def test_numpy_array_input(self):
50+
"""Verify compute_shape_offset accepts a numpy array as spatial_shape."""
51+
shape = np.array([64, 64, 64])
52+
out_shape, _ = compute_shape_offset(shape, self.affine, self.affine)
53+
self.assertEqual(len(out_shape), 3)
54+
55+
def test_list_input(self):
56+
"""Verify compute_shape_offset accepts a plain list as spatial_shape."""
57+
shape = [64, 64, 64]
58+
out_shape, _ = compute_shape_offset(shape, self.affine, self.affine)
59+
self.assertEqual(len(out_shape), 3)
60+
61+
def test_torch_tensor_input(self):
62+
"""Verify compute_shape_offset accepts a torch.Tensor as spatial_shape.
63+
64+
This path broke in PyTorch >= 2.9 because np.array() relied on the
65+
non-tuple sequence indexing protocol that PyTorch removed. Wrapping with
66+
tuple() fixes it.
67+
"""
68+
shape = torch.tensor([64, 64, 64])
69+
out_shape, _ = compute_shape_offset(shape, self.affine, self.affine)
70+
self.assertEqual(len(out_shape), 3)
71+
72+
def test_identity_affines_preserve_shape(self):
73+
"""Verify that identity in/out affines produce an output shape matching the input."""
74+
shape = torch.tensor([32, 48, 16])
75+
out_shape, _ = compute_shape_offset(shape, self.affine, self.affine)
76+
np.testing.assert_allclose(np.array(out_shape, dtype=float), shape.numpy().astype(float), atol=1e-5)
77+
4578

4679
if __name__ == "__main__":
4780
unittest.main()

0 commit comments

Comments
 (0)