File tree Expand file tree Collapse file tree 3 files changed +5
-84
lines changed
py/torch_tensorrt/dynamo/conversion
tests/py/dynamo/conversion Expand file tree Collapse file tree 3 files changed +5
-84
lines changed Original file line number Diff line number Diff line change @@ -692,9 +692,7 @@ def aten_ops_softmax(
692
692
693
693
@dynamo_tensorrt_converter (
694
694
torch .ops .aten .split .Tensor ,
695
- capability_validator = (
696
- has_static_shapes_in_args ([0 ]) and has_static_shapes_in_args ([1 ])
697
- ),
695
+ capability_validator = has_static_shapes_in_args ([1 ]),
698
696
supports_dynamic_shapes = True ,
699
697
)
700
698
@dynamo_tensorrt_converter (
@@ -905,30 +903,6 @@ def aten_ops_slice(
905
903
)
906
904
907
905
908
- @dynamo_tensorrt_converter (torch .ops .aten .chunk .default )
909
- @enforce_tensor_types (
910
- {
911
- 0 : (TRTTensor ,),
912
- }
913
- )
914
- def aten_ops_chunk (
915
- ctx : ConversionContext ,
916
- target : Target ,
917
- args : Tuple [Argument , ...],
918
- kwargs : Dict [str , Argument ],
919
- name : str ,
920
- ) -> Union [TRTTensor , Sequence [TRTTensor ]]:
921
- return impl .slice .chunk (
922
- ctx ,
923
- target ,
924
- SourceIR .ATEN ,
925
- name ,
926
- args [0 ],
927
- args [1 ],
928
- args_bounds_check (args , 2 , 0 ),
929
- )
930
-
931
-
932
906
@dynamo_tensorrt_converter (torch .ops .aten .cumsum .default , supports_dynamic_shapes = True )
933
907
@enforce_tensor_types (
934
908
{
Original file line number Diff line number Diff line change @@ -324,61 +324,6 @@ def expand(
324
324
return layer .get_output (0 )
325
325
326
326
327
- def chunk (
328
- ctx : ConversionContext ,
329
- target : Target ,
330
- source_ir : Optional [SourceIR ],
331
- name : str ,
332
- input : TRTTensor ,
333
- chunks : int ,
334
- dim : int ,
335
- ) -> TRTTensor :
336
- if chunks <= 0 :
337
- raise RuntimeError (
338
- f"chunk expects `chunks` to be greater than 0, got: { chunks } "
339
- )
340
-
341
- shape = input .shape
342
- dim = get_positive_dim (dim , len (shape ))
343
-
344
- if dim >= len (shape ):
345
- raise RuntimeError (
346
- f"chunk expects `dim` to be less than the length of input shape, got: { dim } "
347
- )
348
-
349
- dynamic_shape = has_dynamic_shape (input .shape )
350
- if dynamic_shape > 0 :
351
- # Check whether slice target dim is dynamic shape dim
352
- assert input .shape [dim ] != - 1 , "Can't chunk on dynamic shape dimension!"
353
-
354
- size_dim = shape [dim ]
355
- chunk_size = math .ceil (size_dim / chunks )
356
- result = []
357
- start = 0
358
- end = min (start + chunk_size , size_dim )
359
- cnt = 0
360
-
361
- while start < end :
362
- result .append (
363
- slice_op (
364
- ctx ,
365
- target ,
366
- source_ir ,
367
- f"{ name } _slice_{ cnt } " ,
368
- input ,
369
- dim ,
370
- start ,
371
- end ,
372
- 1 ,
373
- )
374
- )
375
- start = end
376
- end = min (start + chunk_size , size_dim )
377
- cnt += 1
378
-
379
- return result
380
-
381
-
382
327
def cumsum (
383
328
ctx : ConversionContext ,
384
329
target : Target ,
Original file line number Diff line number Diff line change @@ -85,8 +85,10 @@ def forward(self, input):
85
85
86
86
87
87
#######################Dynamic cases#######################
88
- #The tests are skipped for now. Will be addressed once https://github.com/pytorch/pytorch/issues/134663 is addressed
89
- @unittest .skip ("Pending aten.split converter. Currently tested by E2E" )
88
+ # The tests are skipped for now. Will be addressed once https://github.com/pytorch/pytorch/issues/134663 is addressed
89
+ @unittest .skip (
90
+ "Pending aten.split dynamic input torch.export guard bug. Issue- https://github.com/pytorch/pytorch/issues/134663"
91
+ )
90
92
class TestChunkDynamicConverter (DispatchTestCase ):
91
93
@parameterized .expand (
92
94
[
You can’t perform that action at this time.
0 commit comments