Skip to content

Commit 35b5d03

Browse files
apboselaikhtewari
authored andcommitted
Aten scatter converter (#2664)
1 parent 10698e2 commit 35b5d03

File tree

4 files changed

+274
-10
lines changed

4 files changed

+274
-10
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -691,6 +691,26 @@ def aten_ops_clamp(
691691
)
692692

693693

694+
@dynamo_tensorrt_converter(torch.ops.aten.scatter.src)
695+
@dynamo_tensorrt_converter(torch.ops.aten.scatter.value)
696+
@enforce_tensor_types(
697+
{
698+
0: (TRTTensor,),
699+
2: (TRTTensor,),
700+
}
701+
)
702+
def aten_ops_scatter(
703+
ctx: ConversionContext,
704+
target: Target,
705+
args: Tuple[Argument, ...],
706+
kwargs: Dict[str, Argument],
707+
name: str,
708+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
709+
return impl.select.scatter(
710+
ctx, target, SourceIR.ATEN, name, args[0], args[1], args[2], args[3]
711+
)
712+
713+
694714
@dynamo_tensorrt_converter(torch.ops.aten.select.int)
695715
def aten_ops_select(
696716
ctx: ConversionContext,

py/torch_tensorrt/dynamo/conversion/impl/select.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -390,3 +390,41 @@ def index_select(
390390
set_layer_name(gather_layer, target, f"{name}_gather", source_ir)
391391

392392
return gather_layer.get_output(0)
393+
394+
395+
def scatter(
396+
ctx: ConversionContext,
397+
target: Target,
398+
source_ir: Optional[SourceIR],
399+
name: str,
400+
input: TRTTensor,
401+
dim: int,
402+
index: Union[TRTTensor, np.ndarray, torch.Tensor],
403+
src: Union[TRTTensor, int, float],
404+
) -> TRTTensor:
405+
input_shape = input.shape
406+
index_shape = index.shape
407+
index_shape_list = list(index_shape)
408+
if index.dtype == trt.int64:
409+
index = cast_trt_tensor(ctx, index, trt.int32, name + "_cast_index_tensor")
410+
dim = get_positive_dim(dim, len(input_shape))
411+
src_tensor = src
412+
# scatter.value
413+
if isinstance(src, int) or isinstance(src, float):
414+
src_tensor = get_trt_tensor(
415+
ctx, src * np.ones(index_shape_list), name + "_value_tensor"
416+
)
417+
src_tensor = cast_trt_tensor(
418+
ctx, src_tensor, input.dtype, name + "_cast_value_tensor"
419+
)
420+
# scatter.src
421+
elif not (isinstance(src, TRTTensor)):
422+
src_tensor = get_trt_tensor(ctx, src, name + "_src_tensor")
423+
424+
scatter_layer = ctx.net.add_scatter(
425+
input, index, src_tensor, trt.ScatterMode.ELEMENT
426+
)
427+
scatter_layer.axis = dim
428+
set_layer_name(scatter_layer, target, name + "_scatter_layer", source_ir)
429+
out = scatter_layer.get_output(0)
430+
return out

tests/py/dynamo/conversion/harness.py

Lines changed: 39 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
# type: ignore
2-
31
import logging
42
import time
53
import unittest
@@ -50,16 +48,20 @@ def setUp(self):
5048
def run_test(
5149
self,
5250
mod,
53-
inputs,
51+
fx_inputs,
52+
trt_interpreter_inputs,
5453
interpreter,
5554
rtol,
5655
atol,
5756
check_dtype=True,
5857
):
5958
with torch.no_grad():
60-
cuda_inputs = []
61-
for i in inputs:
62-
cuda_inputs.append(i.cuda())
59+
cuda_fx_inputs = []
60+
cuda_trt_inputs = []
61+
for i in trt_interpreter_inputs:
62+
cuda_trt_inputs.append(i.cuda())
63+
for i in fx_inputs:
64+
cuda_fx_inputs.append(i.cuda())
6365

6466
mod.eval()
6567
start = time.perf_counter()
@@ -73,13 +75,13 @@ def run_test(
7375
)
7476

7577
mod = mod.cuda()
76-
ref_outputs = mod(*cuda_inputs)
78+
ref_outputs = mod(*cuda_fx_inputs)
7779

7880
torch.cuda.synchronize()
7981
start_event = torch.cuda.Event(enable_timing=True)
8082
end_event = torch.cuda.Event(enable_timing=True)
8183
start_event.record()
82-
outputs = trt_mod(*cuda_inputs)
84+
outputs = trt_mod(*cuda_trt_inputs)
8385
end_event.record()
8486
torch.cuda.synchronize()
8587
_LOGGER.info(
@@ -220,6 +222,7 @@ def run_test(
220222
check_dtype=True,
221223
use_dynamo_tracer=False,
222224
enable_passes=False,
225+
int32_reqd=False,
223226
):
224227
mod.eval()
225228
mod = self.generate_graph(
@@ -237,6 +240,30 @@ def run_test(
237240
debug=True,
238241
)
239242

243+
num_inputs = len(inputs)
244+
trt_inputs = inputs
245+
dtype_to_change = []
246+
if int32_reqd:
247+
dtype_to_change = [torch.int64, torch.float64]
248+
else:
249+
dtype_to_change = [
250+
torch.float64,
251+
]
252+
for num_input in range(num_inputs):
253+
input = inputs[num_input]
254+
if input.dtype in dtype_to_change:
255+
dtype_32bit = (
256+
torch.float32 if (input.dtype == torch.float64) else torch.int32
257+
)
258+
trt_inputs = (
259+
list(trt_inputs[:num_input])
260+
+ [
261+
input.to(dtype_32bit),
262+
]
263+
+ list(trt_inputs[num_input + 1 :])
264+
)
265+
266+
trt_input_specs = [Input.from_tensor(i) for i in trt_inputs]
240267
input_specs = [Input.from_tensor(i) for i in inputs]
241268

242269
output_dtypes = None
@@ -254,13 +281,15 @@ def run_test(
254281

255282
interp = TRTInterpreter(
256283
mod,
257-
input_specs,
284+
trt_input_specs,
258285
output_dtypes=output_dtypes,
259286
compilation_settings=compilation_settings,
260287
)
288+
261289
super().run_test(
262290
mod,
263291
inputs,
292+
trt_inputs,
264293
interp,
265294
rtol,
266295
atol,
@@ -335,4 +364,4 @@ def run_test_with_dynamic_shape(
335364
# Since the lowering is based on optimal shape. We need to test with
336365
# different shape(for ex. max shape) for testing dynamic shape
337366
inputs_max = [spec.example_tensor("max_shape") for spec in input_specs]
338-
super().run_test(mod, inputs_max, interp, rtol, atol)
367+
super().run_test(mod, inputs_max, inputs_max, interp, rtol, atol)
Lines changed: 177 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,177 @@
1+
import torch
2+
from parameterized import parameterized
3+
from torch.testing._internal.common_utils import run_tests
4+
from torch_tensorrt import Input
5+
6+
from .harness import DispatchTestCase
7+
8+
9+
class TestScatterValueConverter(DispatchTestCase):
10+
@parameterized.expand(
11+
[
12+
(
13+
"scatter_zero_dim_indexOne_constant_value",
14+
0,
15+
torch.tensor([[0, 1, 2, 0]]),
16+
1,
17+
),
18+
(
19+
"scatter_zero_dim_indexTwo_constant_value",
20+
0,
21+
torch.tensor([[0, 1, 2, 0], [1, 2, 1, 1]]),
22+
1,
23+
),
24+
(
25+
"scatter_one_dim_indexOne_constant_value",
26+
1,
27+
torch.tensor([[0, 1, 2, 0]]),
28+
1,
29+
),
30+
(
31+
"scatter_one_dim_indexTwo_costant_value",
32+
1,
33+
torch.tensor([[0, 1, 2, 0], [1, 2, 1, 1]]),
34+
1,
35+
),
36+
]
37+
)
38+
def test_scatter_index_constant(self, _, dim, index, value):
39+
class TestModule(torch.nn.Module):
40+
def __init__(self):
41+
super().__init__()
42+
43+
def forward(self, input):
44+
return torch.ops.aten.scatter.value(input, dim, index, value)
45+
46+
input = torch.zeros(3, 5, dtype=torch.int32)
47+
inputs = [input]
48+
self.run_test(TestModule(), inputs, int32_reqd=True)
49+
50+
@parameterized.expand(
51+
[
52+
("scatter_zero_dim_indexOne_value", 0, torch.tensor([[0, 1, 2, 0]]), 1),
53+
(
54+
"scatter_zero_dim_indexTwo_value",
55+
0,
56+
torch.tensor([[0, 1, 2, 0], [1, 2, 1, 1]]),
57+
1,
58+
),
59+
("scatter_one_dim_indexOne_value", 1, torch.tensor([[0, 1, 2, 0]]), 1),
60+
(
61+
"scatter_one_dim_indexTwo_value",
62+
1,
63+
torch.tensor([[0, 1, 2, 0], [1, 2, 1, 1]]),
64+
1,
65+
),
66+
]
67+
)
68+
def test_scatter_index_input(self, _, dim, index, value):
69+
class TestModule(torch.nn.Module):
70+
def __init__(self):
71+
super().__init__()
72+
73+
def forward(self, input, index):
74+
return torch.ops.aten.scatter.value(input, dim, index, value)
75+
76+
input = torch.zeros(3, 5, dtype=torch.int32)
77+
inputs = [input, index]
78+
self.run_test(TestModule(), inputs, int32_reqd=True)
79+
80+
81+
class TestScatterSrcConverter(DispatchTestCase):
82+
@parameterized.expand(
83+
[
84+
(
85+
"scatter_zero_dim_indexOne_src",
86+
0,
87+
torch.tensor([[0, 1, 2, 0]]),
88+
torch.tensor([[1, 2, 3, 4]], dtype=torch.int32),
89+
),
90+
(
91+
"scatter_zero_dim_indexTwo_src",
92+
0,
93+
torch.tensor([[0, 1, 2, 0], [1, 2, 1, 1]]),
94+
torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8]], dtype=torch.int32),
95+
),
96+
(
97+
"scatter_one_dim_indexOne_src",
98+
1,
99+
torch.tensor([[0, 1, 2, 0]]),
100+
torch.tensor([[1, 2, 3, 1]], dtype=torch.int32),
101+
),
102+
(
103+
"scatter_one_dim_indexTwo_src",
104+
1,
105+
torch.tensor([[0, 1, 2, 0], [1, 2, 1, 1]]),
106+
torch.tensor([[1, 2, 3, 1], [5, 6, 5, 5]], dtype=torch.int32),
107+
),
108+
(
109+
"scatter_one_dim_indexOne_constant_src",
110+
1,
111+
torch.tensor([[0, 1, 2, 0]]),
112+
torch.tensor([[1, 2, 3, 4]], dtype=torch.int32),
113+
),
114+
(
115+
"scatter_one_dim_indexTwo_constant_src",
116+
1,
117+
torch.tensor([[0, 1, 2, 0], [1, 2, 1, 1]]),
118+
torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8]], dtype=torch.int32),
119+
),
120+
]
121+
)
122+
def test_scatter_index_constant(self, _, dim, index, src):
123+
class TestModule(torch.nn.Module):
124+
def __init__(self):
125+
super().__init__()
126+
127+
def forward(self, input):
128+
return torch.ops.aten.scatter.src(input, dim, index, src)
129+
130+
input = torch.zeros(3, 5, dtype=torch.int32)
131+
inputs = [input]
132+
scatter = TestModule()
133+
self.run_test(TestModule(), inputs, int32_reqd=True)
134+
135+
@parameterized.expand(
136+
[
137+
(
138+
"scatter_zero_dim_indexOne_constant_src",
139+
0,
140+
torch.tensor([[0, 1, 2, 0]]),
141+
torch.tensor([[1, 2, 3, 4]], dtype=torch.int32),
142+
),
143+
(
144+
"scatter_zero_dim_indexTwo_constant_src",
145+
0,
146+
torch.tensor([[0, 1, 2, 0], [1, 2, 1, 1]]),
147+
torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8]], dtype=torch.int32),
148+
),
149+
(
150+
"scatter_one_dim_indexOne_constant_src",
151+
1,
152+
torch.tensor([[0, 1, 2, 0]]),
153+
torch.tensor([[1, 2, 3, 1]], dtype=torch.int32),
154+
),
155+
(
156+
"scatter_one_dim_indexTwo_constant_src",
157+
1,
158+
torch.tensor([[0, 1, 2, 0], [1, 2, 1, 1]]),
159+
torch.tensor([[1, 2, 3, 1], [5, 6, 5, 5]], dtype=torch.int32),
160+
),
161+
]
162+
)
163+
def test_scatter_index_input(self, _, dim, index, src):
164+
class TestModule(torch.nn.Module):
165+
def __init__(self):
166+
super().__init__()
167+
168+
def forward(self, input, index):
169+
return torch.ops.aten.scatter.src(input, dim, index, src)
170+
171+
input = torch.zeros(3, 5, dtype=torch.int32)
172+
inputs = [input, index]
173+
self.run_test(TestModule(), inputs, int32_reqd=True)
174+
175+
176+
if __name__ == "__main__":
177+
run_tests()

0 commit comments

Comments
 (0)