Skip to content

Commit 3e668bf

Browse files
authored
Expose IGridSampleLayer (#2290)
1 parent c50cab4 commit 3e668bf

File tree

4 files changed

+230
-5
lines changed

4 files changed

+230
-5
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 33 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -330,6 +330,34 @@ def aten_ops_fmod(
330330
return impl.elementwise.fmod(ctx, target, SourceIR.ATEN, name, args[0], args[1])
331331

332332

333+
@dynamo_tensorrt_converter(torch.ops.aten.grid_sampler)
334+
@dynamo_tensorrt_converter(torch.ops.aten.grid_sampler_2d)
335+
@enforce_tensor_types(
336+
{
337+
0: (TRTTensor,),
338+
1: (TRTTensor,),
339+
}
340+
)
341+
def aten_ops_grid(
342+
ctx: ConversionContext,
343+
target: Target,
344+
args: Tuple[Argument, ...],
345+
kwargs: Dict[str, Argument],
346+
name: str,
347+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
348+
return impl.grid.grid(
349+
ctx,
350+
target,
351+
SourceIR.ATEN,
352+
name,
353+
input=args[0],
354+
grid=args[1],
355+
interpolation_mode=args[2],
356+
padding_mode=args[3],
357+
align_corners=args[4],
358+
)
359+
360+
333361
@dynamo_tensorrt_converter(torch.ops.aten.relu.default)
334362
def aten_ops_relu(
335363
ctx: ConversionContext,
@@ -759,12 +787,12 @@ def aten_ops_cumsum(
759787
)
760788

761789

762-
@dynamo_tensorrt_converter(torch.ops.aten.tile.default) # type: ignore[misc]
790+
@dynamo_tensorrt_converter(torch.ops.aten.tile.default)
763791
@enforce_tensor_types(
764792
{
765793
0: (TRTTensor,),
766794
}
767-
) # type: ignore[misc]
795+
)
768796
def aten_ops_tile(
769797
ctx: ConversionContext,
770798
target: Target,
@@ -782,7 +810,7 @@ def aten_ops_tile(
782810
)
783811

784812

785-
@dynamo_tensorrt_converter(torch.ops.aten.permute.default) # type: ignore[misc]
813+
@dynamo_tensorrt_converter(torch.ops.aten.permute.default)
786814
@enforce_tensor_types(
787815
{
788816
0: (TRTTensor,),
@@ -2000,14 +2028,14 @@ def aten_ops_argmax(
20002028
)
20012029

20022030

2003-
@dynamo_tensorrt_converter(torch.ops.aten.addmm.default) # type: ignore[misc]
2031+
@dynamo_tensorrt_converter(torch.ops.aten.addmm.default)
20042032
@enforce_tensor_types(
20052033
{
20062034
0: (TRTTensor,),
20072035
1: (np.ndarray, torch.Tensor, TRTTensor),
20082036
2: (np.ndarray, torch.Tensor, TRTTensor),
20092037
}
2010-
) # type: ignore[misc]
2038+
)
20112039
def aten_ops_addmm(
20122040
ctx: ConversionContext,
20132041
target: Target,

py/torch_tensorrt/dynamo/conversion/impl/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
deconv,
1313
elementwise,
1414
embedding,
15+
grid,
1516
linear,
1617
matmul,
1718
normalization,
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
from typing import Optional, Sequence
2+
3+
import tensorrt as trt
4+
import torch
5+
from torch.fx.node import Target
6+
from torch_tensorrt.dynamo._SourceIR import SourceIR
7+
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
8+
from torch_tensorrt.dynamo.conversion.converter_utils import cast_trt_tensor
9+
from torch_tensorrt.fx.converters.converter_utils import set_layer_name
10+
from torch_tensorrt.fx.types import TRTNetwork, TRTTensor
11+
12+
# nearest, linear, cubic
13+
GridSamplerInterpolationMode = {
14+
0: trt.InterpolationMode.NEAREST,
15+
1: trt.InterpolationMode.LINEAR,
16+
2: trt.InterpolationMode.CUBIC,
17+
}
18+
19+
# zeros, border, reflection
20+
GridSamplerSampling = {
21+
0: trt.SampleMode.FILL,
22+
1: trt.SampleMode.CLAMP,
23+
2: trt.SampleMode.REFLECT,
24+
}
25+
26+
27+
def grid(
28+
ctx: ConversionContext,
29+
target: Target,
30+
source_ir: Optional[SourceIR],
31+
name: str,
32+
input: TRTTensor,
33+
grid: TRTTensor,
34+
interpolation_mode: int,
35+
padding_mode: int,
36+
align_corners: bool,
37+
) -> TRTTensor:
38+
grid_layer = ctx.net.add_grid_sample(input, grid)
39+
assert interpolation_mode in GridSamplerInterpolationMode
40+
grid_layer.interpolation_mode = GridSamplerInterpolationMode.get(
41+
interpolation_mode, None
42+
)
43+
assert padding_mode in GridSamplerSampling
44+
grid_layer.sample_mode = GridSamplerSampling.get(padding_mode, None)
45+
grid_layer.align_corners = align_corners
46+
set_layer_name(grid_layer, target, name + "_grid_layer", source_ir)
47+
return grid_layer.get_output(0)
Lines changed: 149 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,149 @@
1+
import pytest
2+
import torch
3+
import torch.nn as nn
4+
from .harness import DispatchTestCase
5+
from parameterized import parameterized
6+
from torch.testing._internal.common_utils import run_tests
7+
from torch_tensorrt import Input
8+
9+
grid_sampler_ops = [
10+
(
11+
"input_grid_interpolation_nearest_sample_fill",
12+
(lambda x, grid: torch.ops.aten.grid_sampler(x, grid, 0, 0, True)),
13+
[1, 1, 5, 5],
14+
[1, 5, 2, 2],
15+
),
16+
(
17+
"input_grid_interpolation_nearest_sample_clamp",
18+
(lambda x, grid: torch.ops.aten.grid_sampler(x, grid, 0, 1, True)),
19+
[1, 1, 5, 5],
20+
[1, 5, 2, 2],
21+
),
22+
(
23+
"input_grid_interpolation_nearest_sample_reflect",
24+
(lambda x, grid: torch.ops.aten.grid_sampler(x, grid, 0, 2, True)),
25+
[1, 1, 5, 5],
26+
[1, 5, 2, 2],
27+
),
28+
(
29+
"input_grid_interpolation_linear_sample_fill",
30+
(lambda x, grid: torch.ops.aten.grid_sampler(x, grid, 0, 0, True)),
31+
[1, 1, 5, 5],
32+
[1, 5, 2, 2],
33+
),
34+
(
35+
"input_grid_interpolation_linear_sample_clamp",
36+
(lambda x, grid: torch.ops.aten.grid_sampler(x, grid, 0, 1, True)),
37+
[1, 1, 5, 5],
38+
[1, 5, 2, 2],
39+
),
40+
(
41+
"input_grid_interpolation_linear_sample_reflect",
42+
(lambda x, grid: torch.ops.aten.grid_sampler(x, grid, 0, 2, True)),
43+
[1, 1, 5, 5],
44+
[1, 5, 2, 2],
45+
),
46+
(
47+
"input_grid_interpolation_cubic_sample_fill",
48+
(lambda x, grid: torch.ops.aten.grid_sampler(x, grid, 0, 0, True)),
49+
[1, 1, 5, 5],
50+
[1, 5, 2, 2],
51+
),
52+
(
53+
"input_grid_interpolation_cubic_sample_clamp",
54+
(lambda x, grid: torch.ops.aten.grid_sampler(x, grid, 0, 1, True)),
55+
[1, 1, 5, 5],
56+
[1, 5, 2, 2],
57+
),
58+
(
59+
"input_grid_interpolation_cubic_sample_reflect",
60+
(lambda x, grid: torch.ops.aten.grid_sampler(x, grid, 0, 2, True)),
61+
[1, 1, 5, 5],
62+
[1, 5, 2, 2],
63+
),
64+
(
65+
"input_grid_interpolation_nearest_sample_fill_2d",
66+
(lambda x, grid: torch.ops.aten.grid_sampler_2d(x, grid, 0, 0, True)),
67+
[1, 1, 5, 5],
68+
[1, 5, 2, 2],
69+
),
70+
(
71+
"input_grid_interpolation_nearest_sample_clamp_2d",
72+
(lambda x, grid: torch.ops.aten.grid_sampler_2d(x, grid, 0, 1, True)),
73+
[1, 1, 5, 5],
74+
[1, 5, 2, 2],
75+
),
76+
(
77+
"input_grid_interpolation_nearest_sample_reflect_2d",
78+
(lambda x, grid: torch.ops.aten.grid_sampler_2d(x, grid, 0, 2, True)),
79+
[1, 1, 5, 5],
80+
[1, 5, 2, 2],
81+
),
82+
(
83+
"input_grid_interpolation_linear_sample_fill_2d",
84+
(lambda x, grid: torch.ops.aten.grid_sampler_2d(x, grid, 0, 0, True)),
85+
[1, 1, 5, 5],
86+
[1, 5, 2, 2],
87+
),
88+
(
89+
"input_grid_interpolation_linear_sample_clamp_2d",
90+
(lambda x, grid: torch.ops.aten.grid_sampler_2d(x, grid, 0, 1, True)),
91+
[1, 1, 5, 5],
92+
[1, 5, 2, 2],
93+
),
94+
(
95+
"input_grid_interpolation_linear_sample_reflect_2d",
96+
(lambda x, grid: torch.ops.aten.grid_sampler_2d(x, grid, 0, 2, True)),
97+
[1, 1, 5, 5],
98+
[1, 5, 2, 2],
99+
),
100+
(
101+
"input_grid_interpolation_cubic_sample_fill_2d",
102+
(lambda x, grid: torch.ops.aten.grid_sampler_2d(x, grid, 0, 0, True)),
103+
[1, 1, 5, 5],
104+
[1, 5, 2, 2],
105+
),
106+
(
107+
"input_grid_interpolation_cubic_sample_clamp_2d",
108+
(lambda x, grid: torch.ops.aten.grid_sampler_2d(x, grid, 0, 1, True)),
109+
[1, 1, 5, 5],
110+
[1, 5, 2, 2],
111+
),
112+
(
113+
"input_grid_interpolation_cubic_sample_reflect_2d",
114+
(lambda x, grid: torch.ops.aten.grid_sampler_2d(x, grid, 0, 2, True)),
115+
[1, 1, 5, 5],
116+
[1, 5, 2, 2],
117+
),
118+
]
119+
120+
121+
class TestGridConverter(DispatchTestCase):
122+
@parameterized.expand(
123+
[
124+
(
125+
grid_sampler_op[0],
126+
grid_sampler_op[1],
127+
grid_sampler_op[2],
128+
grid_sampler_op[3],
129+
)
130+
for grid_sampler_op in grid_sampler_ops
131+
]
132+
)
133+
def test_grid(self, _, op, input_shape, dim_shape):
134+
class TestModule(nn.Module):
135+
def __init__(self, grid_sampler_op):
136+
super().__init__()
137+
self.grid_sampler_op = grid_sampler_op
138+
139+
def forward(self, x):
140+
grid = torch.randint(-1, 1, dim_shape, dtype=torch.float32)
141+
return self.grid_sampler_op(x, grid)
142+
143+
inputs = [torch.randn(input_shape, dtype=torch.float32)]
144+
grid_model = TestModule(op)
145+
self.run_test(grid_model, inputs)
146+
147+
148+
if __name__ == "__main__":
149+
run_tests()

0 commit comments

Comments
 (0)