Skip to content

Commit b298710

Browse files
committed
Adding grid_sampler 2d cases (no 3d cases)
1 parent 5c49595 commit b298710

File tree

2 files changed

+139
-67
lines changed

2 files changed

+139
-67
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -332,7 +332,8 @@ def aten_ops_fmod(
332332

333333
@dynamo_tensorrt_converter(torch.ops.aten.grid_sampler) # type: ignore[misc]
334334
@dynamo_tensorrt_converter(torch.ops.aten.grid_sampler_2d) # type: ignore[misc]
335-
@dynamo_tensorrt_converter(torch.ops.aten.grid_sampler_3d) # type: ignore[misc]
335+
# commented this for now, see py/dynamo/conversion/tests/test_grid_aten. Should this be removed altogether?
336+
# @dynamo_tensorrt_converter(torch.ops.aten.grid_sampler_3d) # type: ignore[misc]
336337
@enforce_tensor_types(
337338
{
338339
0: (TRTTensor,),

tests/py/dynamo/conversion/test_grid_aten.py

Lines changed: 137 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -1,88 +1,159 @@
11
import pytest
22
import torch
33
import torch.nn as nn
4-
from .harness import DispatchTestCase
4+
from harness import DispatchTestCase
55
from parameterized import parameterized
66
from torch.testing._internal.common_utils import run_tests
77
from torch_tensorrt import Input
88

9+
grid_sampler_ops = [
10+
(
11+
"input_grid_interpolation_nearest_sample_fill",
12+
(lambda x, grid: torch.ops.aten.grid_sampler(x, grid, 0, 0, True)),
13+
[1, 1, 5, 5],
14+
[1, 5, 2, 2],
15+
),
16+
(
17+
"input_grid_interpolation_nearest_sample_clamp",
18+
(lambda x, grid: torch.ops.aten.grid_sampler(x, grid, 0, 1, True)),
19+
[1, 1, 5, 5],
20+
[1, 5, 2, 2],
21+
),
22+
(
23+
"input_grid_interpolation_nearest_sample_reflect",
24+
(lambda x, grid: torch.ops.aten.grid_sampler(x, grid, 0, 2, True)),
25+
[1, 1, 5, 5],
26+
[1, 5, 2, 2],
27+
),
28+
(
29+
"input_grid_interpolation_linear_sample_fill",
30+
(lambda x, grid: torch.ops.aten.grid_sampler(x, grid, 0, 0, True)),
31+
[1, 1, 5, 5],
32+
[1, 5, 2, 2],
33+
),
34+
(
35+
"input_grid_interpolation_linear_sample_clamp",
36+
(lambda x, grid: torch.ops.aten.grid_sampler(x, grid, 0, 1, True)),
37+
[1, 1, 5, 5],
38+
[1, 5, 2, 2],
39+
),
40+
(
41+
"input_grid_interpolation_linear_sample_reflect",
42+
(lambda x, grid: torch.ops.aten.grid_sampler(x, grid, 0, 2, True)),
43+
[1, 1, 5, 5],
44+
[1, 5, 2, 2],
45+
),
46+
(
47+
"input_grid_interpolation_cubic_sample_fill",
48+
(lambda x, grid: torch.ops.aten.grid_sampler(x, grid, 0, 0, True)),
49+
[1, 1, 5, 5],
50+
[1, 5, 2, 2],
51+
),
52+
(
53+
"input_grid_interpolation_cubic_sample_clamp",
54+
(lambda x, grid: torch.ops.aten.grid_sampler(x, grid, 0, 1, True)),
55+
[1, 1, 5, 5],
56+
[1, 5, 2, 2],
57+
),
58+
(
59+
"input_grid_interpolation_cubic_sample_reflect",
60+
(lambda x, grid: torch.ops.aten.grid_sampler(x, grid, 0, 2, True)),
61+
[1, 1, 5, 5],
62+
[1, 5, 2, 2],
63+
),
64+
(
65+
"input_grid_interpolation_nearest_sample_fill_2d",
66+
(lambda x, grid: torch.ops.aten.grid_sampler_2d(x, grid, 0, 0, True)),
67+
[1, 1, 5, 5],
68+
[1, 5, 2, 2],
69+
),
70+
(
71+
"input_grid_interpolation_nearest_sample_clamp_2d",
72+
(lambda x, grid: torch.ops.aten.grid_sampler_2d(x, grid, 0, 1, True)),
73+
[1, 1, 5, 5],
74+
[1, 5, 2, 2],
75+
),
76+
(
77+
"input_grid_interpolation_nearest_sample_reflect_2d",
78+
(lambda x, grid: torch.ops.aten.grid_sampler_2d(x, grid, 0, 2, True)),
79+
[1, 1, 5, 5],
80+
[1, 5, 2, 2],
81+
),
82+
(
83+
"input_grid_interpolation_linear_sample_fill_2d",
84+
(lambda x, grid: torch.ops.aten.grid_sampler_2d(x, grid, 0, 0, True)),
85+
[1, 1, 5, 5],
86+
[1, 5, 2, 2],
87+
),
88+
(
89+
"input_grid_interpolation_linear_sample_clamp_2d",
90+
(lambda x, grid: torch.ops.aten.grid_sampler_2d(x, grid, 0, 1, True)),
91+
[1, 1, 5, 5],
92+
[1, 5, 2, 2],
93+
),
94+
(
95+
"input_grid_interpolation_linear_sample_reflect_2d",
96+
(lambda x, grid: torch.ops.aten.grid_sampler_2d(x, grid, 0, 2, True)),
97+
[1, 1, 5, 5],
98+
[1, 5, 2, 2],
99+
),
100+
(
101+
"input_grid_interpolation_cubic_sample_fill_2d",
102+
(lambda x, grid: torch.ops.aten.grid_sampler_2d(x, grid, 0, 0, True)),
103+
[1, 1, 5, 5],
104+
[1, 5, 2, 2],
105+
),
106+
(
107+
"input_grid_interpolation_cubic_sample_clamp_2d",
108+
(lambda x, grid: torch.ops.aten.grid_sampler_2d(x, grid, 0, 1, True)),
109+
[1, 1, 5, 5],
110+
[1, 5, 2, 2],
111+
),
112+
(
113+
"input_grid_interpolation_cubic_sample_reflect_2d",
114+
(lambda x, grid: torch.ops.aten.grid_sampler_2d(x, grid, 0, 2, True)),
115+
[1, 1, 5, 5],
116+
[1, 5, 2, 2],
117+
),
118+
# The 3d cases with 4d input gives the error that it requires 5d input for both input and grid
119+
# The 5d input fails in the generation of the Grid Layer since the TensorRT layer requires 4d input
120+
# ("input_grid_interpolation_nearest_sample_fill_3d", (lambda x, grid: torch.ops.aten.grid_sampler_3d(x, grid, 0, 0, True)), [1, 1, 5, 5, 5], [1, 5, 5, 2, 2]),
121+
# ("input_grid_interpolation_nearest_sample_clamp_3d", (lambda x, grid: torch.ops.aten.grid_sampler_3d(x, grid, 0, 1, True)), [1, 1, 5, 5, 5], [1, 5, 5, 2, 2]),
122+
# ("input_grid_interpolation_nearest_sample_reflect_3d", (lambda x, grid: torch.ops.aten.grid_sampler_3d(x, grid, 0, 2, True)), [1, 1, 5, 5, 5], [1, 5, 5, 2, 2]),
123+
# ("input_grid_interpolation_linear_sample_fill_3d", (lambda x, grid: torch.ops.aten.grid_sampler_3d(x, grid, 0, 0, True)), [1, 1, 5, 5, 5], [1, 5, 5, 2, 2]),
124+
# ("input_grid_interpolation_linear_sample_clamp_3d", (lambda x, grid: torch.ops.aten.grid_sampler_3d(x, grid, 0, 1, True)), [1, 1, 5, 5, 5], [1, 5, 5, 2, 2]),
125+
# ("input_grid_interpolation_linear_sample_reflect_3d", (lambda x, grid: torch.ops.aten.grid_sampler_3d(x, grid, 0, 2, True)), [1, 1, 5, 5, 5], [1, 5, 5, 2, 2]),
126+
# ("input_grid_interpolation_cubic_sample_fill_3d", (lambda x, grid: torch.ops.aten.grid_sampler_3d(x, grid, 0, 0, True)), [1, 1, 5, 5, 5], [1, 5, 5, 2, 2]),
127+
# ("input_grid_interpolation_cubic_sample_clamp_3d", (lambda x, grid: torch.ops.aten.grid_sampler_3d(x, grid, 0, 1, True)), [1, 1, 5, 5, 5], [1, 5, 5, 2, 2]),
128+
# ("input_grid_interpolation_cubic_sample_reflect_3d", (lambda x, grid: torch.ops.aten.grid_sampler_3d(x, grid, 0, 2, True)), [1, 1, 5, 5, 5], [1, 5, 5, 2, 2]),
129+
]
130+
9131

10132
class TestGridConverter(DispatchTestCase):
11133
@parameterized.expand(
12134
[
13135
(
14-
"input_grid_interpolation_nearest_sample_fill",
15-
[1, 1, 5, 5],
16-
[1, 5, 2, 2],
17-
0,
18-
0,
19-
),
20-
(
21-
"input_grid_interpolation_nearest_sample_clamp",
22-
[1, 1, 5, 5],
23-
[1, 5, 2, 2],
24-
0,
25-
1,
26-
),
27-
(
28-
"input_grid_interpolation_nearest_sample_reflect",
29-
[1, 1, 5, 5],
30-
[1, 5, 2, 2],
31-
0,
32-
2,
33-
),
34-
(
35-
"input_grid_interpolation_linear_sample_fill",
36-
[1, 1, 5, 5],
37-
[1, 5, 2, 2],
38-
1,
39-
0,
40-
),
41-
(
42-
"input_grid_interpolation_linear_sample_clamp",
43-
[1, 1, 5, 5],
44-
[1, 5, 2, 2],
45-
1,
46-
1,
47-
),
48-
(
49-
"input_grid_interpolation_linear_sample_reflect",
50-
[1, 1, 5, 5],
51-
[1, 5, 2, 2],
52-
1,
53-
2,
54-
),
55-
(
56-
"input_grid_interpolation_cubic_sample_fill",
57-
[1, 1, 5, 5],
58-
[1, 5, 2, 2],
59-
2,
60-
0,
61-
),
62-
(
63-
"input_grid_interpolation_cubic_sample_clamp",
64-
[1, 1, 5, 5],
65-
[1, 5, 2, 2],
66-
2,
67-
1,
68-
),
69-
(
70-
"input_grid_interpolation_cubic_sample_reflect",
71-
[1, 1, 5, 5],
72-
[1, 5, 2, 2],
73-
2,
74-
2,
75-
),
136+
grid_sampler_op[0],
137+
grid_sampler_op[1],
138+
grid_sampler_op[2],
139+
grid_sampler_op[3],
140+
)
141+
for grid_sampler_op in grid_sampler_ops
76142
]
77143
)
78-
def test_grid(self, _, input_shape, dim_shape, interpolation, sample):
144+
def test_grid(self, _, op, input_shape, dim_shape):
79145
class TestModule(nn.Module):
146+
def __init__(self, grid_sampler_op):
147+
super().__init__()
148+
self.grid_sampler_op = grid_sampler_op
149+
80150
def forward(self, x):
81151
grid = torch.randint(-1, 1, dim_shape, dtype=torch.float32)
82-
return torch.ops.aten.grid_sampler(x, grid, interpolation, sample, True)
152+
return self.grid_sampler_op(x, grid)
83153

84154
inputs = [torch.randn(input_shape, dtype=torch.float32)]
85-
self.run_test(TestModule(), inputs)
155+
grid_model = TestModule(op)
156+
self.run_test(grid_model, inputs)
86157

87158

88159
if __name__ == "__main__":

0 commit comments

Comments
 (0)