@@ -393,100 +393,38 @@ def index_select(
393
393
return gather_layer .get_output (0 )
394
394
395
395
396
- def scatter_value (
396
+ def scatter (
397
397
ctx : ConversionContext ,
398
398
target : Target ,
399
399
source_ir : Optional [SourceIR ],
400
400
name : str ,
401
401
input : TRTTensor ,
402
402
dim : int ,
403
403
index : Union [TRTTensor , np .ndarray , torch .Tensor ],
404
- value : float ,
404
+ src : Union [ TRTTensor , int , float ] ,
405
405
) -> TRTTensor :
406
- if not isinstance (input , TRTTensor ):
407
- raise RuntimeError (
408
- f"scatter_tensor received input { input } that is not part "
409
- "of the TensorRT region!"
410
- )
411
406
input_shape = input .shape
412
407
index_shape = index .shape
413
408
index_shape_list = list (index .shape )
414
409
if not (isinstance (index , TRTTensor )):
415
410
index = get_trt_tensor (ctx , index , f"_index_tensor" )
416
- if len (input_shape ) != len (index_shape ):
417
- raise RuntimeError (f"The no of dimensions of input and index should be equal" )
418
411
dim = get_positive_dim (dim , len (input_shape ))
419
412
dynamic_shape = has_dynamic_shape (input .shape )
420
413
if dynamic_shape :
421
414
# Check whether slice target dim is dynamic shape dim
422
415
assert input .shape [dim ] != - 1 , "Can't scatter on negative shape dimension!"
423
416
424
- input_dims = len (input_shape )
425
- for i in range (0 , input_dims ):
426
- if i != dim and (index_shape [i ] >= input .shape [i ]):
427
- raise RuntimeError (
428
- f"cannot have index size greater than the input size along dimension { dim } "
429
- )
430
-
431
- value_tensor = get_trt_tensor (
432
- ctx , value * torch .ones (index_shape_list ), name + "_value_tensor"
433
- )
434
- value_tensor = cast_trt_tensor (
435
- ctx , value_tensor , input .dtype , name + "_cast_value_tensor"
436
- )
437
- scatter_layer = ctx .net .add_scatter (
438
- input , index , value_tensor , trt .ScatterMode .ELEMENT
439
- )
440
- scatter_layer .axis = dim
441
- set_layer_name (scatter_layer , target , name + "_scatter_layer" , source_ir )
442
- out = scatter_layer .get_output (0 )
443
- return out
444
-
445
-
446
- def scatter_src (
447
- ctx : ConversionContext ,
448
- target : Target ,
449
- source_ir : Optional [SourceIR ],
450
- name : str ,
451
- input : TRTTensor ,
452
- dim : Shape ,
453
- index : Shape ,
454
- src : TRTTensor ,
455
- ) -> TRTTensor :
456
- if not isinstance (input , TRTTensor ):
457
- raise RuntimeError (
458
- f"scatter_tensor received input { input } that is not part "
459
- "of the TensorRT region!"
460
- )
461
- input_shape = input .shape
462
- index_shape = index .shape
463
- src_shape = src .shape
464
- if not (isinstance (index , TRTTensor )):
465
- index = get_trt_tensor (ctx , index , f"_index_tensor" )
466
- if len (input_shape ) != len (index_shape ):
467
- raise RuntimeError (f"The no of dimensions of input and index should be equal" )
468
- if len (index_shape ) != len (src_shape ):
469
- raise RuntimeError (f"The no of dimensions of src and index should be equal" )
470
-
471
- input_dims = len (input_shape )
472
- dim = get_positive_dim (cast (int , dim ), input_dims )
473
- dynamic_shape = has_dynamic_shape (input .shape )
474
- if dynamic_shape :
475
- # Check whether slice target dim is dynamic shape dim
476
- assert input .shape [dim ] != - 1 , "Can't scatter on negative shape dimension!"
477
-
478
- for i in range (0 , input_dims ):
479
- if i != dim and (index_shape [i ] >= input .shape [i ]):
480
- raise RuntimeError (
481
- f"cannot have index size greater than the input size along dimension { dim } "
482
- )
483
- input_dtype = input .dtype
484
- # required for cases where src is a constant
485
- src_dtype = unified_dtype_converter (src .dtype , Frameworks .TRT )
486
- if input_dtype != src_dtype :
487
- raise RuntimeError (f"The type of input and src should be made" )
488
417
src_tensor = src
489
- if not (isinstance (src , TRTTensor )):
418
+ # scatter.value
419
+ if isinstance (src , int ) or isinstance (src , float ):
420
+ src_tensor = get_trt_tensor (
421
+ ctx , src * torch .ones (index_shape_list ), name + "_value_tensor"
422
+ )
423
+ src_tensor = cast_trt_tensor (
424
+ ctx , src_tensor , input .dtype , name + "_cast_value_tensor"
425
+ )
426
+ # scatter.src
427
+ elif not (isinstance (src , TRTTensor )):
490
428
src_tensor = get_trt_tensor (ctx , src , name + "_src_tensor" )
491
429
492
430
scatter_layer = ctx .net .add_scatter (
0 commit comments