Skip to content

Commit 5c6905d

Browse files
committed
Addressing review comments
1 parent 13319d8 commit 5c6905d

File tree

2 files changed

+26
-45
lines changed

2 files changed

+26
-45
lines changed

py/torch_tensorrt/dynamo/conversion/converter_utils.py

Lines changed: 0 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -25,36 +25,6 @@
2525
_LOGGER: logging.Logger = logging.getLogger(__name__)
2626

2727

28-
# nearest, linear, cubic
29-
class GridSamplerInterpolation:
30-
def __init__(self):
31-
self.interpolator_mode = None
32-
33-
def __call__(self, interpolator_int):
34-
if interpolator_int == 0:
35-
self.interpolator_mode = trt.InterpolationMode.NEAREST
36-
elif interpolator_int == 1:
37-
self.interpolator_mode = trt.InterpolationMode.LINEAR
38-
elif interpolator_int == 2:
39-
self.interpolator_mode = trt.InterpolationMode.CUBIC
40-
return self.interpolator_mode
41-
42-
43-
# zeros, border, reflection
44-
class GridSamplerSampling:
45-
def __init__(self):
46-
self.sample_mode = None
47-
48-
def __call__(self, sample_int):
49-
if sample_int == 0:
50-
self.sample_mode = trt.SampleMode.FILL
51-
elif sample_int == 1:
52-
self.sample_mode = trt.SampleMode.CLAMP
53-
elif sample_int == 2:
54-
self.sample_mode = trt.SampleMode.REFLECT
55-
return self.sample_mode
56-
57-
5828
def get_node_name(node: torch.fx.Node) -> str:
5929
# nn_module_stack preserves the call stack of pytorch nn.modules
6030
# The call stack contains a detailed name of the module

py/torch_tensorrt/dynamo/conversion/impl/grid.py

Lines changed: 26 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,24 @@
55
from torch.fx.node import Target
66
from torch_tensorrt.dynamo._SourceIR import SourceIR
77
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
8-
from torch_tensorrt.dynamo.conversion.converter_utils import (
9-
GridSamplerInterpolation,
10-
GridSamplerSampling,
11-
cast_trt_tensor,
12-
)
8+
from torch_tensorrt.dynamo.conversion.converter_utils import cast_trt_tensor
139
from torch_tensorrt.fx.converters.converter_utils import set_layer_name
1410
from torch_tensorrt.fx.types import TRTNetwork, TRTTensor
1511

12+
# nearest, linear, cubic
13+
GridSamplerInterpolationMode = {
14+
0: trt.InterpolationMode.NEAREST,
15+
1: trt.InterpolationMode.LINEAR,
16+
2: trt.InterpolationMode.CUBIC,
17+
}
18+
19+
# zeros, border, reflection
20+
GridSamplerSampling = {
21+
0: trt.SampleMode.FILL,
22+
1: trt.SampleMode.CLAMP,
23+
2: trt.SampleMode.REFLECT,
24+
}
25+
1626

1727
def grid(
1828
ctx: ConversionContext,
@@ -27,18 +37,19 @@ def grid(
2737
output_mask: Optional[Sequence[bool]] = None,
2838
) -> TRTTensor:
2939
grid_layer = ctx.net.add_grid_sample(input, grid)
30-
interpolation_mode_trt = GridSamplerInterpolation()
31-
grid_layer.interpolation_mode = interpolation_mode_trt(interpolation_mode)
32-
sample_mode_trt = GridSamplerSampling()
33-
grid_layer.sample_mode = sample_mode_trt(padding_mode)
40+
assert interpolation_mode in GridSamplerInterpolationMode
41+
grid_layer.interpolation_mode = GridSamplerInterpolationMode.get(
42+
interpolation_mode, None
43+
)
44+
assert padding_mode in GridSamplerSampling
45+
grid_layer.sample_mode = GridSamplerSampling.get(padding_mode, None)
3446
grid_layer.align_corners = align_corners
3547
set_layer_name(grid_layer, target, name + "_grid_layer", source_ir)
3648
if output_mask is None:
3749
return grid_layer.get_output(0)
50+
elif output_mask[0] and output_mask[1]:
51+
return (grid_layer.get_output(0), None)
52+
elif output_mask[0]:
53+
return grid_layer.get_output(0)
3854
else:
39-
if output_mask[0] and output_mask[1]:
40-
return (grid_layer.get_output(0), None)
41-
elif output_mask[0]:
42-
return grid_layer.get_output(0)
43-
else:
44-
return None
55+
return None

0 commit comments

Comments
 (0)