Skip to content

Commit 988cabe

Browse files
committed
removing chunk op, changing the skip test message and remove the dynamic shape check of input in validator
1 parent 885046d commit 988cabe

File tree

3 files changed

+5
-84
lines changed

3 files changed

+5
-84
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 1 addition & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -692,9 +692,7 @@ def aten_ops_softmax(
692692

693693
@dynamo_tensorrt_converter(
694694
torch.ops.aten.split.Tensor,
695-
capability_validator=(
696-
has_static_shapes_in_args([0]) and has_static_shapes_in_args([1])
697-
),
695+
capability_validator=has_static_shapes_in_args([1]),
698696
supports_dynamic_shapes=True,
699697
)
700698
@dynamo_tensorrt_converter(
@@ -905,30 +903,6 @@ def aten_ops_slice(
905903
)
906904

907905

908-
@dynamo_tensorrt_converter(torch.ops.aten.chunk.default)
909-
@enforce_tensor_types(
910-
{
911-
0: (TRTTensor,),
912-
}
913-
)
914-
def aten_ops_chunk(
915-
ctx: ConversionContext,
916-
target: Target,
917-
args: Tuple[Argument, ...],
918-
kwargs: Dict[str, Argument],
919-
name: str,
920-
) -> Union[TRTTensor, Sequence[TRTTensor]]:
921-
return impl.slice.chunk(
922-
ctx,
923-
target,
924-
SourceIR.ATEN,
925-
name,
926-
args[0],
927-
args[1],
928-
args_bounds_check(args, 2, 0),
929-
)
930-
931-
932906
@dynamo_tensorrt_converter(torch.ops.aten.cumsum.default, supports_dynamic_shapes=True)
933907
@enforce_tensor_types(
934908
{

py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py

Lines changed: 0 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -324,61 +324,6 @@ def expand(
324324
return layer.get_output(0)
325325

326326

327-
def chunk(
328-
ctx: ConversionContext,
329-
target: Target,
330-
source_ir: Optional[SourceIR],
331-
name: str,
332-
input: TRTTensor,
333-
chunks: int,
334-
dim: int,
335-
) -> TRTTensor:
336-
if chunks <= 0:
337-
raise RuntimeError(
338-
f"chunk expects `chunks` to be greater than 0, got: {chunks}"
339-
)
340-
341-
shape = input.shape
342-
dim = get_positive_dim(dim, len(shape))
343-
344-
if dim >= len(shape):
345-
raise RuntimeError(
346-
f"chunk expects `dim` to be less than the length of input shape, got: {dim}"
347-
)
348-
349-
dynamic_shape = has_dynamic_shape(input.shape)
350-
if dynamic_shape > 0:
351-
# Check whether slice target dim is dynamic shape dim
352-
assert input.shape[dim] != -1, "Can't chunk on dynamic shape dimension!"
353-
354-
size_dim = shape[dim]
355-
chunk_size = math.ceil(size_dim / chunks)
356-
result = []
357-
start = 0
358-
end = min(start + chunk_size, size_dim)
359-
cnt = 0
360-
361-
while start < end:
362-
result.append(
363-
slice_op(
364-
ctx,
365-
target,
366-
source_ir,
367-
f"{name}_slice_{cnt}",
368-
input,
369-
dim,
370-
start,
371-
end,
372-
1,
373-
)
374-
)
375-
start = end
376-
end = min(start + chunk_size, size_dim)
377-
cnt += 1
378-
379-
return result
380-
381-
382327
def cumsum(
383328
ctx: ConversionContext,
384329
target: Target,

tests/py/dynamo/conversion/test_chunk_aten.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -85,8 +85,10 @@ def forward(self, input):
8585

8686

8787
#######################Dynamic cases#######################
88-
#The tests are skipped for now. Will be addressed once https://github.com/pytorch/pytorch/issues/134663 is addressed
89-
@unittest.skip("Pending aten.split converter. Currently tested by E2E")
88+
# The tests are skipped for now. Will be addressed once https://github.com/pytorch/pytorch/issues/134663 is addressed
89+
@unittest.skip(
90+
"Pending aten.split dynamic input torch.export guard bug. Issue- https://github.com/pytorch/pytorch/issues/134663"
91+
)
9092
class TestChunkDynamicConverter(DispatchTestCase):
9193
@parameterized.expand(
9294
[

0 commit comments

Comments
 (0)