Skip to content

Commit cd5815b

Browse files
committed
chunk_validator
1 parent 4dbeafd commit cd5815b

File tree

2 files changed

+106
-1
lines changed

2 files changed

+106
-1
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -692,7 +692,9 @@ def aten_ops_softmax(
692692

693693
@dynamo_tensorrt_converter(
694694
torch.ops.aten.split.Tensor,
695-
capability_validator=has_static_shapes_in_args([1]),
695+
capability_validator=(
696+
has_static_shapes_in_args([0]) and has_static_shapes_in_args([1])
697+
),
696698
supports_dynamic_shapes=True,
697699
)
698700
@dynamo_tensorrt_converter(

tests/py/dynamo/conversion/test_chunk_aten.py

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
1+
import unittest
2+
13
import torch
24
from parameterized import parameterized
35
from torch.testing._internal.common_utils import run_tests
6+
from torch_tensorrt import Input
47

58
from .harness import DispatchTestCase
69

@@ -27,6 +30,7 @@ def forward(self, input):
2730
self.run_test(
2831
TestChunk(),
2932
input,
33+
use_dynamo_tracer=True,
3034
)
3135

3236
@parameterized.expand(
@@ -51,6 +55,7 @@ def forward(self, input):
5155
self.run_test(
5256
TestChunk(),
5357
input,
58+
use_dynamo_tracer=True,
5459
)
5560

5661
@parameterized.expand(
@@ -75,6 +80,104 @@ def forward(self, input):
7580
self.run_test(
7681
TestChunk(),
7782
input,
83+
use_dynamo_tracer=True,
84+
)
85+
86+
87+
#######################Dynamic cases################
88+
####The tests are skipped for now. Will be addressed once https://github.com/pytorch/pytorch/issues/134663 is addressed
89+
@unittest.skip("Pending aten.split converter. Currently tested by E2E")
90+
class TestChunkDynamicConverter(DispatchTestCase):
91+
@parameterized.expand(
92+
[
93+
((1,), (1,), (3,), 3, 0),
94+
((3,), (3,), (4,), 3, 0),
95+
((4,), (4,), (6,), 3, 0),
96+
((6,), (6,), (9,), 3, 0),
97+
((3,), (3,), (4,), 1, -1),
98+
((3,), (3,), (4,), 3, -1),
99+
((3,), (3,), (4,), 4, -1),
100+
]
101+
)
102+
def test_chunk_1D(self, min_shape, opt_shape, max_shape, chunks, dim):
103+
class TestChunk(torch.nn.Module):
104+
def forward(self, input):
105+
out = torch.ops.aten.chunk.default(input, chunks, dim)
106+
return out
107+
108+
input_specs = [
109+
Input(
110+
min_shape=min_shape,
111+
opt_shape=opt_shape,
112+
max_shape=max_shape,
113+
),
114+
]
115+
self.run_test_with_dynamic_shape(
116+
TestChunk(),
117+
input_specs,
118+
use_dynamo_tracer=True,
119+
)
120+
121+
@parameterized.expand(
122+
[
123+
((3, 4), (3, 4), (4, 4), 1, 0),
124+
((3, 4), (3, 4), (4, 4), 3, 0),
125+
((3, 4), (3, 4), (4, 4), 4, 0),
126+
((3, 4), (3, 4), (4, 4), 2, -2),
127+
((3, 4), (3, 4), (4, 4), 6, -2),
128+
((3, 4), (3, 4), (4, 4), 3, 1),
129+
((3, 4), (3, 4), (4, 4), 4, 1),
130+
((3, 4), (3, 4), (4, 4), 5, -1),
131+
]
132+
)
133+
def test_chunk_2D(self, min_shape, opt_shape, max_shape, chunks, dim):
134+
class TestChunk(torch.nn.Module):
135+
def forward(self, input):
136+
out = torch.ops.aten.chunk.default(input, chunks, dim)
137+
return out
138+
139+
input_specs = [
140+
Input(
141+
min_shape=min_shape,
142+
opt_shape=opt_shape,
143+
max_shape=max_shape,
144+
),
145+
]
146+
self.run_test_with_dynamic_shape(
147+
TestChunk(),
148+
input_specs,
149+
use_dynamo_tracer=True,
150+
)
151+
152+
@parameterized.expand(
153+
[
154+
((3, 4, 2), (3, 4, 2), (4, 4, 2), 1, 0),
155+
((3, 4, 2), (3, 4, 2), (4, 4, 2), 3, -3),
156+
((3, 4, 2), (3, 4, 2), (4, 4, 2), 3, 1),
157+
((3, 4, 2), (3, 4, 2), (4, 4, 2), 4, 1),
158+
((3, 4, 2), (3, 4, 2), (4, 4, 2), 6, -2),
159+
((3, 4, 2), (3, 4, 2), (4, 4, 2), 1, 2),
160+
((3, 4, 2), (3, 4, 2), (4, 4, 2), 3, -1),
161+
((3, 4, 2), (3, 4, 2), (4, 4, 2), 4, -1),
162+
]
163+
)
164+
def test_chunk_3D(self, min_shape, opt_shape, max_shape, chunks, dim):
165+
class TestChunk(torch.nn.Module):
166+
def forward(self, input):
167+
out = torch.ops.aten.chunk.default(input, chunks, dim)
168+
return out
169+
170+
input_specs = [
171+
Input(
172+
min_shape=min_shape,
173+
opt_shape=opt_shape,
174+
max_shape=max_shape,
175+
),
176+
]
177+
self.run_test_with_dynamic_shape(
178+
TestChunk(),
179+
input_specs,
180+
use_dynamo_tracer=True,
78181
)
79182

80183

0 commit comments

Comments
 (0)