@@ -188,13 +188,14 @@ auto element_wise_registrations TRTORCH_UNUSED =
188
188
auto other = args[1 ].ITensorOrFreeze (ctx);
189
189
190
190
if (1 != scalar) {
191
- auto scaleW = Weights (ctx, scalar);
192
- auto unuse = Weights ();
193
- // IScaleLayer assert shift, scale and power to have
194
- // the same dtype
195
- auto scaleLayer = ctx->net ->addScale (
196
- *other, nvinfer1::ScaleMode::kUNIFORM , unuse.data , scaleW.data , unuse.data );
197
- TRTORCH_CHECK (scaleLayer, " Unable to create scale layer from node: " << *n);
191
+ auto alphaTensor = tensor_to_const (ctx, torch::tensor ({scalar}));
192
+ auto scaleLayer = add_elementwise (
193
+ ctx,
194
+ nvinfer1::ElementWiseOperation::kPROD ,
195
+ other,
196
+ alphaTensor,
197
+ util::node_info (n) + std::string (" _AlphaMultiplier" ));
198
+ TRTORCH_CHECK (scaleLayer, " Unable to create alpha*input layer from node: " << *n);
198
199
other = scaleLayer->getOutput (0 );
199
200
}
200
201
@@ -216,13 +217,14 @@ auto element_wise_registrations TRTORCH_UNUSED =
216
217
auto other = args[1 ].ITensorOrFreeze (ctx);
217
218
218
219
if (1 != scalar) {
219
- auto scaleW = Weights (ctx, scalar);
220
- auto unuse = Weights ();
221
- // IScaleLayer assert shift, scale and power to have
222
- // the same dtype
223
- auto scaleLayer = ctx->net ->addScale (
224
- *other, nvinfer1::ScaleMode::kUNIFORM , unuse.data , scaleW.data , unuse.data );
225
- TRTORCH_CHECK (scaleLayer, " Unable to create scale layer from node: " << *n);
220
+ auto alphaTensor = tensor_to_const (ctx, torch::tensor ({scalar}));
221
+ auto scaleLayer = add_elementwise (
222
+ ctx,
223
+ nvinfer1::ElementWiseOperation::kPROD ,
224
+ other,
225
+ alphaTensor,
226
+ util::node_info (n) + std::string (" _AlphaMultiplier" ));
227
+ TRTORCH_CHECK (scaleLayer, " Unable to create alpha*input layer from node: " << *n);
226
228
other = scaleLayer->getOutput (0 );
227
229
}
228
230
@@ -235,6 +237,63 @@ auto element_wise_registrations TRTORCH_UNUSED =
235
237
LOG_DEBUG (" Output tensor shape: " << out->getDimensions ());
236
238
return true ;
237
239
}})
240
+ .pattern({" aten::rsub.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> (Tensor)" ,
241
+ [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
242
+ // Should implement other - alpha * self
243
+ auto self = args[0 ].ITensorOrFreeze (ctx);
244
+ auto otherScalar = args[1 ].unwrapToScalar ().to <float >();
245
+ auto other = tensor_to_const (ctx, torch::tensor ({otherScalar}));
246
+ auto scalar = args[2 ].unwrapToScalar ().to <float >();
247
+
248
+ if (1 != scalar) {
249
+ auto alphaTensor = tensor_to_const (ctx, torch::tensor ({scalar}));
250
+ auto scaleLayer = add_elementwise (
251
+ ctx,
252
+ nvinfer1::ElementWiseOperation::kPROD ,
253
+ self,
254
+ alphaTensor,
255
+ util::node_info (n) + std::string (" _AlphaMultiplier" ));
256
+ TRTORCH_CHECK (scaleLayer, " Unable to create alpha*input layer from node: " << *n);
257
+ self = scaleLayer->getOutput (0 );
258
+ }
259
+
260
+ auto rsub =
261
+ add_elementwise (ctx, nvinfer1::ElementWiseOperation::kSUB , other, self, util::node_info (n));
262
+ TRTORCH_CHECK (rsub, " Unable to create rsub layer from node: " << *n);
263
+
264
+ rsub->setName (util::node_info (n).c_str ());
265
+ auto out = ctx->AssociateValueAndTensor (n->outputs ()[0 ], rsub->getOutput (0 ));
266
+ LOG_DEBUG (" Output tensor shape: " << out->getDimensions ());
267
+ return true ;
268
+ }})
269
+ .pattern({" aten::rsub.Tensor(Tensor self, Tensor other, Scalar alpha=1) -> (Tensor)" ,
270
+ [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
271
+ // Should implement other - alpha * self
272
+ auto self = args[0 ].ITensorOrFreeze (ctx);
273
+ auto other = args[1 ].ITensorOrFreeze (ctx);
274
+ auto scalar = args[2 ].unwrapToScalar ().to <float >();
275
+
276
+ if (1 != scalar) {
277
+ auto alphaTensor = tensor_to_const (ctx, torch::tensor ({scalar}));
278
+ auto scaleLayer = add_elementwise (
279
+ ctx,
280
+ nvinfer1::ElementWiseOperation::kPROD ,
281
+ self,
282
+ alphaTensor,
283
+ util::node_info (n) + std::string (" _AlphaMultiplier" ));
284
+ TRTORCH_CHECK (scaleLayer, " Unable to create alpha*input layer from node: " << *n);
285
+ self = scaleLayer->getOutput (0 );
286
+ }
287
+
288
+ auto rsub =
289
+ add_elementwise (ctx, nvinfer1::ElementWiseOperation::kSUB , other, self, util::node_info (n));
290
+ TRTORCH_CHECK (rsub, " Unable to create rsub layer from node: " << *n);
291
+
292
+ rsub->setName (util::node_info (n).c_str ());
293
+ auto out = ctx->AssociateValueAndTensor (n->outputs ()[0 ], rsub->getOutput (0 ));
294
+ LOG_DEBUG (" Output tensor shape: " << out->getDimensions ());
295
+ return true ;
296
+ }})
238
297
.pattern({" aten::div.Tensor(Tensor self, Tensor other) -> Tensor" ,
239
298
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
240
299
// Should implement self / other
@@ -412,6 +471,63 @@ auto element_wise_registrations TRTORCH_UNUSED =
412
471
pow->setName (util::node_info (n).c_str ());
413
472
auto out = ctx->AssociateValueAndTensor (n->outputs ()[0 ], pow->getOutput (0 ));
414
473
474
+ LOG_DEBUG (" Output tensor shape: " << out->getDimensions ());
475
+ return true ;
476
+ }})
477
+ .pattern({" aten::floor_divide(Tensor self, Tensor other) -> (Tensor)" ,
478
+ [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
479
+ // TODO: Remove with functionalization
480
+ auto self = args[0 ].ITensorOrFreeze (ctx);
481
+ auto other = args[1 ].ITensorOrFreeze (ctx);
482
+ auto floor_divide = add_elementwise (
483
+ ctx, nvinfer1::ElementWiseOperation::kFLOOR_DIV , self, other, util::node_info (n));
484
+ TRTORCH_CHECK (floor_divide, " Unable to create floor_divide layer from node: " << *n);
485
+
486
+ floor_divide->setName (util::node_info (n).c_str ());
487
+ auto out = ctx->AssociateValueAndTensor (n->outputs ()[0 ], floor_divide->getOutput (0 ));
488
+ LOG_DEBUG (" Output tensor shape: " << out->getDimensions ());
489
+ return true ;
490
+ }})
491
+ .pattern({" aten::floor_divide.Scalar(Tensor self, Scalar other) -> (Tensor)" ,
492
+ [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
493
+ // TODO: Remove with functionalization
494
+ auto self = args[0 ].ITensorOrFreeze (ctx);
495
+ auto otherScalar = args[1 ].unwrapToScalar ().to <float >();
496
+ auto other = tensor_to_const (ctx, torch::tensor ({otherScalar}));
497
+ auto floor_divide = add_elementwise (
498
+ ctx, nvinfer1::ElementWiseOperation::kFLOOR_DIV , self, other, util::node_info (n));
499
+ TRTORCH_CHECK (floor_divide, " Unable to create floor_divide layer from node: " << *n);
500
+
501
+ floor_divide->setName (util::node_info (n).c_str ());
502
+ auto out = ctx->AssociateValueAndTensor (n->outputs ()[0 ], floor_divide->getOutput (0 ));
503
+ LOG_DEBUG (" Output tensor shape: " << out->getDimensions ());
504
+ return true ;
505
+ }})
506
+ .pattern({" aten::max.other(Tensor self, Tensor other) -> (Tensor)" ,
507
+ [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
508
+ // TODO: Remove with functionalization
509
+ auto self = args[0 ].ITensorOrFreeze (ctx);
510
+ auto other = args[1 ].ITensorOrFreeze (ctx);
511
+ auto max =
512
+ add_elementwise (ctx, nvinfer1::ElementWiseOperation::kMAX , self, other, util::node_info (n));
513
+ TRTORCH_CHECK (max, " Unable to create max layer from node: " << *n);
514
+
515
+ max->setName (util::node_info (n).c_str ());
516
+ auto out = ctx->AssociateValueAndTensor (n->outputs ()[0 ], max->getOutput (0 ));
517
+ LOG_DEBUG (" Output tensor shape: " << out->getDimensions ());
518
+ return true ;
519
+ }})
520
+ .pattern({" aten::min.other(Tensor self, Tensor other) -> (Tensor)" ,
521
+ [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
522
+ // TODO: Remove with functionalization
523
+ auto self = args[0 ].ITensorOrFreeze (ctx);
524
+ auto other = args[1 ].ITensorOrFreeze (ctx);
525
+ auto min =
526
+ add_elementwise (ctx, nvinfer1::ElementWiseOperation::kMIN , self, other, util::node_info (n));
527
+ TRTORCH_CHECK (min, " Unable to create min layer from node: " << *n);
528
+
529
+ min->setName (util::node_info (n).c_str ());
530
+ auto out = ctx->AssociateValueAndTensor (n->outputs ()[0 ], min->getOutput (0 ));
415
531
LOG_DEBUG (" Output tensor shape: " << out->getDimensions ());
416
532
return true ;
417
533
}})
0 commit comments