@@ -390,3 +390,76 @@ def index_select(
390
390
set_layer_name (gather_layer , target , f"{ name } _gather" , source_ir )
391
391
392
392
return gather_layer .get_output (0 )
393
+
394
+
395
+ def scatter_value (
396
+ ctx : ConversionContext ,
397
+ target : Target ,
398
+ source_ir : Optional [SourceIR ],
399
+ name : str ,
400
+ input : TRTTensor ,
401
+ dim : Shape ,
402
+ index : Shape ,
403
+ value : TRTTensor ,
404
+ ) -> TRTTensor :
405
+ if not isinstance (input , TRTTensor ):
406
+ raise RuntimeError (
407
+ f"scatter_tensor received input { input } that is not part "
408
+ "of the TensorRT region!"
409
+ )
410
+
411
+ ranks = len (input .shape )
412
+ dim = get_positive_dim (cast (int , dim ), ranks )
413
+ dynamic_shape = has_dynamic_shape (input .shape )
414
+ if dynamic_shape :
415
+ # Check whether slice target dim is dynamic shape dim
416
+ assert input .shape [dim ] != - 1 , "Can't select on negative shape dimension!"
417
+
418
+ input_dims = len (input .shape )
419
+ for i in range (0 , input_dims ):
420
+ if index [i ] >= input .shape [i ]:
421
+ raise RuntimeError (
422
+ f"cannot have index greater than the dimension length! { input .shape [dim ]} "
423
+ )
424
+ value_tensor = value * torch .ones (index .shape )
425
+ scatter_layer = ctx .net .add_scatter (input , index , value_tensor , trt .tensorrt .ScatterModekELEMENT )
426
+ scatter_layer .set_axis (dim )
427
+ set_layer_name (scatter_layer , target , name + "_scatter_layer" , source_ir )
428
+ out = scatter_layer .get_output (0 )
429
+ return out
430
+
431
+
432
+ def scatter_src (
433
+ ctx : ConversionContext ,
434
+ target : Target ,
435
+ source_ir : Optional [SourceIR ],
436
+ name : str ,
437
+ input : TRTTensor ,
438
+ dim : Shape ,
439
+ index : Shape ,
440
+ src : float ,
441
+ ) -> TRTTensor :
442
+ if not isinstance (input , TRTTensor ):
443
+ raise RuntimeError (
444
+ f"scatter_tensor received input { input } that is not part "
445
+ "of the TensorRT region!"
446
+ )
447
+
448
+ ranks = len (input .shape )
449
+ dim = get_positive_dim (cast (int , dim ), ranks )
450
+ dynamic_shape = has_dynamic_shape (input .shape )
451
+ if dynamic_shape :
452
+ # Check whether slice target dim is dynamic shape dim
453
+ assert input .shape [dim ] != - 1 , "Can't select on negative shape dimension!"
454
+
455
+ input_dims = len (input .shape )
456
+ for i in range (0 , input_dims ):
457
+ if index [i ] >= input .shape [i ]:
458
+ raise RuntimeError (
459
+ f"cannot have index greater than the dimension length! { input .shape [dim ]} "
460
+ )
461
+ scatter_layer = ctx .net .add_scatter (input , index , src , trt .tensorrt .ScatterModekELEMENT )
462
+ scatter_layer .set_axis (dim )
463
+ set_layer_name (scatter_layer , target , name + "_scatter_layer" , source_ir )
464
+ out = scatter_layer .get_output (0 )
465
+ return out
0 commit comments