1
+ #include < ATen/ATen.h>
2
+ #include < vector>
3
+ #include " NvInfer.h"
4
+ #include " core/conversion/converters/converters.h"
5
+ #include " core/util/prelude.h"
6
+ #include " torch/torch.h"
7
+
8
+ namespace trtorch {
9
+ namespace core {
10
+ namespace conversion {
11
+ namespace converters {
12
+ namespace impl {
13
+ namespace {
14
+
15
+ auto constant_pad_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns().pattern(
16
+ {" aten::constant_pad_nd(Tensor self, int[] pad, Scalar value=0) -> (Tensor)" ,
17
+ [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
18
+ auto in = args[0 ].ITensor ();
19
+ auto inDims = in->getDimensions ();
20
+ int64_t inRank = inDims.nbDims ;
21
+ auto padding = args[1 ].unwrapToIntList ().vec ();
22
+ int64_t padSize = padding.size ();
23
+ auto value = args[2 ].unwrapToScalar ().to <float >();
24
+
25
+ TRTORCH_CHECK (padSize % 2 == 0 , " Length of pad must be even but instead it equals " << padSize);
26
+
27
+ int64_t l_pad = padSize / 2 ;
28
+ TRTORCH_CHECK (
29
+ inRank >= (int64_t )l_pad,
30
+ " Length of pad should be no more than twice the number of "
31
+ " dimensions of the input. Pad length is "
32
+ << padSize << " while the input has " << inRank << " dimensions." );
33
+
34
+ // TODO negative padding. When the pad is negative, we need to crop the image.
35
+
36
+ std::vector<nvinfer1::ITensor*> tensors_vec;
37
+ // input: (N, C, D_in, H_in, W_in).
38
+ // padding: (padding_left, padding_right, padding_top, padding_bottom, padding_front, padding_back)
39
+ // When axis is inRank - 1, making W_out = W_in + padding_left + padding_right.
40
+ // When axis is inRank - 2, making H_out = H_in + padding_top + padding_bottom.
41
+ // When axis is inRank - 3, making D_out = D_in + padding_front + padding_back.
42
+ for (int64_t i = 0 ; i < l_pad; i++) {
43
+ int64_t axis = inRank - (i + 1 ); // axis = {inRank - 1, inRank - 2, inRank - 3}
44
+ int64_t padding_index = i * 2 ;
45
+
46
+ if (padding[padding_index] > 0 ) { // left/top/front padding value
47
+ tensors_vec.clear ();
48
+ if (ctx->input_is_dynamic ) {
49
+ at::Tensor left_indices = torch::tensor ({0 }, torch::kInt32 );
50
+ auto indicesTensor = tensor_to_const (ctx, left_indices);
51
+ auto left_gather_layer = ctx->net ->addGather (*in, *indicesTensor, axis);
52
+ auto left_gather_out = left_gather_layer->getOutput (0 );
53
+
54
+ // fill the left_gather_out with value
55
+ auto fill_layer = ctx->net ->addFill (nvinfer1::Dims{1 , {1 }}, nvinfer1::FillOperation::kLINSPACE );
56
+ auto shape_gather_out = ctx->net ->addShape (*left_gather_out)->getOutput (0 );
57
+ fill_layer->setInput (0 , *shape_gather_out);
58
+ at::Tensor value_tensor = torch::tensor (value, torch::kFloat32 );
59
+ auto valueTensor = tensor_to_const (ctx, value_tensor);
60
+ fill_layer->setInput (1 , *valueTensor);
61
+ at::Tensor delta_tensor = torch::zeros (inRank);
62
+ auto deltaTensor = tensor_to_const (ctx, delta_tensor);
63
+ fill_layer->setInput (2 , *deltaTensor);
64
+ auto padTensor = fill_layer->getOutput (0 );
65
+
66
+ for (int i = 0 ; i < padding[padding_index]; i++) {
67
+ tensors_vec.push_back (padTensor);
68
+ }
69
+ } else {
70
+ inDims.d [axis] = padding[padding_index];
71
+ auto fill_layer = ctx->net ->addFill (inDims, nvinfer1::FillOperation::kLINSPACE );
72
+ at::Tensor value_tensor = torch::tensor (value, torch::kFloat32 );
73
+ auto valueTensor = tensor_to_const (ctx, value_tensor);
74
+ fill_layer->setInput (1 , *valueTensor);
75
+ at::Tensor delta_tensor = torch::zeros (inRank);
76
+ auto deltaTensor = tensor_to_const (ctx, delta_tensor);
77
+ fill_layer->setInput (2 , *deltaTensor);
78
+ auto padTensor = fill_layer->getOutput (0 );
79
+
80
+ tensors_vec.push_back (padTensor);
81
+ }
82
+
83
+ tensors_vec.push_back (in);
84
+ auto concat_layer = ctx->net ->addConcatenation (tensors_vec.data (), tensors_vec.size ());
85
+ concat_layer->setAxis (axis);
86
+ in = concat_layer->getOutput (0 );
87
+ inDims = in->getDimensions ();
88
+ }
89
+
90
+ if (padding[padding_index + 1 ] > 0 ) { // right/bottom/back padding value
91
+ tensors_vec.clear ();
92
+ tensors_vec.push_back (in);
93
+
94
+ nvinfer1::ITensor* indicesTensor = NULL ;
95
+ if (inDims.d [axis] == -1 ) {
96
+ auto shapeTensor = ctx->net ->addShape (*in)->getOutput (0 );
97
+ at::Tensor dimValue = torch::tensor ({axis}, torch::kInt32 );
98
+ auto dimTensor = tensor_to_const (ctx, dimValue);
99
+ indicesTensor = ctx->net ->addGather (*shapeTensor, *dimTensor, 0 )->getOutput (0 );
100
+ auto oneTensor = tensor_to_const (ctx, torch::tensor ({1 }, torch::kInt32 ));
101
+ indicesTensor = ctx->net ->addElementWise (*indicesTensor, *oneTensor, nvinfer1::ElementWiseOperation::kSUB )
102
+ ->getOutput (0 );
103
+ } else {
104
+ auto indices = torch::tensor ({inDims.d [axis] - 1 }, torch::kInt32 );
105
+ indicesTensor = tensor_to_const (ctx, indices);
106
+ }
107
+ auto right_gather_layer = ctx->net ->addGather (*in, *indicesTensor, axis);
108
+ auto right_gather_out = right_gather_layer->getOutput (0 );
109
+
110
+ if (ctx->input_is_dynamic ) {
111
+ // fill the right_gather_out with value
112
+ auto fill_layer = ctx->net ->addFill (nvinfer1::Dims{1 , {1 }}, nvinfer1::FillOperation::kLINSPACE );
113
+ auto shape_gather_out = ctx->net ->addShape (*right_gather_out)->getOutput (0 );
114
+ fill_layer->setInput (0 , *shape_gather_out);
115
+ at::Tensor value_tensor = torch::tensor (value, torch::kFloat32 );
116
+ auto valueTensor = tensor_to_const (ctx, value_tensor);
117
+ fill_layer->setInput (1 , *valueTensor);
118
+ at::Tensor delta_tensor = torch::zeros (inRank);
119
+ auto deltaTensor = tensor_to_const (ctx, delta_tensor);
120
+ fill_layer->setInput (2 , *deltaTensor);
121
+ auto padTensor = fill_layer->getOutput (0 );
122
+
123
+ for (int i = 0 ; i < padding[padding_index + 1 ]; i++) {
124
+ tensors_vec.push_back (padTensor);
125
+ }
126
+ } else {
127
+ inDims.d [axis] = padding[padding_index + 1 ];
128
+ auto fill_layer = ctx->net ->addFill (inDims, nvinfer1::FillOperation::kLINSPACE );
129
+ at::Tensor value_tensor = torch::tensor (value, torch::kFloat32 );
130
+ auto valueTensor = tensor_to_const (ctx, value_tensor);
131
+ fill_layer->setInput (1 , *valueTensor);
132
+ at::Tensor delta_tensor = torch::zeros (inRank);
133
+ auto deltaTensor = tensor_to_const (ctx, delta_tensor);
134
+ fill_layer->setInput (2 , *deltaTensor);
135
+ auto padTensor = fill_layer->getOutput (0 );
136
+
137
+ tensors_vec.push_back (padTensor);
138
+ }
139
+ auto concat_layer = ctx->net ->addConcatenation (tensors_vec.data (), tensors_vec.size ());
140
+ concat_layer->setAxis (axis);
141
+ in = concat_layer->getOutput (0 );
142
+ inDims = in->getDimensions ();
143
+ }
144
+ }
145
+
146
+ auto out = ctx->AssociateValueAndTensor (n->outputs ()[0 ], in);
147
+ LOG_DEBUG (" Output tensor shape: " << out->getDimensions ());
148
+ return true ;
149
+ }});
150
+
151
+ } // namespace
152
+ } // namespace impl
153
+ } // namespace converters
154
+ } // namespace conversion
155
+ } // namespace core
156
+ } // namespace trtorch
0 commit comments