Skip to content

Commit b4b22c3

Browse files
authored
chunk converter validator (#3120)
1 parent 29b4913 commit b4b22c3

File tree

3 files changed

+105
-79
lines changed

3 files changed

+105
-79
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 0 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -924,30 +924,6 @@ def aten_ops_slice(
924924
)
925925

926926

927-
@dynamo_tensorrt_converter(torch.ops.aten.chunk.default)
928-
@enforce_tensor_types(
929-
{
930-
0: (TRTTensor,),
931-
}
932-
)
933-
def aten_ops_chunk(
934-
ctx: ConversionContext,
935-
target: Target,
936-
args: Tuple[Argument, ...],
937-
kwargs: Dict[str, Argument],
938-
name: str,
939-
) -> Union[TRTTensor, Sequence[TRTTensor]]:
940-
return impl.slice.chunk(
941-
ctx,
942-
target,
943-
SourceIR.ATEN,
944-
name,
945-
args[0],
946-
args[1],
947-
args_bounds_check(args, 2, 0),
948-
)
949-
950-
951927
@dynamo_tensorrt_converter(torch.ops.aten.cumsum.default, supports_dynamic_shapes=True)
952928
@enforce_tensor_types(
953929
{

py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py

Lines changed: 0 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -324,61 +324,6 @@ def expand(
324324
return layer.get_output(0)
325325

326326

327-
def chunk(
328-
ctx: ConversionContext,
329-
target: Target,
330-
source_ir: Optional[SourceIR],
331-
name: str,
332-
input: TRTTensor,
333-
chunks: int,
334-
dim: int,
335-
) -> TRTTensor:
336-
if chunks <= 0:
337-
raise RuntimeError(
338-
f"chunk expects `chunks` to be greater than 0, got: {chunks}"
339-
)
340-
341-
shape = input.shape
342-
dim = get_positive_dim(dim, len(shape))
343-
344-
if dim >= len(shape):
345-
raise RuntimeError(
346-
f"chunk expects `dim` to be less than the length of input shape, got: {dim}"
347-
)
348-
349-
dynamic_shape = has_dynamic_shape(input.shape)
350-
if dynamic_shape > 0:
351-
# Check whether slice target dim is dynamic shape dim
352-
assert input.shape[dim] != -1, "Can't chunk on dynamic shape dimension!"
353-
354-
size_dim = shape[dim]
355-
chunk_size = math.ceil(size_dim / chunks)
356-
result = []
357-
start = 0
358-
end = min(start + chunk_size, size_dim)
359-
cnt = 0
360-
361-
while start < end:
362-
result.append(
363-
slice_op(
364-
ctx,
365-
target,
366-
source_ir,
367-
f"{name}_slice_{cnt}",
368-
input,
369-
dim,
370-
start,
371-
end,
372-
1,
373-
)
374-
)
375-
start = end
376-
end = min(start + chunk_size, size_dim)
377-
cnt += 1
378-
379-
return result
380-
381-
382327
def cumsum(
383328
ctx: ConversionContext,
384329
target: Target,

tests/py/dynamo/conversion/test_chunk_aten.py

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
1+
import unittest
2+
13
import torch
24
from parameterized import parameterized
35
from torch.testing._internal.common_utils import run_tests
6+
from torch_tensorrt import Input
47

58
from .harness import DispatchTestCase
69

@@ -27,6 +30,7 @@ def forward(self, input):
2730
self.run_test(
2831
TestChunk(),
2932
input,
33+
use_dynamo_tracer=True,
3034
)
3135

3236
@parameterized.expand(
@@ -51,6 +55,7 @@ def forward(self, input):
5155
self.run_test(
5256
TestChunk(),
5357
input,
58+
use_dynamo_tracer=True,
5459
)
5560

5661
@parameterized.expand(
@@ -75,6 +80,106 @@ def forward(self, input):
7580
self.run_test(
7681
TestChunk(),
7782
input,
83+
use_dynamo_tracer=True,
84+
)
85+
86+
87+
#######################Dynamic cases#######################
88+
# The tests are skipped for now. Will be addressed once https://github.com/pytorch/pytorch/issues/134663 is addressed
89+
@unittest.skip(
90+
"Pending aten.split dynamic input torch.export guard bug. Issue- https://github.com/pytorch/pytorch/issues/134663"
91+
)
92+
class TestChunkDynamicConverter(DispatchTestCase):
93+
@parameterized.expand(
94+
[
95+
((1,), (1,), (3,), 3, 0),
96+
((3,), (3,), (4,), 3, 0),
97+
((4,), (4,), (6,), 3, 0),
98+
((6,), (6,), (9,), 3, 0),
99+
((3,), (3,), (4,), 1, -1),
100+
((3,), (3,), (4,), 3, -1),
101+
((3,), (3,), (4,), 4, -1),
102+
]
103+
)
104+
def test_chunk_1D(self, min_shape, opt_shape, max_shape, chunks, dim):
105+
class TestChunk(torch.nn.Module):
106+
def forward(self, input):
107+
out = torch.ops.aten.chunk.default(input, chunks, dim)
108+
return out
109+
110+
input_specs = [
111+
Input(
112+
min_shape=min_shape,
113+
opt_shape=opt_shape,
114+
max_shape=max_shape,
115+
),
116+
]
117+
self.run_test_with_dynamic_shape(
118+
TestChunk(),
119+
input_specs,
120+
use_dynamo_tracer=True,
121+
)
122+
123+
@parameterized.expand(
124+
[
125+
((3, 4), (3, 4), (4, 4), 1, 0),
126+
((3, 4), (3, 4), (4, 4), 3, 0),
127+
((3, 4), (3, 4), (4, 4), 4, 0),
128+
((3, 4), (3, 4), (4, 4), 2, -2),
129+
((3, 4), (3, 4), (4, 4), 6, -2),
130+
((3, 4), (3, 4), (4, 4), 3, 1),
131+
((3, 4), (3, 4), (4, 4), 4, 1),
132+
((3, 4), (3, 4), (4, 4), 5, -1),
133+
]
134+
)
135+
def test_chunk_2D(self, min_shape, opt_shape, max_shape, chunks, dim):
136+
class TestChunk(torch.nn.Module):
137+
def forward(self, input):
138+
out = torch.ops.aten.chunk.default(input, chunks, dim)
139+
return out
140+
141+
input_specs = [
142+
Input(
143+
min_shape=min_shape,
144+
opt_shape=opt_shape,
145+
max_shape=max_shape,
146+
),
147+
]
148+
self.run_test_with_dynamic_shape(
149+
TestChunk(),
150+
input_specs,
151+
use_dynamo_tracer=True,
152+
)
153+
154+
@parameterized.expand(
155+
[
156+
((3, 4, 2), (3, 4, 2), (4, 4, 2), 1, 0),
157+
((3, 4, 2), (3, 4, 2), (4, 4, 2), 3, -3),
158+
((3, 4, 2), (3, 4, 2), (4, 4, 2), 3, 1),
159+
((3, 4, 2), (3, 4, 2), (4, 4, 2), 4, 1),
160+
((3, 4, 2), (3, 4, 2), (4, 4, 2), 6, -2),
161+
((3, 4, 2), (3, 4, 2), (4, 4, 2), 1, 2),
162+
((3, 4, 2), (3, 4, 2), (4, 4, 2), 3, -1),
163+
((3, 4, 2), (3, 4, 2), (4, 4, 2), 4, -1),
164+
]
165+
)
166+
def test_chunk_3D(self, min_shape, opt_shape, max_shape, chunks, dim):
167+
class TestChunk(torch.nn.Module):
168+
def forward(self, input):
169+
out = torch.ops.aten.chunk.default(input, chunks, dim)
170+
return out
171+
172+
input_specs = [
173+
Input(
174+
min_shape=min_shape,
175+
opt_shape=opt_shape,
176+
max_shape=max_shape,
177+
),
178+
]
179+
self.run_test_with_dynamic_shape(
180+
TestChunk(),
181+
input_specs,
182+
use_dynamo_tracer=True,
78183
)
79184

80185

0 commit comments

Comments
 (0)