21
21
set_layer_name ,
22
22
)
23
23
from torch_tensorrt .fx .types import Shape , TRTTensor
24
+ from torch_tensorrt .fx .utils import Frameworks , unified_dtype_converter
24
25
25
26
_LOGGER : logging .Logger = logging .getLogger (__name__ )
26
27
@@ -398,8 +399,8 @@ def scatter_value(
398
399
source_ir : Optional [SourceIR ],
399
400
name : str ,
400
401
input : TRTTensor ,
401
- dim : Shape ,
402
- index : Shape ,
402
+ dim : int ,
403
+ index : Union [ TRTTensor , np . ndarray , torch . Tensor ] ,
403
404
value : float ,
404
405
) -> TRTTensor :
405
406
if not isinstance (input , TRTTensor ):
@@ -409,26 +410,34 @@ def scatter_value(
409
410
)
410
411
input_shape = input .shape
411
412
index_shape = index .shape
413
+ index_shape_list = list (index .shape )
414
+ if not (isinstance (index , TRTTensor )):
415
+ index = get_trt_tensor (ctx , index , f"_index_tensor" )
412
416
if len (input_shape ) != len (index_shape ):
413
417
raise RuntimeError (f"The no of dimensions of input and index should be equal" )
414
- ranks = len (input_shape )
415
- dim = get_positive_dim (cast (int , dim ), ranks )
418
+ dim = get_positive_dim (dim , len (input_shape ))
416
419
dynamic_shape = has_dynamic_shape (input .shape )
417
420
if dynamic_shape :
418
421
# Check whether slice target dim is dynamic shape dim
419
422
assert input .shape [dim ] != - 1 , "Can't scatter on negative shape dimension!"
420
423
421
- input_dims = len (input . shape )
424
+ input_dims = len (input_shape )
422
425
for i in range (0 , input_dims ):
423
- if index [i ] >= input .shape [i ]:
426
+ if i != dim and ( index_shape [i ] >= input .shape [i ]) :
424
427
raise RuntimeError (
425
- f"cannot have index greater than the dimension length! { input . shape [ dim ] } "
428
+ f"cannot have index size greater than the input size along dimension { dim } "
426
429
)
427
- value_tensor = value * torch .ones (index .shape )
430
+
431
+ value_tensor = get_trt_tensor (
432
+ ctx , value * torch .ones (index_shape_list ), name + "_value_tensor"
433
+ )
434
+ value_tensor = cast_trt_tensor (
435
+ ctx , value_tensor , input .dtype , name + "_cast_value_tensor"
436
+ )
428
437
scatter_layer = ctx .net .add_scatter (
429
- input , index , value_tensor , trt .tensorrt . ScatterModekELEMENT
438
+ input , index , value_tensor , trt .ScatterMode . ELEMENT
430
439
)
431
- scatter_layer .set_axis ( dim )
440
+ scatter_layer .axis = dim
432
441
set_layer_name (scatter_layer , target , name + "_scatter_layer" , source_ir )
433
442
out = scatter_layer .get_output (0 )
434
443
return out
@@ -452,6 +461,8 @@ def scatter_src(
452
461
input_shape = input .shape
453
462
index_shape = index .shape
454
463
src_shape = src .shape
464
+ if not (isinstance (index , TRTTensor )):
465
+ index = get_trt_tensor (ctx , index , f"_index_tensor" )
455
466
if len (input_shape ) != len (index_shape ):
456
467
raise RuntimeError (f"The no of dimensions of input and index should be equal" )
457
468
if len (index_shape ) != len (src_shape ):
@@ -465,14 +476,23 @@ def scatter_src(
465
476
assert input .shape [dim ] != - 1 , "Can't scatter on negative shape dimension!"
466
477
467
478
for i in range (0 , input_dims ):
468
- if index [i ] >= input .shape [i ]:
479
+ if i != dim and ( index_shape [i ] >= input .shape [i ]) :
469
480
raise RuntimeError (
470
- f"cannot have index greater than the dimension length! { input . shape [ dim ] } "
481
+ f"cannot have index size greater than the input size along dimension { dim } "
471
482
)
483
+ input_dtype = input .dtype
484
+ # required for cases where src is a constant
485
+ src_dtype = unified_dtype_converter (src .dtype , Frameworks .TRT )
486
+ if input_dtype != src_dtype :
487
+ raise RuntimeError (f"The type of input and src should be made" )
488
+ src_tensor = src
489
+ if not (isinstance (src , TRTTensor )):
490
+ src_tensor = get_trt_tensor (ctx , src , name + "_src_tensor" )
491
+
472
492
scatter_layer = ctx .net .add_scatter (
473
- input , index , src , trt .tensorrt . ScatterModekELEMENT
493
+ input , index , src_tensor , trt .ScatterMode . ELEMENT
474
494
)
475
- scatter_layer .set_axis ( dim )
495
+ scatter_layer .axis = dim
476
496
set_layer_name (scatter_layer , target , name + "_scatter_layer" , source_ir )
477
497
out = scatter_layer .get_output (0 )
478
498
return out
0 commit comments