Skip to content

Commit f05a550

Browse files
authored
Merge pull request #309 from NVIDIA/inocsin/elementwise_ops
Adding support for rsub, min, max, floor_divide
2 parents bad7d63 + f9d29d0 commit f05a550

File tree

3 files changed

+199
-18
lines changed

3 files changed

+199
-18
lines changed

core/conversion/converters/impl/element_wise.cpp

Lines changed: 130 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -188,13 +188,14 @@ auto element_wise_registrations TRTORCH_UNUSED =
188188
auto other = args[1].ITensorOrFreeze(ctx);
189189

190190
if (1 != scalar) {
191-
auto scaleW = Weights(ctx, scalar);
192-
auto unuse = Weights();
193-
// IScaleLayer assert shift, scale and power to have
194-
// the same dtype
195-
auto scaleLayer = ctx->net->addScale(
196-
*other, nvinfer1::ScaleMode::kUNIFORM, unuse.data, scaleW.data, unuse.data);
197-
TRTORCH_CHECK(scaleLayer, "Unable to create scale layer from node: " << *n);
191+
auto alphaTensor = tensor_to_const(ctx, torch::tensor({scalar}));
192+
auto scaleLayer = add_elementwise(
193+
ctx,
194+
nvinfer1::ElementWiseOperation::kPROD,
195+
other,
196+
alphaTensor,
197+
util::node_info(n) + std::string("_AlphaMultiplier"));
198+
TRTORCH_CHECK(scaleLayer, "Unable to create alpha*input layer from node: " << *n);
198199
other = scaleLayer->getOutput(0);
199200
}
200201

@@ -216,13 +217,14 @@ auto element_wise_registrations TRTORCH_UNUSED =
216217
auto other = args[1].ITensorOrFreeze(ctx);
217218

218219
if (1 != scalar) {
219-
auto scaleW = Weights(ctx, scalar);
220-
auto unuse = Weights();
221-
// IScaleLayer assert shift, scale and power to have
222-
// the same dtype
223-
auto scaleLayer = ctx->net->addScale(
224-
*other, nvinfer1::ScaleMode::kUNIFORM, unuse.data, scaleW.data, unuse.data);
225-
TRTORCH_CHECK(scaleLayer, "Unable to create scale layer from node: " << *n);
220+
auto alphaTensor = tensor_to_const(ctx, torch::tensor({scalar}));
221+
auto scaleLayer = add_elementwise(
222+
ctx,
223+
nvinfer1::ElementWiseOperation::kPROD,
224+
other,
225+
alphaTensor,
226+
util::node_info(n) + std::string("_AlphaMultiplier"));
227+
TRTORCH_CHECK(scaleLayer, "Unable to create alpha*input layer from node: " << *n);
226228
other = scaleLayer->getOutput(0);
227229
}
228230

