Skip to content

Commit 7c1b5ba

Browse files
committed
scatter_value and scatter_src converter
1 parent ac82540 commit 7c1b5ba

File tree

3 files changed

+122
-10
lines changed

3 files changed

+122
-10
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -697,6 +697,44 @@ def aten_ops_clamp(
697697
)
698698

699699

700+
@dynamo_tensorrt_converter(torch.ops.aten.scatter.value)
701+
def aten_ops_scatter_value(
702+
ctx: ConversionContext,
703+
target: Target,
704+
args: Tuple[Argument, ...],
705+
kwargs: Dict[str, Argument],
706+
name: str,
707+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
708+
return impl.select.scatter_value(
709+
ctx, target, SourceIR.ATEN, name, args[0], args[1], args[2]
710+
)
711+
712+
713+
@dynamo_tensorrt_converter(torch.ops.aten.scatter.src)
714+
def aten_ops_scatter_src(
715+
ctx: ConversionContext,
716+
target: Target,
717+
args: Tuple[Argument, ...],
718+
kwargs: Dict[str, Argument],
719+
name: str,
720+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
721+
return impl.select.scatter_src(
722+
ctx, target, SourceIR.ATEN, name, args[0], args[1], args[2]
723+
)
724+
725+
726+
def aten_ops_select(
727+
ctx: ConversionContext,
728+
target: Target,
729+
args: Tuple[Argument, ...],
730+
kwargs: Dict[str, Argument],
731+
name: str,
732+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
733+
return impl.select.select(
734+
ctx, target, SourceIR.ATEN, name, args[0], args[1], args[2]
735+
)
736+
737+
700738
@dynamo_tensorrt_converter(torch.ops.aten.select.int)
701739
def aten_ops_select(
702740
ctx: ConversionContext,

py/torch_tensorrt/dynamo/conversion/impl/select.py

Lines changed: 25 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -380,20 +380,25 @@ def scatter_value(
380380
input: TRTTensor,
381381
dim: Shape,
382382
index: Shape,
383-
value: TRTTensor,
383+
value: float,
384384
) -> TRTTensor:
385385
if not isinstance(input, TRTTensor):
386386
raise RuntimeError(
387387
f"scatter_tensor received input {input} that is not part "
388388
"of the TensorRT region!"
389389
)
390-
391-
ranks = len(input.shape)
390+
input_shape = input.shape
391+
index_shape = index.shape
392+
if (len(input_shape) != len(index_shape)):
393+
raise RuntimeError(
394+
f"The no of dimensions of input and index should be equal"
395+
)
396+
ranks = len(input_shape)
392397
dim = get_positive_dim(cast(int, dim), ranks)
393398
dynamic_shape = has_dynamic_shape(input.shape)
394399
if dynamic_shape:
395400
# Check whether slice target dim is dynamic shape dim
396-
assert input.shape[dim] != -1, "Can't select on negative shape dimension!"
401+
assert input.shape[dim] != -1, "Can't scatter on negative shape dimension!"
397402

398403
input_dims = len(input.shape)
399404
for i in range(0, input_dims):
@@ -417,22 +422,32 @@ def scatter_src(
417422
input: TRTTensor,
418423
dim: Shape,
419424
index: Shape,
420-
src: float,
425+
src: TRTTensor,
421426
) -> TRTTensor:
422427
if not isinstance(input, TRTTensor):
423428
raise RuntimeError(
424429
f"scatter_tensor received input {input} that is not part "
425430
"of the TensorRT region!"
426431
)
427-
428-
ranks = len(input.shape)
429-
dim = get_positive_dim(cast(int, dim), ranks)
432+
input_shape = input.shape
433+
index_shape = index.shape
434+
src_shape = src.shape
435+
if (len(input_shape) != len(index_shape)):
436+
raise RuntimeError(
437+
f"The no of dimensions of input and index should be equal"
438+
)
439+
if (len(index_shape) != len(src_shape)):
440+
raise RuntimeError(
441+
f"The no of dimensions of src and index should be equal"
442+
)
443+
444+
input_dims = len(input_shape)
445+
dim = get_positive_dim(cast(int, dim), input_dims)
430446
dynamic_shape = has_dynamic_shape(input.shape)
431447
if dynamic_shape:
432448
# Check whether slice target dim is dynamic shape dim
433-
assert input.shape[dim] != -1, "Can't select on negative shape dimension!"
449+
assert input.shape[dim] != -1, "Can't scatter on negative shape dimension!"
434450

435-
input_dims = len(input.shape)
436451
for i in range(0, input_dims):
437452
if index[i] >= input.shape[i]:
438453
raise RuntimeError(
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
import torch
2+
from parameterized import parameterized
3+
from torch.testing._internal.common_utils import run_tests
4+
from torch_tensorrt import Input
5+
6+
from .harness import DispatchTestCase
7+
8+
9+
class TestScatterValueConverter(DispatchTestCase):
10+
@parameterized.expand(
11+
[
12+
("scatter_zero_dim_indexOne_value", 0, [[0, 1, 2, 0]], 1),
13+
("scatter_zero_dim_indexTwo_value", 0, [[0, 1, 2, 0], [1, 2, 1, 1]], 1),
14+
("scatter_one_dim_indexOne_value", 1, [[0, 1, 2, 0]], 1),
15+
("scatter_one_dim_indexTwo_value", 1, [[0, 1, 2, 0], [1, 2, 1, 1]], 1),
16+
]
17+
)
18+
def test_scatter(self, _, dim, index, value):
19+
class TestModule(torch.nn.Module):
20+
def __init__(self):
21+
super().__init__()
22+
23+
def forward(self, input, src):
24+
return torch.ops.aten.scatter.value(input, dim, index, value)
25+
26+
input = [torch.zeros(3, 5, dtype = torch.int32)]
27+
self.run_test(
28+
TestModule(),
29+
input,
30+
)
31+
32+
33+
class TestScatterSrcConverter(DispatchTestCase):
34+
@parameterized.expand(
35+
[
36+
("scatter_zero_dim_indexOne", 0, [[0, 1, 2, 0]]),
37+
("scatter_zero_dim_indexTwo", 0, [[0, 1, 2, 0], [1, 2, 1, 1]]),
38+
("scatter_one_dim_indexOne", 1, [[0, 1, 2, 0]]),
39+
("scatter_one_dim_indexTwo", 1, [[0, 1, 2, 0], [1, 2, 1, 1]]),
40+
]
41+
)
42+
def test_scatter(self, _, dim, index):
43+
class TestModule(torch.nn.Module):
44+
def __init__(self):
45+
super().__init__()
46+
47+
def forward(self, input, src):
48+
return torch.ops.aten.scatter.src(input, dim, index, src)
49+
50+
src = [torch.arange(1, 11).reshape((2,5))]
51+
input = torch.zeros(3, 5, dtype = src.dtype)
52+
inputs = [input, src]
53+
self.run_test(
54+
TestModule(),
55+
inputs,
56+
)
57+
58+
59+

0 commit comments

Comments
 (0)