@@ -406,46 +406,34 @@ def acc_ops_pad_with_slice_layer(
406
406
)
407
407
408
408
input_shape = input_val .shape
409
- pre_start = tuple (i - 1 for i in input_shape )
410
409
prefix_len = len (input_shape ) - len (pad ) // 2
411
- pre_shape = tuple (
412
- input_shape [ i ] + ( pad [- (i - prefix_len ) * 2 - 2 ] if i >= prefix_len else 0 )
410
+ start = tuple (
411
+ - pad [- (i - prefix_len ) * 2 - 2 ] if i >= prefix_len else 0
413
412
for i in range (0 , len (input_shape ))
414
413
)
415
- pre_stride = [- 1 ] * len (input_shape )
414
+
415
+ shape = tuple (
416
+ input_shape [i ]
417
+ + (
418
+ pad [- (i - prefix_len ) * 2 - 1 ] + pad [- (i - prefix_len ) * 2 - 2 ]
419
+ if i >= prefix_len
420
+ else 0
421
+ )
422
+ for i in range (0 , len (input_shape ))
423
+ )
424
+ stride = tuple ([1 ] * len (shape ))
416
425
417
426
layer = network .add_slice (
418
427
input_val ,
419
- pre_start ,
420
- pre_shape ,
421
- pre_stride ,
428
+ start ,
429
+ shape ,
430
+ stride ,
422
431
)
423
- layer .set_input (4 , value_const )
424
- layer .mode = trt .SliceMode .FILL
425
- set_layer_name (layer , target , f"pre_{ name } " )
426
- half_pad_output = layer .get_output (0 )
427
432
428
- shape = half_pad_output .shape
429
- mid_start = tuple (i - 1 for i in shape )
430
- mid_stride = [- 1 ] * len (shape )
431
- layer = network .add_slice (half_pad_output , mid_start , shape , mid_stride )
432
433
layer .set_input (4 , value_const )
433
434
layer .mode = trt .SliceMode .FILL
434
- set_layer_name (layer , target , f"transpose_{ name } " )
435
- transpose_output = layer .get_output (0 )
436
-
437
- shape = transpose_output .shape
438
- post_start = tuple ([0 ] * len (shape ))
439
- post_shape = tuple (
440
- shape [i ] + (pad [- (i - prefix_len ) * 2 - 1 ] if i >= prefix_len else 0 )
441
- for i in range (0 , len (shape ))
442
- )
443
- post_stride = tuple ([1 ] * len (shape ))
435
+ set_layer_name (layer , target , name )
444
436
445
- layer = network .add_slice (transpose_output , post_start , post_shape , post_stride )
446
- layer .set_input (4 , value_const )
447
- layer .mode = trt .SliceMode .FILL
448
- set_layer_name (layer , target , f"post_{ name } " )
449
437
return layer .get_output (0 )
450
438
451
439
0 commit comments