5
5
from torch .fx .node import Target
6
6
from torch_tensorrt .dynamo ._SourceIR import SourceIR
7
7
from torch_tensorrt .dynamo .conversion ._ConversionContext import ConversionContext
8
- from torch_tensorrt .dynamo .conversion .converter_utils import (
9
- GridSamplerInterpolation ,
10
- GridSamplerSampling ,
11
- cast_trt_tensor ,
12
- )
8
+ from torch_tensorrt .dynamo .conversion .converter_utils import cast_trt_tensor
13
9
from torch_tensorrt .fx .converters .converter_utils import set_layer_name
14
10
from torch_tensorrt .fx .types import TRTNetwork , TRTTensor
15
11
12
+ # nearest, linear, cubic
13
+ GridSamplerInterpolationMode = {
14
+ 0 : trt .InterpolationMode .NEAREST ,
15
+ 1 : trt .InterpolationMode .LINEAR ,
16
+ 2 : trt .InterpolationMode .CUBIC ,
17
+ }
18
+
19
+ # zeros, border, reflection
20
+ GridSamplerSampling = {
21
+ 0 : trt .SampleMode .FILL ,
22
+ 1 : trt .SampleMode .CLAMP ,
23
+ 2 : trt .SampleMode .REFLECT ,
24
+ }
25
+
16
26
17
27
def grid (
18
28
ctx : ConversionContext ,
@@ -27,18 +37,19 @@ def grid(
27
37
output_mask : Optional [Sequence [bool ]] = None ,
28
38
) -> TRTTensor :
29
39
grid_layer = ctx .net .add_grid_sample (input , grid )
30
- interpolation_mode_trt = GridSamplerInterpolation ()
31
- grid_layer .interpolation_mode = interpolation_mode_trt (interpolation_mode )
32
- sample_mode_trt = GridSamplerSampling ()
33
- grid_layer .sample_mode = sample_mode_trt (padding_mode )
40
+ assert interpolation_mode in GridSamplerInterpolationMode
41
+ grid_layer .interpolation_mode = GridSamplerInterpolationMode .get (
42
+ interpolation_mode , None
43
+ )
44
+ assert padding_mode in GridSamplerSampling
45
+ grid_layer .sample_mode = GridSamplerSampling .get (padding_mode , None )
34
46
grid_layer .align_corners = align_corners
35
47
set_layer_name (grid_layer , target , name + "_grid_layer" , source_ir )
36
48
if output_mask is None :
37
49
return grid_layer .get_output (0 )
50
+ elif output_mask [0 ] and output_mask [1 ]:
51
+ return (grid_layer .get_output (0 ), None )
52
+ elif output_mask [0 ]:
53
+ return grid_layer .get_output (0 )
38
54
else :
39
- if output_mask [0 ] and output_mask [1 ]:
40
- return (grid_layer .get_output (0 ), None )
41
- elif output_mask [0 ]:
42
- return grid_layer .get_output (0 )
43
- else :
44
- return None
55
+ return None
0 commit comments