@@ -235,6 +237,63 @@ auto element_wise_registrations TRTORCH_UNUSED =
235237
LOG_DEBUG("Output tensor shape: " << out->getDimensions());
236238
return true;
237239
}})
240+
.pattern({"aten::rsub.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> (Tensor)",
241+
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
242+
// Should implement other - alpha * self
243+
auto self = args[0].ITensorOrFreeze(ctx);
244+
auto otherScalar = args[1].unwrapToScalar().to<float>();
245+
auto other = tensor_to_const(ctx, torch::tensor({otherScalar}));
246+
auto scalar = args[2].unwrapToScalar().to<float>();
247+
248+
if (1 != scalar) {
249+
auto alphaTensor = tensor_to_const(ctx, torch::tensor({scalar}));
250+
auto scaleLayer = add_elementwise(
251+
ctx,
252+
nvinfer1::ElementWiseOperation::kPROD,
253+
self,
254+
alphaTensor,
255+
util::node_info(n) + std::string("_AlphaMultiplier"));
256+
TRTORCH_CHECK(scaleLayer, "Unable to create alpha*input layer from node: " << *n);
257+
self = scaleLayer->getOutput(0);
258+
}
259+
260+
auto rsub =
261+
add_elementwise(ctx, nvinfer1::ElementWiseOperation::kSUB, other, self, util::node_info(n));
262+
TRTORCH_CHECK(rsub, "Unable to create rsub layer from node: " << *n);
263+
264+
rsub->setName(util::node_info(n).c_str());
265+
auto out = ctx->AssociateValueAndTensor(n->outputs()[0], rsub->getOutput(0));
266+
LOG_DEBUG("Output tensor shape: " << out->getDimensions());
267+
return true;
268+
}})
269+
.pattern({"aten::rsub.Tensor(Tensor self, Tensor other, Scalar alpha=1) -> (Tensor)",
270+
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
271+
// Should implement other - alpha * self
272+
auto self = args[0].ITensorOrFreeze(ctx);
273+
auto other = args[1].ITensorOrFreeze(ctx);
274+
auto scalar = args[2].unwrapToScalar().to<float>();
275+
276+
if (1 != scalar) {
277+
auto alphaTensor = tensor_to_const(ctx, torch::tensor({scalar}));
278+
auto scaleLayer = add_elementwise(
279+
ctx,
280+
nvinfer1::ElementWiseOperation::kPROD,
281+
self,
282+
alphaTensor,
283+
util::node_info(n) + std::string("_AlphaMultiplier"));
284+
TRTORCH_CHECK(scaleLayer, "Unable to create alpha*input layer from node: " << *n);
285+
self = scaleLayer->getOutput(0);
286+
}
287+
288+
auto rsub =
289+
add_elementwise(ctx, nvinfer1::ElementWiseOperation::kSUB, other, self, util::node_info(n));
290+
TRTORCH_CHECK(rsub, "Unable to create rsub layer from node: " << *n);
291+
292+
rsub->setName(util::node_info(n).c_str());
293+
auto out = ctx->AssociateValueAndTensor(n->outputs()[0], rsub->getOutput(0));
294+
LOG_DEBUG("Output tensor shape: " << out->getDimensions());
295+
return true;
296+
}})
238297
.pattern({"aten::div.Tensor(Tensor self, Tensor other) -> Tensor",
239298
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
240299
// Should implement self / other
@@ -412,6 +471,63 @@ auto element_wise_registrations TRTORCH_UNUSED =
412471
pow->setName(util::node_info(n).c_str());
413472
auto out = ctx->AssociateValueAndTensor(n->outputs()[0], pow->getOutput(0));
414473

474+
LOG_DEBUG("Output tensor shape: " << out->getDimensions());
475+
return true;
476+
}})
477+
.pattern({"aten::floor_divide(Tensor self, Tensor other) -> (Tensor)",
478+
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
479+
// TODO: Remove with functionalization
480+
auto self = args[0].ITensorOrFreeze(ctx);
481+
auto other = args[1].ITensorOrFreeze(ctx);
482+
auto floor_divide = add_elementwise(
483+
ctx, nvinfer1::ElementWiseOperation::kFLOOR_DIV, self, other, util::node_info(n));
484+
TRTORCH_CHECK(floor_divide, "Unable to create floor_divide layer from node: " << *n);
485+
486+
floor_divide->setName(util::node_info(n).c_str());
487+
auto out = ctx->AssociateValueAndTensor(n->outputs()[0], floor_divide->getOutput(0));
488+
LOG_DEBUG("Output tensor shape: " << out->getDimensions());
489+
return true;
490+
}})
491+
.pattern({"aten::floor_divide.Scalar(Tensor self, Scalar other) -> (Tensor)",
492+
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
493+
// TODO: Remove with functionalization
494+
auto self = args[0].ITensorOrFreeze(ctx);
495+
auto otherScalar = args[1].unwrapToScalar().to<float>();
496+
auto other = tensor_to_const(ctx, torch::tensor({otherScalar}));
497+
auto floor_divide = add_elementwise(
498+
ctx, nvinfer1::ElementWiseOperation::kFLOOR_DIV, self, other, util::node_info(n));
499+
TRTORCH_CHECK(floor_divide, "Unable to create floor_divide layer from node: " << *n);
500+
501+
floor_divide->setName(util::node_info(n).c_str());
502+
auto out = ctx->AssociateValueAndTensor(n->outputs()[0], floor_divide->getOutput(0));
503+
LOG_DEBUG("Output tensor shape: " << out->getDimensions());
504+
return true;
505+
}})
506+
.pattern({"aten::max.other(Tensor self, Tensor other) -> (Tensor)",
507+
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
508+
// TODO: Remove with functionalization
509+
auto self = args[0].ITensorOrFreeze(ctx);
510+
auto other = args[1].ITensorOrFreeze(ctx);
511+
auto max =
512+
add_elementwise(ctx, nvinfer1::ElementWiseOperation::kMAX, self, other, util::node_info(n));
513+
TRTORCH_CHECK(max, "Unable to create max layer from node: " << *n);
514+
515+
max->setName(util::node_info(n).c_str());
516+
auto out = ctx->AssociateValueAndTensor(n->outputs()[0], max->getOutput(0));
517+
LOG_DEBUG("Output tensor shape: " << out->getDimensions());
518+
return true;
519+
}})
520+
.pattern({"aten::min.other(Tensor self, Tensor other) -> (Tensor)",
521+
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
522+
// TODO: Remove with functionalization
523+
auto self = args[0].ITensorOrFreeze(ctx);
524+
auto other = args[1].ITensorOrFreeze(ctx);
525+
auto min =
526+
add_elementwise(ctx, nvinfer1::ElementWiseOperation::kMIN, self, other, util::node_info(n));
527+
TRTORCH_CHECK(min, "Unable to create min layer from node: " << *n);
528+
529+
min->setName(util::node_info(n).c_str());
530+
auto out = ctx->AssociateValueAndTensor(n->outputs()[0], min->getOutput(0));
415531
LOG_DEBUG("Output tensor shape: " << out->getDimensions());
416532
return true;
417533
}})

