@@ -153,13 +153,14 @@ auto element_wise_registrations TRTORCH_UNUSED =
153
153
auto other = args[1 ].ITensorOrFreeze (ctx);
154
154
155
155
if (1 != scalar) {
156
- auto scaleW = Weights (ctx, scalar);
157
- auto unuse = Weights ();
158
- // IScaleLayer assert shift, scale and power to have
159
- // the same dtype
160
- auto scaleLayer = ctx->net ->addScale (
161
- *other, nvinfer1::ScaleMode::kUNIFORM , unuse.data , scaleW.data , unuse.data );
162
- TRTORCH_CHECK (scaleLayer, " Unable to create scale layer from node: " << *n);
156
+ auto alphaTensor = tensor_to_const (ctx, torch::tensor ({scalar}));
157
+ auto scaleLayer = add_elementwise (
158
+ ctx,
159
+ nvinfer1::ElementWiseOperation::kPROD ,
160
+ other,
161
+ alphaTensor,
162
+ util::node_info (n) + std::string (" _AlphaMultiplier" ));
163
+ TRTORCH_CHECK (scaleLayer, " Unable to create alpha*input layer from node: " << *n);
163
164
other = scaleLayer->getOutput (0 );
164
165
}
165
166
@@ -181,13 +182,14 @@ auto element_wise_registrations TRTORCH_UNUSED =
181
182
auto other = args[1 ].ITensorOrFreeze (ctx);
182
183
183
184
if (1 != scalar) {
184
- auto scaleW = Weights (ctx, scalar);
185
- auto unuse = Weights ();
186
- // IScaleLayer assert shift, scale and power to have
187
- // the same dtype
188
- auto scaleLayer = ctx->net ->addScale (
189
- *other, nvinfer1::ScaleMode::kUNIFORM , unuse.data , scaleW.data , unuse.data );
190
- TRTORCH_CHECK (scaleLayer, " Unable to create scale layer from node: " << *n);
185
+ auto alphaTensor = tensor_to_const (ctx, torch::tensor ({scalar}));
186
+ auto scaleLayer = add_elementwise (
187
+ ctx,
188
+ nvinfer1::ElementWiseOperation::kPROD ,
189
+ other,
190
+ alphaTensor,
191
+ util::node_info (n) + std::string (" _AlphaMultiplier" ));
192
+ TRTORCH_CHECK (scaleLayer, " Unable to create alpha*input layer from node: " << *n);
191
193
other = scaleLayer->getOutput (0 );
192
194
}
193
195
@@ -209,13 +211,14 @@ auto element_wise_registrations TRTORCH_UNUSED =
209
211
auto scalar = args[2 ].unwrapToScalar ().to <float >();
210
212
211
213
if (1 != scalar) {
212
- auto scaleW = Weights (ctx, scalar);
213
- auto unuse = Weights ();
214
- // IScaleLayer assert shift, scale and power to have
215
- // the same dtype
216
- auto scaleLayer =
217
- ctx->net ->addScale (*self, nvinfer1::ScaleMode::kUNIFORM , unuse.data , scaleW.data , unuse.data );
218
- TRTORCH_CHECK (scaleLayer, " Unable to create scale layer from node: " << *n);
214
+ auto alphaTensor = tensor_to_const (ctx, torch::tensor ({scalar}));
215
+ auto scaleLayer = add_elementwise (
216
+ ctx,
217
+ nvinfer1::ElementWiseOperation::kPROD ,
218
+ self,
219
+ alphaTensor,
220
+ util::node_info (n) + std::string (" _AlphaMultiplier" ));
221
+ TRTORCH_CHECK (scaleLayer, " Unable to create alpha*input layer from node: " << *n);
219
222
self = scaleLayer->getOutput (0 );
220
223
}
221
224
@@ -236,13 +239,14 @@ auto element_wise_registrations TRTORCH_UNUSED =
236
239
auto scalar = args[2 ].unwrapToScalar ().to <float >();
237
240
238
241
if (1 != scalar) {
239
- auto scaleW = Weights (ctx, scalar);
240
- auto unuse = Weights ();
241
- // IScaleLayer assert shift, scale and power to have
242
- // the same dtype
243
- auto scaleLayer =
244
- ctx->net ->addScale (*self, nvinfer1::ScaleMode::kUNIFORM , unuse.data , scaleW.data , unuse.data );
245
- TRTORCH_CHECK (scaleLayer, " Unable to create scale layer from node: " << *n);
242
+ auto alphaTensor = tensor_to_const (ctx, torch::tensor ({scalar}));
243
+ auto scaleLayer = add_elementwise (
244
+ ctx,
245
+ nvinfer1::ElementWiseOperation::kPROD ,
246
+ self,
247
+ alphaTensor,
248
+ util::node_info (n) + std::string (" _AlphaMultiplier" ));
249
+ TRTORCH_CHECK (scaleLayer, " Unable to create alpha*input layer from node: " << *n);
246
250
self = scaleLayer->getOutput (0 );
247
251
}
248
252
0 commit comments