Skip to content

Commit e68328a

Browse files
lanluo-nvidiawenbinglclaude
authored
cherrypick: split index.Tensor converter for bool vs int indexing (#4123) (#4133)
Co-authored-by: Wenbing Li <10278425+wenbingl@users.noreply.github.com> Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
1 parent 2fb7b37 commit e68328a

2 files changed

Lines changed: 217 additions & 2 deletions

File tree

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 52 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from torch_tensorrt.dynamo.conversion import impl
1717
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
1818
from torch_tensorrt.dynamo.conversion._ConverterRegistry import (
19+
ConverterPriority,
1920
dynamo_tensorrt_converter,
2021
has_static_shapes_in_args,
2122
)
@@ -568,12 +569,29 @@ def index_nonbool_validator(
568569
return True
569570

570571

572+
def index_has_bool_indices(
573+
node: Node, settings: Optional[CompilationSettings] = None
574+
) -> bool:
575+
"""Returns True if any index tensor is boolean."""
576+
index = node.args[1]
577+
for ind in index:
578+
if ind is not None:
579+
val = ind.meta.get("val")
580+
if val is not None and val.dtype == torch.bool:
581+
return True
582+
return False
583+
584+
585+
# Integer indexing: output shape is deterministic (depends on index tensor
586+
# shape, not values), so no output allocator is needed. This is the common
587+
# case and is checked first via HIGH priority.
571588
@dynamo_tensorrt_converter(
572589
torch.ops.aten.index.Tensor,
573590
capability_validator=lambda node, settings: index_dtype_validator(node, settings)
574-
and index_nonbool_validator(node, settings),
591+
and not index_has_bool_indices(node, settings),
592+
priority=ConverterPriority.HIGH,
575593
supports_dynamic_shapes=True,
576-
requires_output_allocator=True,
594+
requires_output_allocator=False,
577595
)
578596
@enforce_tensor_types(
579597
{
@@ -597,6 +615,38 @@ def aten_ops_index(
597615
)
598616

599617

618+
# Boolean indexing: internally uses nonzero() which produces data-dependent
619+
# output shapes, so an output allocator is required.
620+
@dynamo_tensorrt_converter(
621+
torch.ops.aten.index.Tensor,
622+
capability_validator=lambda node, settings: index_dtype_validator(node, settings)
623+
and index_nonbool_validator(node, settings)
624+
and index_has_bool_indices(node, settings),
625+
supports_dynamic_shapes=True,
626+
requires_output_allocator=True,
627+
)
628+
@enforce_tensor_types(
629+
{
630+
0: (TRTTensor,),
631+
}
632+
)
633+
def aten_ops_index_bool(
634+
ctx: ConversionContext,
635+
target: Target,
636+
args: Tuple[Argument, ...],
637+
kwargs: Dict[str, Argument],
638+
name: str,
639+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
640+
return impl.select.index(
641+
ctx,
642+
target,
643+
SourceIR.ATEN,
644+
name,
645+
args[0],
646+
args[1],
647+
)
648+
649+
600650
@dynamo_tensorrt_converter(torch.ops.aten.tanh.default, supports_dynamic_shapes=True)
601651
def aten_ops_tanh(
602652
ctx: ConversionContext,
Lines changed: 165 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,165 @@
1+
"""Tests for the bool/int index.Tensor converter split (commit b91168b).
2+
3+
Verifies that:
4+
1. `index_has_bool_indices` validator correctly distinguishes bool vs int indices.
5+
2. Integer-indexed `aten.index.Tensor` routes to the converter WITHOUT output allocator.
6+
3. Boolean-indexed `aten.index.Tensor` routes to the converter WITH output allocator.
7+
4. Both paths produce correct results.
8+
"""
9+
import unittest
10+
from unittest.mock import MagicMock
11+
12+
import torch
13+
import torch.nn as nn
14+
import torch_tensorrt
15+
from parameterized import parameterized
16+
from torch.testing._internal.common_utils import run_tests
17+
from torch_tensorrt import ENABLED_FEATURES
18+
19+
from torch_tensorrt.dynamo.conversion.aten_ops_converters import (
20+
index_has_bool_indices,
21+
index_nonbool_validator,
22+
)
23+
24+
from .harness import DispatchTestCase
25+
26+
27+
def _make_index_node(indices):
28+
"""Create a mock FX Node whose args[1] contains index tensors with proper metadata."""
29+
mock_indices = []
30+
for idx in indices:
31+
if idx is None:
32+
mock_indices.append(None)
33+
else:
34+
mock_ind = MagicMock()
35+
mock_ind.meta = {"val": idx}
36+
mock_indices.append(mock_ind)
37+
node = MagicMock()
38+
node.args = (MagicMock(), mock_indices)
39+
return node
40+
41+
42+
class TestIndexHasBoolIndicesValidator(unittest.TestCase):
43+
"""Unit tests for the index_has_bool_indices validator function."""
44+
45+
def test_int_indices_returns_false(self):
46+
node = _make_index_node([torch.tensor([0, 1, 2])])
47+
self.assertFalse(index_has_bool_indices(node))
48+
49+
def test_bool_indices_returns_true(self):
50+
node = _make_index_node([torch.tensor([True, False, True])])
51+
self.assertTrue(index_has_bool_indices(node))
52+
53+
def test_none_with_int_indices_returns_false(self):
54+
node = _make_index_node([None, torch.tensor([0, 1])])
55+
self.assertFalse(index_has_bool_indices(node))
56+
57+
def test_none_with_bool_indices_returns_true(self):
58+
node = _make_index_node([None, torch.tensor([True, False])])
59+
self.assertTrue(index_has_bool_indices(node))
60+
61+
def test_mixed_int_and_bool_returns_true(self):
62+
"""If any index is bool, the function should return True."""
63+
node = _make_index_node(
64+
[torch.tensor([0, 1]), torch.tensor([True, False])]
65+
)
66+
self.assertTrue(index_has_bool_indices(node))
67+
68+
def test_all_none_returns_false(self):
69+
node = _make_index_node([None, None])
70+
self.assertFalse(index_has_bool_indices(node))
71+
72+
def test_empty_indices_returns_false(self):
73+
node = _make_index_node([])
74+
self.assertFalse(index_has_bool_indices(node))
75+
76+
77+
class TestIndexNonboolValidatorConsistency(unittest.TestCase):
78+
"""Verify index_nonbool_validator and index_has_bool_indices interact correctly."""
79+
80+
def test_int_index_nonbool_true_has_bool_false(self):
81+
node = _make_index_node([torch.tensor([0, 1])])
82+
self.assertTrue(index_nonbool_validator(node))
83+
self.assertFalse(index_has_bool_indices(node))
84+
85+
def test_bool_index_has_bool_true(self):
86+
node = _make_index_node([torch.tensor([True, False])])
87+
self.assertTrue(index_has_bool_indices(node))
88+
89+
@unittest.skipUnless(
90+
ENABLED_FEATURES.tensorrt_rtx,
91+
"index_nonbool_validator only rejects bool on tensorrt_rtx",
92+
)
93+
def test_bool_index_nonbool_false_on_rtx(self):
94+
node = _make_index_node([torch.tensor([True, False])])
95+
self.assertFalse(index_nonbool_validator(node))
96+
97+
@unittest.skipIf(
98+
ENABLED_FEATURES.tensorrt_rtx,
99+
"On non-RTX, index_nonbool_validator always passes",
100+
)
101+
def test_bool_index_nonbool_true_on_non_rtx(self):
102+
"""On non-RTX, nonbool_validator passes even for bool indices;
103+
the bool/int split is handled by index_has_bool_indices instead."""
104+
node = _make_index_node([torch.tensor([True, False])])
105+
self.assertTrue(index_nonbool_validator(node))
106+
self.assertTrue(index_has_bool_indices(node))
107+
108+
109+
class TestIndexIntConverterNoOutputAllocator(DispatchTestCase):
110+
"""Integer indexing should work correctly (routed to non-output-allocator converter)."""
111+
112+
@parameterized.expand(
113+
[
114+
("int_1d_index", [torch.tensor([0, 1])], torch.randn(3, 4)),
115+
(
116+
"int_2d_with_none",
117+
[None, torch.tensor([0, 1])],
118+
torch.randn(2, 3),
119+
),
120+
(
121+
"int_multi_index",
122+
[torch.tensor([0, 1]), torch.tensor([1, 0])],
123+
torch.randn(3, 3),
124+
),
125+
]
126+
)
127+
def test_int_index(self, _, index, input_tensor):
128+
class IndexModule(nn.Module):
129+
def forward(self, x):
130+
return torch.ops.aten.index.Tensor(x, index)
131+
132+
self.run_test(IndexModule(), [input_tensor])
133+
134+
135+
@unittest.skipIf(
136+
ENABLED_FEATURES.tensorrt_rtx,
137+
"Skipped on tensorrt_rtx due to nonzero not supported",
138+
)
139+
class TestIndexBoolConverterWithOutputAllocator(DispatchTestCase):
140+
"""Boolean indexing should work correctly (routed to output-allocator converter)."""
141+
142+
@parameterized.expand(
143+
[
144+
(
145+
"bool_1d_mask",
146+
[torch.tensor([True, False, True])],
147+
torch.randn(3, 4),
148+
),
149+
(
150+
"bool_mask_with_none",
151+
[None, torch.tensor([True, False])],
152+
torch.randn(2, 2),
153+
),
154+
]
155+
)
156+
def test_bool_index(self, _, index, input_tensor):
157+
class BoolIndexModule(nn.Module):
158+
def forward(self, x):
159+
return torch.ops.aten.index.Tensor(x, index)
160+
161+
self.run_test(BoolIndexModule(), [input_tensor])
162+
163+
164+
if __name__ == "__main__":
165+
run_tests()

0 commit comments

Comments
 (0)