Skip to content

Commit 13319d8

Browse files
committed
Grid test changes
1 parent e39b69e commit 13319d8

File tree

4 files changed

+139
-49
lines changed

4 files changed

+139
-49
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -330,20 +330,37 @@ def aten_ops_fmod(
330330
return impl.elementwise.fmod(ctx, target, SourceIR.ATEN, name, args[0], args[1])
331331

332332

333-
@dynamo_tensorrt_converter(torch.ops.aten.grid_sampler.out)
334-
@dynamo_tensorrt_converter(torch.ops.aten.grid_sampler_backward.out)
333+
@dynamo_tensorrt_converter(torch.ops.aten.grid_sampler)
335334
@dynamo_tensorrt_converter(torch.ops.aten.grid_sampler_2d.out)
336335
@dynamo_tensorrt_converter(torch.ops.aten.grid_sampler_2d_backward.out)
337336
@dynamo_tensorrt_converter(torch.ops.aten.grid_sampler_3d.out)
338337
@dynamo_tensorrt_converter(torch.ops.aten.grid_sampler_3d_backward.out)
338+
@enforce_tensor_types(
339+
{
340+
0: (TRTTensor,),
341+
1: (TRTTensor,),
342+
}
343+
) # type: ignore[misc]
339344
def aten_ops_grid(
340345
ctx: ConversionContext,
341346
target: Target,
342347
args: Tuple[Argument, ...],
343348
kwargs: Dict[str, Argument],
344349
name: str,
345350
) -> Union[TRTTensor, Sequence[TRTTensor]]:
346-
return impl.grid.grid(ctx, target, SourceIR.ATEN, name, args[0], args[1], args[2], args[3], args[4])
351+
return impl.grid.grid(
352+
ctx,
353+
target,
354+
SourceIR.ATEN,
355+
name,
356+
input=args[0],
357+
grid=args[1],
358+
interpolation_mode=args[2],
359+
padding_mode=args[3],
360+
align_corners=args_bounds_check(args, 4, True),
361+
output_mask=args_bounds_check(args, 5, None),
362+
363+
)
347364

348365

349366
@dynamo_tensorrt_converter(torch.ops.aten.relu.default)

py/torch_tensorrt/dynamo/conversion/converter_utils.py

Lines changed: 21 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -24,32 +24,36 @@
2424

2525
_LOGGER: logging.Logger = logging.getLogger(__name__)
2626

27-
#nearesr, linear, cubc
27+
28+
# nearest, linear, cubic
2829
class GridSamplerInterpolation:
2930
def __init__(self):
3031
self.interpolator_mode = None
31-
def __call__(self, interpolator_int):
32-
if(interpolator_int == 0) :
32+
33+
def __call__(self, interpolator_int):
34+
if interpolator_int == 0:
3335
self.interpolator_mode = trt.InterpolationMode.NEAREST
34-
elif(interpolator_int == 1) :
36+
elif interpolator_int == 1:
3537
self.interpolator_mode = trt.InterpolationMode.LINEAR
36-
elif(interpolator_int == 2) :
38+
elif interpolator_int == 2:
3739
self.interpolator_mode = trt.InterpolationMode.CUBIC
3840
return self.interpolator_mode
39-
4041

41-
#zeros, border, reflection
42-
class GridSamplerPadding:
42+
43+
# zeros, border, reflection
44+
class GridSamplerSampling:
4345
def __init__(self):
44-
self.padding_mode = None
45-
def __call__(self, padding_int):
46-
if(padding_int == 0) :
47-
self.padding_mode = trt.SampleMode.kFILL
48-
elif(padding_int == 1) :
49-
self.padding_mode = trt.SampleMode.kCLAMP
50-
elif(padding_int == 2) :
51-
self.padding_mode = trt.SampleMode.kREFLECT
52-
return self.padding_mode
46+
self.sample_mode = None
47+
48+
def __call__(self, sample_int):
49+
if sample_int == 0:
50+
self.sample_mode = trt.SampleMode.FILL
51+
elif sample_int == 1:
52+
self.sample_mode = trt.SampleMode.CLAMP
53+
elif sample_int == 2:
54+
self.sample_mode = trt.SampleMode.REFLECT
55+
return self.sample_mode
56+
5357

5458
def get_node_name(node: torch.fx.Node) -> str:
5559
# nn_module_stack preserves the call stack of pytorch nn.modules
Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,21 @@
1-
from typing import Optional
1+
from typing import Optional, Sequence
22

3+
import tensorrt as trt
34
import torch
45
from torch.fx.node import Target
56
from torch_tensorrt.dynamo._SourceIR import SourceIR
6-
from torch_tensorrt.dynamo.conversion.converter_utils import GridSamplerInterpolation, GridSamplerPadding
7+
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
8+
from torch_tensorrt.dynamo.conversion.converter_utils import (
9+
GridSamplerInterpolation,
10+
GridSamplerSampling,
11+
cast_trt_tensor,
12+
)
713
from torch_tensorrt.fx.converters.converter_utils import set_layer_name
814
from torch_tensorrt.fx.types import TRTNetwork, TRTTensor
915

16+
1017
def grid(
11-
network: TRTNetwork,
18+
ctx: ConversionContext,
1219
target: Target,
1320
source_ir: Optional[SourceIR],
1421
name: str,
@@ -17,10 +24,21 @@ def grid(
1724
interpolation_mode: int,
1825
padding_mode: int,
1926
align_corners: bool,
27+
output_mask: Optional[Sequence[bool]] = None,
2028
) -> TRTTensor:
21-
grid_layer = network.add_grid_sample(input, grid)
22-
grid_layer.interpolation_mode = GridSamplerInterpolation(interpolation_mode)
23-
grid_layer.padding_mode = GridSamplerPadding(padding_mode)
29+
grid_layer = ctx.net.add_grid_sample(input, grid)
30+
interpolation_mode_trt = GridSamplerInterpolation()
31+
grid_layer.interpolation_mode = interpolation_mode_trt(interpolation_mode)
32+
sample_mode_trt = GridSamplerSampling()
33+
grid_layer.sample_mode = sample_mode_trt(padding_mode)
2434
grid_layer.align_corners = align_corners
2535
set_layer_name(grid_layer, target, name + "_grid_layer", source_ir)
26-
return grid_layer.get_output(0)
36+
if output_mask is None:
37+
return grid_layer.get_output(0)
38+
else:
39+
if output_mask[0] and output_mask[1]:
40+
return (grid_layer.get_output(0), None)
41+
elif output_mask[0]:
42+
return grid_layer.get_output(0)
43+
else:
44+
return None
Lines changed: 73 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,38 +1,89 @@
11
import pytest
22
import torch
33
import torch.nn as nn
4+
from .harness import DispatchTestCase
5+
from parameterized import parameterized
46
from torch.testing._internal.common_utils import run_tests
57
from torch_tensorrt import Input
6-
from parameterized import parameterized
7-
from .harness import DispatchTestCase
8+
89

910
class TestGridConverter(DispatchTestCase):
1011
@parameterized.expand(
1112
[
12-
("input_grid_interpolation_nearest_sample_fill", [5,5], [5,2], 0, 0),
13-
("input_grid_interpolation_nearest_sample_clamp", [5,5], [5,2], 0, 1),
14-
("input_grid_interpolation_nearest_sample_reflect", [5,5], [5,2], 0, 2),
15-
("input_grid_interpolation_linear_sample_fill", [5,5], [5,2], 1, 0),
16-
("input_grid_interpolation_linear_sample_clamp", [5,5], [5,2], 1, 1),
17-
("input_grid_interpolation_linear_sample_reflect", [5,5], [5,2], 1, 2),
18-
("input_grid_interpolation_cubic_sample_fill", [5,5], [5,2], 2, 0),
19-
("input_grid_interpolation_cubic_sample_clamp", [5,5], [5,2], 2, 1),
20-
("input_grid_interpolation_cubic_sample_reflect", [5,5], [5,2], 2, 2),
13+
(
14+
"input_grid_interpolation_nearest_sample_fill",
15+
[1, 1, 5, 5],
16+
[1, 5, 2, 2],
17+
0,
18+
0,
19+
),
20+
(
21+
"input_grid_interpolation_nearest_sample_clamp",
22+
[1, 1, 5, 5],
23+
[1, 5, 2, 2],
24+
0,
25+
1,
26+
),
27+
(
28+
"input_grid_interpolation_nearest_sample_reflect",
29+
[1, 1, 5, 5],
30+
[1, 5, 2, 2],
31+
0,
32+
2,
33+
),
34+
(
35+
"input_grid_interpolation_linear_sample_fill",
36+
[1, 1, 5, 5],
37+
[1, 5, 2, 2],
38+
1,
39+
0,
40+
),
41+
(
42+
"input_grid_interpolation_linear_sample_clamp",
43+
[1, 1, 5, 5],
44+
[1, 5, 2, 2],
45+
1,
46+
1,
47+
),
48+
(
49+
"input_grid_interpolation_linear_sample_reflect",
50+
[1, 1, 5, 5],
51+
[1, 5, 2, 2],
52+
1,
53+
2,
54+
),
55+
(
56+
"input_grid_interpolation_cubic_sample_fill",
57+
[1, 1, 5, 5],
58+
[1, 5, 2, 2],
59+
2,
60+
0,
61+
),
62+
(
63+
"input_grid_interpolation_cubic_sample_clamp",
64+
[1, 1, 5, 5],
65+
[1, 5, 2, 2],
66+
2,
67+
1,
68+
),
69+
(
70+
"input_grid_interpolation_cubic_sample_reflect",
71+
[1, 1, 5, 5],
72+
[1, 5, 2, 2],
73+
2,
74+
2,
75+
),
2176
]
2277
)
23-
def test_grid(self,_, input_shape, dim_shape, interpolation, sample):
78+
def test_grid(self, _, input_shape, dim_shape, interpolation, sample):
2479
class TestModule(nn.Module):
2580
def forward(self, x):
26-
input = torch.randn(10).reshape(input_shape)
27-
grid = torch.randint(-1, 1, dim_shape)
28-
return nn.functional.grid(input, grid, interpolation, sample)
29-
30-
inputs = [torch.randn(1, 10)]
31-
self.run_test(TestModule(), inputs, expected_ops={torch.ops.aten.grid_sampler.out})
32-
33-
81+
grid = torch.randint(-1, 1, dim_shape, dtype=torch.float32)
82+
return torch.ops.aten.grid_sampler(x, grid, interpolation, sample, True)
3483

84+
inputs = [torch.randn(input_shape, dtype=torch.float32)]
85+
self.run_test(TestModule(), inputs)
3586

36-
3787

38-
88+
if __name__ == "__main__":
89+
run_tests()

0 commit comments

Comments
 (0)