@@ -409,25 +409,25 @@ def scatter_value(
409
409
)
410
410
input_shape = input .shape
411
411
index_shape = index .shape
412
- if (len (input_shape ) != len (index_shape )):
413
- raise RuntimeError (
414
- f"The no of dimensions of input and index should be equal"
415
- )
412
+ if len (input_shape ) != len (index_shape ):
413
+ raise RuntimeError (f"The no of dimensions of input and index should be equal" )
416
414
ranks = len (input_shape )
417
415
dim = get_positive_dim (cast (int , dim ), ranks )
418
416
dynamic_shape = has_dynamic_shape (input .shape )
419
417
if dynamic_shape :
420
418
# Check whether slice target dim is dynamic shape dim
421
419
assert input .shape [dim ] != - 1 , "Can't scatter on negative shape dimension!"
422
-
420
+
423
421
input_dims = len (input .shape )
424
422
for i in range (0 , input_dims ):
425
423
if index [i ] >= input .shape [i ]:
426
424
raise RuntimeError (
427
425
f"cannot have index greater than the dimension length! { input .shape [dim ]} "
428
426
)
429
427
value_tensor = value * torch .ones (index .shape )
430
- scatter_layer = ctx .net .add_scatter (input , index , value_tensor , trt .tensorrt .ScatterModekELEMENT )
428
+ scatter_layer = ctx .net .add_scatter (
429
+ input , index , value_tensor , trt .tensorrt .ScatterModekELEMENT
430
+ )
431
431
scatter_layer .set_axis (dim )
432
432
set_layer_name (scatter_layer , target , name + "_scatter_layer" , source_ir )
433
433
out = scatter_layer .get_output (0 )
@@ -452,28 +452,26 @@ def scatter_src(
452
452
input_shape = input .shape
453
453
index_shape = index .shape
454
454
src_shape = src .shape
455
- if (len (input_shape ) != len (index_shape )):
456
- raise RuntimeError (
457
- f"The no of dimensions of input and index should be equal"
458
- )
459
- if (len (index_shape ) != len (src_shape )):
460
- raise RuntimeError (
461
- f"The no of dimensions of src and index should be equal"
462
- )
463
-
455
+ if len (input_shape ) != len (index_shape ):
456
+ raise RuntimeError (f"The no of dimensions of input and index should be equal" )
457
+ if len (index_shape ) != len (src_shape ):
458
+ raise RuntimeError (f"The no of dimensions of src and index should be equal" )
459
+
464
460
input_dims = len (input_shape )
465
461
dim = get_positive_dim (cast (int , dim ), input_dims )
466
462
dynamic_shape = has_dynamic_shape (input .shape )
467
463
if dynamic_shape :
468
464
# Check whether slice target dim is dynamic shape dim
469
465
assert input .shape [dim ] != - 1 , "Can't scatter on negative shape dimension!"
470
-
466
+
471
467
for i in range (0 , input_dims ):
472
468
if index [i ] >= input .shape [i ]:
473
469
raise RuntimeError (
474
470
f"cannot have index greater than the dimension length! { input .shape [dim ]} "
475
471
)
476
- scatter_layer = ctx .net .add_scatter (input , index , src , trt .tensorrt .ScatterModekELEMENT )
472
+ scatter_layer = ctx .net .add_scatter (
473
+ input , index , src , trt .tensorrt .ScatterModekELEMENT
474
+ )
477
475
scatter_layer .set_axis (dim )
478
476
set_layer_name (scatter_layer , target , name + "_scatter_layer" , source_ir )
479
477
out = scatter_layer .get_output (0 )
0 commit comments