@@ -370,3 +370,76 @@ def index(
370
370
reshape_output = reshape_layer .get_output (0 )
371
371
372
372
return reshape_output
373
+
374
+
375
+ def scatter_value (
376
+ ctx : ConversionContext ,
377
+ target : Target ,
378
+ source_ir : Optional [SourceIR ],
379
+ name : str ,
380
+ input : TRTTensor ,
381
+ dim : Shape ,
382
+ index : Shape ,
383
+ value : TRTTensor ,
384
+ ) -> TRTTensor :
385
+ if not isinstance (input , TRTTensor ):
386
+ raise RuntimeError (
387
+ f"scatter_tensor received input { input } that is not part "
388
+ "of the TensorRT region!"
389
+ )
390
+
391
+ ranks = len (input .shape )
392
+ dim = get_positive_dim (cast (int , dim ), ranks )
393
+ dynamic_shape = has_dynamic_shape (input .shape )
394
+ if dynamic_shape :
395
+ # Check whether slice target dim is dynamic shape dim
396
+ assert input .shape [dim ] != - 1 , "Can't select on negative shape dimension!"
397
+
398
+ input_dims = len (input .shape )
399
+ for i in range (0 , input_dims ):
400
+ if index [i ] >= input .shape [i ]:
401
+ raise RuntimeError (
402
+ f"cannot have index greater than the dimension length! { input .shape [dim ]} "
403
+ )
404
+ value_tensor = value * torch .ones (index .shape )
405
+ scatter_layer = ctx .net .add_scatter (input , index , value_tensor , trt .tensorrt .ScatterModekELEMENT )
406
+ scatter_layer .set_axis (dim )
407
+ set_layer_name (scatter_layer , target , name + "_scatter_layer" , source_ir )
408
+ out = scatter_layer .get_output (0 )
409
+ return out
410
+
411
+
412
+ def scatter_src (
413
+ ctx : ConversionContext ,
414
+ target : Target ,
415
+ source_ir : Optional [SourceIR ],
416
+ name : str ,
417
+ input : TRTTensor ,
418
+ dim : Shape ,
419
+ index : Shape ,
420
+ src : float ,
421
+ ) -> TRTTensor :
422
+ if not isinstance (input , TRTTensor ):
423
+ raise RuntimeError (
424
+ f"scatter_tensor received input { input } that is not part "
425
+ "of the TensorRT region!"
426
+ )
427
+
428
+ ranks = len (input .shape )
429
+ dim = get_positive_dim (cast (int , dim ), ranks )
430
+ dynamic_shape = has_dynamic_shape (input .shape )
431
+ if dynamic_shape :
432
+ # Check whether slice target dim is dynamic shape dim
433
+ assert input .shape [dim ] != - 1 , "Can't select on negative shape dimension!"
434
+
435
+ input_dims = len (input .shape )
436
+ for i in range (0 , input_dims ):
437
+ if index [i ] >= input .shape [i ]:
438
+ raise RuntimeError (
439
+ f"cannot have index greater than the dimension length! { input .shape [dim ]} "
440
+ )
441
+ scatter_layer = ctx .net .add_scatter (input , index , src , trt .tensorrt .ScatterModekELEMENT )
442
+ scatter_layer .set_axis (dim )
443
+ set_layer_name (scatter_layer , target , name + "_scatter_layer" , source_ir )
444
+ out = scatter_layer .get_output (0 )
445
+ return out
0 commit comments