core/lowering/passes/BUILD

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,5 +40,4 @@ pkg_tar(
4040
name = "include",
4141
package_dir = "core/lowering/passes/",
4242
srcs = ["passes.h"],
43-
)
44-
43+
)

tests/core/conversion/converters/test_element_wise.cpp

Lines changed: 68 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ TEST(Converters, ATenAddWithScalarConvertsCorrectly) {
8989
TEST(Converters, ATenSubConvertsCorrectly) {
9090
const auto graph = R"IR(
9191
graph(%0 : Tensor, %1 : Tensor):
92-
%2 : int = prim::Constant[value=1]()
92+
%2 : int = prim::Constant[value=2.3]()
9393
%3 : Tensor = aten::sub(%0, %1, %2)
9494
return (%3))IR";
9595
pointwise_test_helper(graph, false);
@@ -170,7 +170,73 @@ TEST(Converters, ATenNeScalarConvertsCorrectly) {
170170
%3 : Tensor = aten::ne(%x.1, %2)
171171
return (%3))IR";
172172
pointwise_test_helper(graph, true, false, {3, 4, 2});
173-
;
173+
}
174+
175+
TEST(Converters, ATenFloorDivideConvertsCorrectly) {
176+
const auto graph = R"IR(
177+
graph(%0 : Tensor, %1 : Tensor):
178+
%2 : Tensor = aten::floor_divide(%0, %1)
179+
return (%2))IR";
180+
pointwise_test_helper(graph, false);
181+
pointwise_test_helper(graph, false, false, {3, 4}, {4});
182+
pointwise_test_helper(graph, false, false, {4}, {3, 4});
183+
pointwise_test_helper(graph, false, true, {3, 4, 3}, {4, 3});
184+
pointwise_test_helper(graph, false, true, {4, 3}, {3, 4, 3});
185+
}
186+
187+
TEST(Converters, ATenFloorDivideWithScalarConvertsCorrectly) {
188+
const auto graph = R"IR(
189+
graph(%0 : Tensor):
190+
%scalar : float = prim::Constant[value=2.4]()
191+
%1 : Tensor = aten::floor_divide(%0, %scalar)
192+
return (%1))IR";
193+
pointwise_test_helper(graph, true);
194+
}
195+
196+
TEST(Converters, ATenMaxConvertsCorrectly) {
197+
const auto graph = R"IR(
198+
graph(%0 : Tensor, %1 : Tensor):
199+
%2 : Tensor = aten::max(%0, %1)
200+
return (%2))IR";
201+
pointwise_test_helper(graph, false);
202+
pointwise_test_helper(graph, false, false, {3, 4}, {4});
203+
pointwise_test_helper(graph, false, false, {4}, {3, 4});
204+
pointwise_test_helper(graph, false, true, {3, 4, 3}, {4, 3});
205+
pointwise_test_helper(graph, false, true, {4, 3}, {3, 4, 3});
206+
}
207+
208+
TEST(Converters, ATenMinConvertsCorrectly) {
209+
const auto graph = R"IR(
210+
graph(%0 : Tensor, %1 : Tensor):
211+
%2 : Tensor = aten::min(%0, %1)
212+
return (%2))IR";
213+
pointwise_test_helper(graph, false);
214+
pointwise_test_helper(graph, false, false, {3, 4}, {4});
215+
pointwise_test_helper(graph, false, false, {4}, {3, 4});
216+
pointwise_test_helper(graph, false, true, {3, 4, 3}, {4, 3});
217+
pointwise_test_helper(graph, false, true, {4, 3}, {3, 4, 3});
218+
}
219+
220+
TEST(Converters, ATenRsubWithTensorConvertsCorrectly) {
221+
const auto graph = R"IR(
222+
graph(%0 : Tensor, %1 : Tensor):
223+
%2 : int = prim::Constant[value=2]()
224+
%3 : Tensor = aten::rsub(%0, %1, %2)
225+
return (%3))IR";
226+
pointwise_test_helper(graph, false, false, {3, 4}, {4});
227+
pointwise_test_helper(graph, false, false, {4}, {3, 4});
228+
pointwise_test_helper(graph, false, true, {4, 3, 3, 3}, {4, 3, 3, 3});
229+
}
230+
231+
TEST(Converters, ATenRsubWithScalarConvertsCorrectly) {
232+
const auto graph = R"IR(
233+
graph(%0 : Tensor):
234+
%2 : int = prim::Constant[value=2]()
235+
%scalar : float = prim::Constant[value=2.4]()
236+
%3 : Tensor = aten::rsub(%0, %scalar, %2)
237+
return (%3))IR";
238+
pointwise_test_helper(graph, true, false, {4, 3, 3, 3});
239+
}
174240

175241
TEST(Converters, ATenClampMinConvertsCorrectly) {
176242
const auto graph = R"IR(

0 commit comments

Comments
 (0)