Skip to content

Commit c4de771

Browse files
committed
addressing review comments and changing test names
1 parent ae08f41 commit c4de771

File tree

3 files changed

+26
-94
lines changed

3 files changed

+26
-94
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 8 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -691,28 +691,21 @@ def aten_ops_clamp(
691691
)
692692

693693

694-
@dynamo_tensorrt_converter(torch.ops.aten.scatter.value)
695-
def aten_ops_scatter_value(
696-
ctx: ConversionContext,
697-
target: Target,
698-
args: Tuple[Argument, ...],
699-
kwargs: Dict[str, Argument],
700-
name: str,
701-
) -> Union[TRTTensor, Sequence[TRTTensor]]:
702-
return impl.select.scatter_value(
703-
ctx, target, SourceIR.ATEN, name, args[0], args[1], args[2], args[3]
704-
)
705-
706-
694+
@enforce_tensor_types(
695+
{
696+
0: (TRTTensor,),
697+
}
698+
)
707699
@dynamo_tensorrt_converter(torch.ops.aten.scatter.src)
708-
def aten_ops_scatter_src(
700+
@dynamo_tensorrt_converter(torch.ops.aten.scatter.value)
701+
def aten_ops_scatter(
709702
ctx: ConversionContext,
710703
target: Target,
711704
args: Tuple[Argument, ...],
712705
kwargs: Dict[str, Argument],
713706
name: str,
714707
) -> Union[TRTTensor, Sequence[TRTTensor]]:
715-
return impl.select.scatter_src(
708+
return impl.select.scatter(
716709
ctx, target, SourceIR.ATEN, name, args[0], args[1], args[2], args[3]
717710
)
718711

py/torch_tensorrt/dynamo/conversion/impl/select.py

Lines changed: 12 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -393,100 +393,38 @@ def index_select(
393393
return gather_layer.get_output(0)
394394

395395

396-
def scatter_value(
396+
def scatter(
397397
ctx: ConversionContext,
398398
target: Target,
399399
source_ir: Optional[SourceIR],
400400
name: str,
401401
input: TRTTensor,
402402
dim: int,
403403
index: Union[TRTTensor, np.ndarray, torch.Tensor],
404-
value: float,
404+
src: Union[TRTTensor, int, float],
405405
) -> TRTTensor:
406-
if not isinstance(input, TRTTensor):
407-
raise RuntimeError(
408-
f"scatter_tensor received input {input} that is not part "
409-
"of the TensorRT region!"
410-
)
411406
input_shape = input.shape
412407
index_shape = index.shape
413408
index_shape_list = list(index.shape)
414409
if not (isinstance(index, TRTTensor)):
415410
index = get_trt_tensor(ctx, index, f"_index_tensor")
416-
if len(input_shape) != len(index_shape):
417-
raise RuntimeError(f"The no of dimensions of input and index should be equal")
418411
dim = get_positive_dim(dim, len(input_shape))
419412
dynamic_shape = has_dynamic_shape(input.shape)
420413
if dynamic_shape:
421414
# Check whether slice target dim is dynamic shape dim
422415
assert input.shape[dim] != -1, "Can't scatter on negative shape dimension!"
423416

424-
input_dims = len(input_shape)
425-
for i in range(0, input_dims):
426-
if i != dim and (index_shape[i] >= input.shape[i]):
427-
raise RuntimeError(
428-
f"cannot have index size greater than the input size along dimension {dim}"
429-
)
430-
431-
value_tensor = get_trt_tensor(
432-
ctx, value * torch.ones(index_shape_list), name + "_value_tensor"
433-
)
434-
value_tensor = cast_trt_tensor(
435-
ctx, value_tensor, input.dtype, name + "_cast_value_tensor"
436-
)
437-
scatter_layer = ctx.net.add_scatter(
438-
input, index, value_tensor, trt.ScatterMode.ELEMENT
439-
)
440-
scatter_layer.axis = dim
441-
set_layer_name(scatter_layer, target, name + "_scatter_layer", source_ir)
442-
out = scatter_layer.get_output(0)
443-
return out
444-
445-
446-
def scatter_src(
447-
ctx: ConversionContext,
448-
target: Target,
449-
source_ir: Optional[SourceIR],
450-
name: str,
451-
input: TRTTensor,
452-
dim: Shape,
453-
index: Shape,
454-
src: TRTTensor,
455-
) -> TRTTensor:
456-
if not isinstance(input, TRTTensor):
457-
raise RuntimeError(
458-
f"scatter_tensor received input {input} that is not part "
459-
"of the TensorRT region!"
460-
)
461-
input_shape = input.shape
462-
index_shape = index.shape
463-
src_shape = src.shape
464-
if not (isinstance(index, TRTTensor)):
465-
index = get_trt_tensor(ctx, index, f"_index_tensor")
466-
if len(input_shape) != len(index_shape):
467-
raise RuntimeError(f"The no of dimensions of input and index should be equal")
468-
if len(index_shape) != len(src_shape):
469-
raise RuntimeError(f"The no of dimensions of src and index should be equal")
470-
471-
input_dims = len(input_shape)
472-
dim = get_positive_dim(cast(int, dim), input_dims)
473-
dynamic_shape = has_dynamic_shape(input.shape)
474-
if dynamic_shape:
475-
# Check whether slice target dim is dynamic shape dim
476-
assert input.shape[dim] != -1, "Can't scatter on negative shape dimension!"
477-
478-
for i in range(0, input_dims):
479-
if i != dim and (index_shape[i] >= input.shape[i]):
480-
raise RuntimeError(
481-
f"cannot have index size greater than the input size along dimension {dim}"
482-
)
483-
input_dtype = input.dtype
484-
# required for cases where src is a constant
485-
src_dtype = unified_dtype_converter(src.dtype, Frameworks.TRT)
486-
if input_dtype != src_dtype:
487-
raise RuntimeError(f"The type of input and src should be made")
488417
src_tensor = src
489-
if not (isinstance(src, TRTTensor)):
418+
# scatter.value
419+
if isinstance(src, int) or isinstance(src, float):
420+
src_tensor = get_trt_tensor(
421+
ctx, src * torch.ones(index_shape_list), name + "_value_tensor"
422+
)
423+
src_tensor = cast_trt_tensor(
424+
ctx, src_tensor, input.dtype, name + "_cast_value_tensor"
425+
)
426+
# scatter.src
427+
elif not (isinstance(src, TRTTensor)):
490428
src_tensor = get_trt_tensor(ctx, src, name + "_src_tensor")
491429

492430
scatter_layer = ctx.net.add_scatter(

tests/py/dynamo/conversion/test_scatter_aten.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
import torch
2-
from harness import DispatchTestCase
32
from parameterized import parameterized
43
from torch.testing._internal.common_utils import run_tests
54
from torch_tensorrt import Input
65

6+
from .harness import DispatchTestCase
7+
78

89
class TestScatterValueConverter(DispatchTestCase):
910
@parameterized.expand(
@@ -87,25 +88,25 @@ class TestScatterSrcConverter(DispatchTestCase):
8788
@parameterized.expand(
8889
[
8990
(
90-
"scatter_zero_dim_indexOne_constant_src",
91+
"scatter_zero_dim_indexOne_src",
9192
0,
9293
torch.tensor([[0, 1, 2, 0]]),
9394
torch.tensor([[1, 2, 3, 4]], dtype=torch.int32),
9495
),
9596
(
96-
"scatter_zero_dim_indexTwo_constant_src",
97+
"scatter_zero_dim_indexTwo_src",
9798
0,
9899
torch.tensor([[0, 1, 2, 0], [1, 2, 1, 1]]),
99100
torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8]], dtype=torch.int32),
100101
),
101102
(
102-
"scatter_one_dim_indexOne_constant_src",
103+
"scatter_one_dim_indexOne_src",
103104
1,
104105
torch.tensor([[0, 1, 2, 0]]),
105106
torch.tensor([[1, 2, 3, 1]], dtype=torch.int32),
106107
),
107108
(
108-
"scatter_one_dim_indexTwo_constant_src",
109+
"scatter_one_dim_indexTwo_src",
109110
1,
110111
torch.tensor([[0, 1, 2, 0], [1, 2, 1, 1]]),
111112
torch.tensor([[1, 2, 3, 1], [5, 6, 5, 5]], dtype=torch.int32),

0 commit comments

Comments
 (0)