Skip to content

Commit 6fbc0ec

Browse files
committed
scatter_value and scatter_src converter
1 parent 628fab7 commit 6fbc0ec

File tree

3 files changed

+122
-10
lines changed

3 files changed

+122
-10
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -691,6 +691,44 @@ def aten_ops_clamp(
691691
)
692692

693693

694+
@dynamo_tensorrt_converter(torch.ops.aten.scatter.value)
695+
def aten_ops_scatter_value(
696+
ctx: ConversionContext,
697+
target: Target,
698+
args: Tuple[Argument, ...],
699+
kwargs: Dict[str, Argument],
700+
name: str,
701+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
702+
return impl.select.scatter_value(
703+
ctx, target, SourceIR.ATEN, name, args[0], args[1], args[2]
704+
)
705+
706+
707+
@dynamo_tensorrt_converter(torch.ops.aten.scatter.src)
708+
def aten_ops_scatter_src(
709+
ctx: ConversionContext,
710+
target: Target,
711+
args: Tuple[Argument, ...],
712+
kwargs: Dict[str, Argument],
713+
name: str,
714+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
715+
return impl.select.scatter_src(
716+
ctx, target, SourceIR.ATEN, name, args[0], args[1], args[2]
717+
)
718+
719+
720+
def aten_ops_select(
721+
ctx: ConversionContext,
722+
target: Target,
723+
args: Tuple[Argument, ...],
724+
kwargs: Dict[str, Argument],
725+
name: str,
726+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
727+
return impl.select.select(
728+
ctx, target, SourceIR.ATEN, name, args[0], args[1], args[2]
729+
)
730+
731+
694732
@dynamo_tensorrt_converter(torch.ops.aten.select.int)
695733
def aten_ops_select(
696734
ctx: ConversionContext,

py/torch_tensorrt/dynamo/conversion/impl/select.py

Lines changed: 25 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -400,20 +400,25 @@ def scatter_value(
400400
input: TRTTensor,
401401
dim: Shape,
402402
index: Shape,
403-
value: TRTTensor,
403+
value: float,
404404
) -> TRTTensor:
405405
if not isinstance(input, TRTTensor):
406406
raise RuntimeError(
407407
f"scatter_tensor received input {input} that is not part "
408408
"of the TensorRT region!"
409409
)
410-
411-
ranks = len(input.shape)
410+
input_shape = input.shape
411+
index_shape = index.shape
412+
if (len(input_shape) != len(index_shape)):
413+
raise RuntimeError(
414+
f"The no of dimensions of input and index should be equal"
415+
)
416+
ranks = len(input_shape)
412417
dim = get_positive_dim(cast(int, dim), ranks)
413418
dynamic_shape = has_dynamic_shape(input.shape)
414419
if dynamic_shape:
415420
# Check whether slice target dim is dynamic shape dim
416-
assert input.shape[dim] != -1, "Can't select on negative shape dimension!"
421+
assert input.shape[dim] != -1, "Can't scatter on negative shape dimension!"
417422

418423
input_dims = len(input.shape)
419424
for i in range(0, input_dims):
@@ -437,22 +442,32 @@ def scatter_src(
437442
input: TRTTensor,
438443
dim: Shape,
439444
index: Shape,
440-
src: float,
445+
src: TRTTensor,
441446
) -> TRTTensor:
442447
if not isinstance(input, TRTTensor):
443448
raise RuntimeError(
444449
f"scatter_tensor received input {input} that is not part "
445450
"of the TensorRT region!"
446451
)
447-
448-
ranks = len(input.shape)
449-
dim = get_positive_dim(cast(int, dim), ranks)
452+
input_shape = input.shape
453+
index_shape = index.shape
454+
src_shape = src.shape
455+
if (len(input_shape) != len(index_shape)):
456+
raise RuntimeError(
457+
f"The no of dimensions of input and index should be equal"
458+
)
459+
if (len(index_shape) != len(src_shape)):
460+
raise RuntimeError(
461+
f"The no of dimensions of src and index should be equal"
462+
)
463+
464+
input_dims = len(input_shape)
465+
dim = get_positive_dim(cast(int, dim), input_dims)
450466
dynamic_shape = has_dynamic_shape(input.shape)
451467
if dynamic_shape:
452468
# Check whether slice target dim is dynamic shape dim
453-
assert input.shape[dim] != -1, "Can't select on negative shape dimension!"
469+
assert input.shape[dim] != -1, "Can't scatter on negative shape dimension!"
454470

455-
input_dims = len(input.shape)
456471
for i in range(0, input_dims):
457472
if index[i] >= input.shape[i]:
458473
raise RuntimeError(
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
import torch
2+
from parameterized import parameterized
3+
from torch.testing._internal.common_utils import run_tests
4+
from torch_tensorrt import Input
5+
6+
from .harness import DispatchTestCase
7+
8+
9+
class TestScatterValueConverter(DispatchTestCase):
10+
@parameterized.expand(
11+
[
12+
("scatter_zero_dim_indexOne_value", 0, [[0, 1, 2, 0]], 1),
13+
("scatter_zero_dim_indexTwo_value", 0, [[0, 1, 2, 0], [1, 2, 1, 1]], 1),
14+
("scatter_one_dim_indexOne_value", 1, [[0, 1, 2, 0]], 1),
15+
("scatter_one_dim_indexTwo_value", 1, [[0, 1, 2, 0], [1, 2, 1, 1]], 1),
16+
]
17+
)
18+
def test_scatter(self, _, dim, index, value):
19+
class TestModule(torch.nn.Module):
20+
def __init__(self):
21+
super().__init__()
22+
23+
def forward(self, input, src):
24+
return torch.ops.aten.scatter.value(input, dim, index, value)
25+
26+
input = [torch.zeros(3, 5, dtype = torch.int32)]
27+
self.run_test(
28+
TestModule(),
29+
input,
30+
)
31+
32+
33+
class TestScatterSrcConverter(DispatchTestCase):
34+
@parameterized.expand(
35+
[
36+
("scatter_zero_dim_indexOne", 0, [[0, 1, 2, 0]]),
37+
("scatter_zero_dim_indexTwo", 0, [[0, 1, 2, 0], [1, 2, 1, 1]]),
38+
("scatter_one_dim_indexOne", 1, [[0, 1, 2, 0]]),
39+
("scatter_one_dim_indexTwo", 1, [[0, 1, 2, 0], [1, 2, 1, 1]]),
40+
]
41+
)
42+
def test_scatter(self, _, dim, index):
43+
class TestModule(torch.nn.Module):
44+
def __init__(self):
45+
super().__init__()
46+
47+
def forward(self, input, src):
48+
return torch.ops.aten.scatter.src(input, dim, index, src)
49+
50+
src = [torch.arange(1, 11).reshape((2,5))]
51+
input = torch.zeros(3, 5, dtype = src.dtype)
52+
inputs = [input, src]
53+
self.run_test(
54+
TestModule(),
55+
inputs,
56+
)
57+
58+
59+

0 commit comments

Comments
 (0)