File tree Expand file tree Collapse file tree 2 files changed +20
-1
lines changed
py/torch_tensorrt/dynamo/conversion
tests/py/dynamo/conversion Expand file tree Collapse file tree 2 files changed +20
-1
lines changed Original file line number Diff line number Diff line change 17
17
from torch_tensorrt .dynamo .conversion .converter_utils import (
18
18
enforce_tensor_types ,
19
19
get_positive_dim ,
20
+ has_dynamic_shape ,
20
21
is_only_operator_on_placeholder ,
21
22
)
22
23
from torch_tensorrt .fx .types import TRTTensor
@@ -903,7 +904,22 @@ def aten_ops_slice(
903
904
)
904
905
905
906
906
- @dynamo_tensorrt_converter (torch .ops .aten .chunk .default )
907
+ def chunk_validator (node : Node ) -> bool :
908
+ meta_data = node .args [0 ].meta .get ("tensor_meta" )
909
+ if meta_data is None :
910
+ return False
911
+ shape = meta_data .shape
912
+ dynamic_shape = has_dynamic_shape (shape )
913
+ if dynamic_shape :
914
+ return False
915
+ return True
916
+
917
+
918
+ @dynamo_tensorrt_converter (
919
+ torch .ops .aten .chunk .default ,
920
+ supports_dynamic_shapes = True ,
921
+ capability_validator = chunk_validator ,
922
+ )
907
923
@enforce_tensor_types (
908
924
{
909
925
0 : (TRTTensor ,),
Original file line number Diff line number Diff line change @@ -27,6 +27,7 @@ def forward(self, input):
27
27
self .run_test (
28
28
TestChunk (),
29
29
input ,
30
+ enable_passes = True ,
30
31
)
31
32
32
33
@parameterized .expand (
@@ -51,6 +52,7 @@ def forward(self, input):
51
52
self .run_test (
52
53
TestChunk (),
53
54
input ,
55
+ enable_passes = True ,
54
56
)
55
57
56
58
@parameterized .expand (
@@ -75,6 +77,7 @@ def forward(self, input):
75
77
self .run_test (
76
78
TestChunk (),
77
79
input ,
80
+ enable_passes = True ,
78
81
)
79
82
80
83
You can’t perform that action at this time.
0 commit comments