Skip to content

Commit 3f4a2e8

Browse files
committed
Addressing review comments and checking in tests
1 parent 6cf1d67 commit 3f4a2e8

File tree

3 files changed

+171
-41
lines changed

3 files changed

+171
-41
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -352,22 +352,45 @@ def aten_ops_softmax(
352352
return impl.normalization.softmax(
353353
network, target, SourceIR.ATEN, name, args[0], args[1]
354354
)
355+
356+
def dynamic_unsupported_split(node: torch.fx.Node) -> bool:
357+
# Validate that none of the inputs to the node have Dynamic shapes
358+
assert isinstance(
359+
node, torch.fx.Node
360+
), "Inputs to validator functions must be FX Nodes"
361+
362+
if isinstance(node.args[1], torch.fx.Node):
363+
if getattr(node.args[1].meta["val"], "_has_symbolic_sizes_strides", True):
364+
return False
365+
return True
355366

356367

357368
@dynamo_tensorrt_converter(
358-
torch.ops.aten.split.default, capability_validator=dynamic_unsupported
369+
torch.ops.aten.split.Tensor, capability_validator=dynamic_unsupported_split
370+
)
371+
@dynamo_tensorrt_converter(
372+
torch.ops.aten.split.sizes, capability_validator=dynamic_unsupported_split
359373
)
360374
@dynamo_tensorrt_converter(
361-
torch.ops.aten.split.sizes, capability_validator=dynamic_unsupported
375+
torch.ops.aten.split_with_sizes.default,
376+
capability_validator=dynamic_unsupported_split,
362377
)
363378
def aten_ops_split(
364379
network: TRTNetwork,
365380
target: Target,
366381
args: Tuple[Argument, ...],
367382
kwargs: Dict[str, Argument],
368383
name: str,
369-
) -> Union[TRTTensor, Sequence[TRTTensor]]:
370-
return impl.split(network, target, SourceIR.ATEN, name, args[0], args[1], args[2])
384+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
385+
return impl.split.split(
386+
network,
387+
target,
388+
SourceIR.ATEN,
389+
name,
390+
input=args[0],
391+
split_size_or_sections=args[1],
392+
dim=args_bounds_check(args, 2, 0),
393+
)
371394

372395

373396
@dynamo_tensorrt_converter(torch.ops.aten.where.self) # type: ignore[misc]

py/torch_tensorrt/dynamo/conversion/impl/split.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@ def split(
2020
source_ir: Optional[SourceIR],
2121
name: str,
2222
input: TRTTensor,
23-
split_size_or_sections: Union[int, List(int)],
24-
dim: Optional[Any] = 0,
23+
split_size_or_sections: Union[int, List[int]],
24+
dim: int = 0,
2525
) -> Union[TRTTensor, Sequence[TRTTensor]]:
2626
if not isinstance(input, TRTTensor):
2727
raise RuntimeError(
@@ -30,16 +30,12 @@ def split(
3030

3131
dim = cast(int, dim)
3232
dynamic_shape = has_dynamic_shape(input.shape)
33-
if network.has_implicit_batch_dimension:
34-
assert dim != 0, "Can't split on batch dim when it's implicit!"
35-
dim -= 1
36-
else:
37-
if dynamic_shape > 0:
38-
# Check whether slice target dim is dynamic shape dim
39-
assert input.shape[dim] != -1, "Can't chunk on dynamic shape dimension!"
33+
if dynamic_shape > 0:
34+
# Check whether slice target dim is dynamic shape dim
35+
assert input.shape[dim] != -1, "Can't chunk on dynamic shape dimension!"
4036

4137
split_sizes = []
42-
if type(split_size_or_sections) == int:
38+
if isinstance(split_size_or_sections, int):
4339
split_sizes.append(cast(int, split_size_or_sections))
4440
else:
4541
for split_size_or_section in split_size_or_sections:
@@ -48,12 +44,16 @@ def split(
4844
start = [0] * len(input.shape)
4945
stride = [1] * len(start)
5046
offset = 0
51-
5247
if len(split_sizes) == 1:
53-
num_splits = input.shape[dim] + split_sizes[0] - 1 // split_sizes[0]
48+
num_splits = (input.shape[dim] + split_sizes[0] - 1) // split_sizes[0]
5449
split_sizes = [split_sizes[0]] * num_splits
5550
else:
5651
num_splits = len(split_sizes)
52+
sum_split_sizes = sum(split_sizes)
53+
if sum_split_sizes != input.shape[dim]:
54+
raise RuntimeError(
55+
f"split sizes don't add up to the tensor's size in the given dimension"
56+
)
5757

5858
if num_splits < 1:
5959
raise RuntimeError(
@@ -69,7 +69,7 @@ def split(
6969
start[dim] = offset
7070
if dynamic_shape:
7171
shape = get_shape_with_dynamic_shape(
72-
network, shape, input, target, f"{name}_shape_{i}"
72+
network, target, source_ir, f"{name}_shape_{i}", shape, input
7373
)
7474
layer = network.add_slice(
7575
input, start=start, shape=[] if dynamic_shape else shape, stride=stride

tests/py/dynamo/converters/test_split_aten.py

Lines changed: 131 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,14 @@
33
from parameterized import parameterized
44
from torch.testing._internal.common_utils import run_tests
55
from torch_tensorrt import Input
6+
from torch_tensorrt.dynamo.conversion import UnsupportedOperatorException
67

78

89
# FIXME: check about implicit and explicit batch
910
class TestSplitConverterNoDim(DispatchTestCase):
1011
@parameterized.expand(
1112
[
1213
("split_size_or_sections_no_dim", 2),
13-
("split_size_or_sections_list_no_dim", [1, 4]),
14-
("split_size_or_sections_list_no_dim_not_full_split", [1, 3]),
1514
]
1615
)
1716
def test_split(self, _, split_size_or_tensor):
@@ -23,20 +22,62 @@ def forward(self, input):
2322
out = torch.split(input, split_size_or_tensor)
2423
return out
2524

26-
input = torch.arange(10).reshape(5, 2)
25+
input = [torch.randn(10).reshape(5, 2)]
2726
self.run_test(
2827
TestModule(),
2928
input,
30-
expected_ops={torch.ops.aten.split.default},
29+
expected_ops={torch.ops.aten.split.Tensor},
30+
disable_passes=True,
3131
)
3232

33+
@parameterized.expand(
34+
[
35+
("split_size_or_sections_list_no_dim_list", [1, 4]),
36+
]
37+
)
38+
def test_split_list(self, _, split_size_or_tensor):
39+
class TestModule(torch.nn.Module):
40+
def __init__(self):
41+
super().__init__()
42+
43+
def forward(self, input):
44+
out = torch.split(input, split_size_or_tensor)
45+
return out
46+
47+
input = [torch.randn(10).reshape(5, 2)]
48+
self.run_test(
49+
TestModule(),
50+
input,
51+
expected_ops={torch.ops.aten.split_with_sizes.default},
52+
disable_passes=True,
53+
)
3354

34-
class TestSplitConverterWithDim(DispatchTestCase):
3555
@parameterized.expand(
3656
[
37-
("split_size_or_sections_dim", 2, 1),
38-
("split_size_or_sections_list_dim", [1, 4], 1),
39-
("split_size_or_sections_list_dim_not_full_split", [1, 3], 1),
57+
("split_size_or_sections_list_no_dim_not_full_list", [1, 3]),
58+
]
59+
)
60+
def test_split_not_full_list(self, _, split_size_or_tensor):
61+
class TestModule(torch.nn.Module):
62+
def __init__(self):
63+
super().__init__()
64+
65+
def forward(self, input):
66+
out = torch.split(input, split_size_or_tensor)
67+
return out
68+
69+
input = [torch.randn(10).reshape(5, 2)]
70+
with self.assertRaises(RuntimeError):
71+
self.run_test(
72+
TestModule(),
73+
input,
74+
expected_ops={torch.ops.aten.split_with_sizes.default},
75+
disable_passes=True,
76+
)
77+
78+
@parameterized.expand(
79+
[
80+
("split_size_or_sections_dims", 2, 1),
4081
]
4182
)
4283
def test_split(self, _, split_size_or_tensor, dim):
@@ -45,25 +86,90 @@ def __init__(self):
4586
super().__init__()
4687

4788
def forward(self, input):
48-
out = torch.split(split_size_or_tensor, dim)
89+
out = torch.split(input, split_size_or_tensor, dim)
90+
return out
91+
92+
input = [torch.randn(10).reshape(5, 2)]
93+
self.run_test(
94+
TestModule(),
95+
input,
96+
expected_ops={torch.ops.aten.split.Tensor},
97+
disable_passes=True,
98+
)
99+
100+
@parameterized.expand(
101+
[
102+
("split_size_or_sections_list_dims", [1, 1], 1),
103+
]
104+
)
105+
def test_split_dim(self, _, split_size_or_tensor, dim):
106+
class TestModule(torch.nn.Module):
107+
def __init__(self):
108+
super().__init__()
109+
110+
def forward(self, input):
111+
out = torch.split(input, split_size_or_tensor, dim)
49112
return out
50113

51-
input = torch.arange(10).reshape(2, 5)
114+
input = [torch.randn(10).reshape(5, 2)]
52115
self.run_test(
53116
TestModule(),
54117
input,
55-
expected_ops={torch.ops.aten.split.default},
118+
expected_ops={torch.ops.aten.split_with_sizes.default},
119+
disable_passes=True,
56120
)
57121

122+
@parameterized.expand(
123+
[
124+
("split_size_or_sections_list_dims", [1, 1], 1),
125+
]
126+
)
127+
def test_split_dim_list(self, _, split_size_or_tensor, dim):
128+
class TestModule(torch.nn.Module):
129+
def __init__(self):
130+
super().__init__()
131+
132+
def forward(self, input):
133+
out = torch.split(input, split_size_or_tensor, dim)
134+
return out
135+
136+
input = [torch.randn(10).reshape(5, 2)]
137+
self.run_test(
138+
TestModule(),
139+
input,
140+
expected_ops={torch.ops.aten.split_with_sizes.default},
141+
disable_passes=True,
142+
)
58143

59-
class TestSplitConverterDynamicShape(DispatchTestCase):
60144
@parameterized.expand(
61145
[
62-
("select_split_size_or_sections_dim", 2, 1),
63-
("select_split_size_or_sections_list_dim", [1, 4], 1),
146+
("split_size_or_sections_list_dims_not_full_list", [1, 1], 1),
64147
]
65148
)
66-
def test_split(self, _, split_size_or_tensor, dim):
149+
def test_split_dim_list(self, _, split_size_or_tensor, dim):
150+
class TestModule(torch.nn.Module):
151+
def __init__(self):
152+
super().__init__()
153+
154+
def forward(self, input):
155+
out = torch.split(input, split_size_or_tensor, dim)
156+
return out
157+
158+
input = [torch.randn(15).reshape(5, 3)]
159+
with self.assertRaises(RuntimeError):
160+
self.run_test(
161+
TestModule(),
162+
input,
163+
expected_ops={torch.ops.aten.split_with_sizes.default},
164+
disable_passes=True,
165+
)
166+
167+
@parameterized.expand(
168+
[
169+
("select_split_size_or_sections_dim_dynamic_shape", 2, 1),
170+
]
171+
)
172+
def test_split_dynamic(self, _, split_size_or_tensor, dim):
67173
class TestModule(torch.nn.Module):
68174
def __init__(self):
69175
super().__init__()
@@ -82,17 +188,16 @@ def forward(self, input):
82188
self.run_test_with_dynamic_shape(
83189
TestModule(),
84190
input_specs,
85-
expected_ops={torch.ops.aten.split.default},
191+
expected_ops={torch.ops.aten.split.Tensor},
192+
disable_passes=True,
86193
)
87194

88-
89-
class TestSplitSymIntConverterImplicitBatch(DispatchTestCase):
90195
@parameterized.expand(
91196
[
92197
("select_chunk_dim", 6, 0),
93198
]
94199
)
95-
def test_chunk(self, _, chunk, dim):
200+
def test_split_dynamic(self, _, chunk, dim):
96201
class TestModule(torch.nn.Module):
97202
def __init__(self):
98203
super().__init__()
@@ -102,11 +207,13 @@ def forward(self, input):
102207
return out
103208

104209
input = [torch.randn(11)]
105-
self.run_test(
106-
TestModule(),
107-
input,
108-
expected_ops={torch.ops.aten.split.default},
109-
)
210+
with self.assertRaises(UnsupportedOperatorException):
211+
self.run_test(
212+
TestModule(),
213+
input,
214+
expected_ops={torch.ops.aten.split.Tensor},
215+
disable_passes=True,
216+
)
110217

111218

112219
if __name__ == "__main__":

0 commit comments

Comments
 (0)