Skip to content

Commit 16c4e8a

Browse files
committed
chunk_validator
1 parent 4dbeafd commit 16c4e8a

File tree

2 files changed

+20
-1
lines changed

2 files changed

+20
-1
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from torch_tensorrt.dynamo.conversion.converter_utils import (
1818
enforce_tensor_types,
1919
get_positive_dim,
20+
has_dynamic_shape,
2021
is_only_operator_on_placeholder,
2122
)
2223
from torch_tensorrt.fx.types import TRTTensor
@@ -903,7 +904,22 @@ def aten_ops_slice(
903904
)
904905

905906

906-
@dynamo_tensorrt_converter(torch.ops.aten.chunk.default)
907+
def chunk_validator(node: Node) -> bool:
908+
meta_data = node.args[0].meta.get("tensor_meta")
909+
if meta_data is None:
910+
return False
911+
shape = meta_data.shape
912+
dynamic_shape = has_dynamic_shape(shape)
913+
if dynamic_shape:
914+
return False
915+
return True
916+
917+
918+
@dynamo_tensorrt_converter(
919+
torch.ops.aten.chunk.default,
920+
supports_dynamic_shapes=True,
921+
capability_validator=chunk_validator,
922+
)
907923
@enforce_tensor_types(
908924
{
909925
0: (TRTTensor,),

tests/py/dynamo/conversion/test_chunk_aten.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ def forward(self, input):
2727
self.run_test(
2828
TestChunk(),
2929
input,
30+
enable_passes=True,
3031
)
3132

3233
@parameterized.expand(
@@ -51,6 +52,7 @@ def forward(self, input):
5152
self.run_test(
5253
TestChunk(),
5354
input,
55+
enable_passes=True,
5456
)
5557

5658
@parameterized.expand(
@@ -75,6 +77,7 @@ def forward(self, input):
7577
self.run_test(
7678
TestChunk(),
7779
input,
80+
enable_passes=True,
7881
)
7982

8083

0 commit comments

Comments
 (0)