Skip to content

Adding support for rsub, min, max, floor_divide #309

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Feb 3, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
144 changes: 130 additions & 14 deletions core/conversion/converters/impl/element_wise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -188,13 +188,14 @@ auto element_wise_registrations TRTORCH_UNUSED =
auto other = args[1].ITensorOrFreeze(ctx);

if (1 != scalar) {
auto scaleW = Weights(ctx, scalar);
auto unuse = Weights();
// IScaleLayer assert shift, scale and power to have
// the same dtype
auto scaleLayer = ctx->net->addScale(
*other, nvinfer1::ScaleMode::kUNIFORM, unuse.data, scaleW.data, unuse.data);
TRTORCH_CHECK(scaleLayer, "Unable to create scale layer from node: " << *n);
auto alphaTensor = tensor_to_const(ctx, torch::tensor({scalar}));
auto scaleLayer = add_elementwise(
ctx,
nvinfer1::ElementWiseOperation::kPROD,
other,
alphaTensor,
util::node_info(n) + std::string("_AlphaMultiplier"));
TRTORCH_CHECK(scaleLayer, "Unable to create alpha*input layer from node: " << *n);
other = scaleLayer->getOutput(0);
}

Expand All @@ -216,13 +217,14 @@ auto element_wise_registrations TRTORCH_UNUSED =
auto other = args[1].ITensorOrFreeze(ctx);

if (1 != scalar) {
auto scaleW = Weights(ctx, scalar);
auto unuse = Weights();
// IScaleLayer assert shift, scale and power to have
// the same dtype
auto scaleLayer = ctx->net->addScale(
*other, nvinfer1::ScaleMode::kUNIFORM, unuse.data, scaleW.data, unuse.data);
TRTORCH_CHECK(scaleLayer, "Unable to create scale layer from node: " << *n);
auto alphaTensor = tensor_to_const(ctx, torch::tensor({scalar}));
auto scaleLayer = add_elementwise(
ctx,
nvinfer1::ElementWiseOperation::kPROD,
other,
alphaTensor,
util::node_info(n) + std::string("_AlphaMultiplier"));
TRTORCH_CHECK(scaleLayer, "Unable to create alpha*input layer from node: " << *n);
other = scaleLayer->getOutput(0);
}

Expand All @@ -235,6 +237,63 @@ auto element_wise_registrations TRTORCH_UNUSED =
LOG_DEBUG("Output tensor shape: " << out->getDimensions());
return true;
}})
.pattern({"aten::rsub.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> (Tensor)",
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
// Should implement other - alpha * self
auto self = args[0].ITensorOrFreeze(ctx);
auto otherScalar = args[1].unwrapToScalar().to<float>();
auto other = tensor_to_const(ctx, torch::tensor({otherScalar}));
auto scalar = args[2].unwrapToScalar().to<float>();

if (1 != scalar) {
auto alphaTensor = tensor_to_const(ctx, torch::tensor({scalar}));
auto scaleLayer = add_elementwise(
ctx,
nvinfer1::ElementWiseOperation::kPROD,
self,
alphaTensor,
util::node_info(n) + std::string("_AlphaMultiplier"));
TRTORCH_CHECK(scaleLayer, "Unable to create alpha*input layer from node: " << *n);
self = scaleLayer->getOutput(0);
}

auto rsub =
add_elementwise(ctx, nvinfer1::ElementWiseOperation::kSUB, other, self, util::node_info(n));
TRTORCH_CHECK(rsub, "Unable to create rsub layer from node: " << *n);

rsub->setName(util::node_info(n).c_str());
auto out = ctx->AssociateValueAndTensor(n->outputs()[0], rsub->getOutput(0));
LOG_DEBUG("Output tensor shape: " << out->getDimensions());
return true;
}})
.pattern({"aten::rsub.Tensor(Tensor self, Tensor other, Scalar alpha=1) -> (Tensor)",
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
// Should implement other - alpha * self
auto self = args[0].ITensorOrFreeze(ctx);
auto other = args[1].ITensorOrFreeze(ctx);
auto scalar = args[2].unwrapToScalar().to<float>();

if (1 != scalar) {
auto alphaTensor = tensor_to_const(ctx, torch::tensor({scalar}));
auto scaleLayer = add_elementwise(
ctx,
nvinfer1::ElementWiseOperation::kPROD,
self,
alphaTensor,
util::node_info(n) + std::string("_AlphaMultiplier"));
TRTORCH_CHECK(scaleLayer, "Unable to create alpha*input layer from node: " << *n);
self = scaleLayer->getOutput(0);
}

auto rsub =
add_elementwise(ctx, nvinfer1::ElementWiseOperation::kSUB, other, self, util::node_info(n));
TRTORCH_CHECK(rsub, "Unable to create rsub layer from node: " << *n);

rsub->setName(util::node_info(n).c_str());
auto out = ctx->AssociateValueAndTensor(n->outputs()[0], rsub->getOutput(0));
LOG_DEBUG("Output tensor shape: " << out->getDimensions());
return true;
}})
.pattern({"aten::div.Tensor(Tensor self, Tensor other) -> Tensor",
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
// Should implement self / other
Expand Down Expand Up @@ -412,6 +471,63 @@ auto element_wise_registrations TRTORCH_UNUSED =
pow->setName(util::node_info(n).c_str());
auto out = ctx->AssociateValueAndTensor(n->outputs()[0], pow->getOutput(0));

LOG_DEBUG("Output tensor shape: " << out->getDimensions());
return true;
}})
.pattern({"aten::floor_divide(Tensor self, Tensor other) -> (Tensor)",
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
// TODO: Remove with functionalization
auto self = args[0].ITensorOrFreeze(ctx);
auto other = args[1].ITensorOrFreeze(ctx);
auto floor_divide = add_elementwise(
ctx, nvinfer1::ElementWiseOperation::kFLOOR_DIV, self, other, util::node_info(n));
TRTORCH_CHECK(floor_divide, "Unable to create floor_divide layer from node: " << *n);

floor_divide->setName(util::node_info(n).c_str());
auto out = ctx->AssociateValueAndTensor(n->outputs()[0], floor_divide->getOutput(0));
LOG_DEBUG("Output tensor shape: " << out->getDimensions());
return true;
}})
.pattern({"aten::floor_divide.Scalar(Tensor self, Scalar other) -> (Tensor)",
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
// TODO: Remove with functionalization
auto self = args[0].ITensorOrFreeze(ctx);
auto otherScalar = args[1].unwrapToScalar().to<float>();
auto other = tensor_to_const(ctx, torch::tensor({otherScalar}));
auto floor_divide = add_elementwise(
ctx, nvinfer1::ElementWiseOperation::kFLOOR_DIV, self, other, util::node_info(n));
TRTORCH_CHECK(floor_divide, "Unable to create floor_divide layer from node: " << *n);

floor_divide->setName(util::node_info(n).c_str());
auto out = ctx->AssociateValueAndTensor(n->outputs()[0], floor_divide->getOutput(0));
LOG_DEBUG("Output tensor shape: " << out->getDimensions());
return true;
}})
.pattern({"aten::max.other(Tensor self, Tensor other) -> (Tensor)",
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
// TODO: Remove with functionalization
auto self = args[0].ITensorOrFreeze(ctx);
auto other = args[1].ITensorOrFreeze(ctx);
auto max =
add_elementwise(ctx, nvinfer1::ElementWiseOperation::kMAX, self, other, util::node_info(n));
TRTORCH_CHECK(max, "Unable to create max layer from node: " << *n);

max->setName(util::node_info(n).c_str());
auto out = ctx->AssociateValueAndTensor(n->outputs()[0], max->getOutput(0));
LOG_DEBUG("Output tensor shape: " << out->getDimensions());
return true;
}})
.pattern({"aten::min.other(Tensor self, Tensor other) -> (Tensor)",
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
// TODO: Remove with functionalization
auto self = args[0].ITensorOrFreeze(ctx);
auto other = args[1].ITensorOrFreeze(ctx);
auto min =
add_elementwise(ctx, nvinfer1::ElementWiseOperation::kMIN, self, other, util::node_info(n));
TRTORCH_CHECK(min, "Unable to create min layer from node: " << *n);

min->setName(util::node_info(n).c_str());
auto out = ctx->AssociateValueAndTensor(n->outputs()[0], min->getOutput(0));
LOG_DEBUG("Output tensor shape: " << out->getDimensions());
return true;
}})
Expand Down
3 changes: 1 addition & 2 deletions core/lowering/passes/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -40,5 +40,4 @@ pkg_tar(
name = "include",
package_dir = "core/lowering/passes/",
srcs = ["passes.h"],
)

)
70 changes: 68 additions & 2 deletions tests/core/conversion/converters/test_element_wise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ TEST(Converters, ATenAddWithScalarConvertsCorrectly) {
TEST(Converters, ATenSubConvertsCorrectly) {
const auto graph = R"IR(
graph(%0 : Tensor, %1 : Tensor):
%2 : int = prim::Constant[value=1]()
%2 : int = prim::Constant[value=2.3]()
%3 : Tensor = aten::sub(%0, %1, %2)
return (%3))IR";
pointwise_test_helper(graph, false);
Expand Down Expand Up @@ -170,7 +170,73 @@ TEST(Converters, ATenNeScalarConvertsCorrectly) {
%3 : Tensor = aten::ne(%x.1, %2)
return (%3))IR";
pointwise_test_helper(graph, true, false, {3, 4, 2});
;
}

TEST(Converters, ATenFloorDivideConvertsCorrectly) {
const auto graph = R"IR(
graph(%0 : Tensor, %1 : Tensor):
%2 : Tensor = aten::floor_divide(%0, %1)
return (%2))IR";
pointwise_test_helper(graph, false);
pointwise_test_helper(graph, false, false, {3, 4}, {4});
pointwise_test_helper(graph, false, false, {4}, {3, 4});
pointwise_test_helper(graph, false, true, {3, 4, 3}, {4, 3});
pointwise_test_helper(graph, false, true, {4, 3}, {3, 4, 3});
}

TEST(Converters, ATenFloorDivideWithScalarConvertsCorrectly) {
const auto graph = R"IR(
graph(%0 : Tensor):
%scalar : float = prim::Constant[value=2.4]()
%1 : Tensor = aten::floor_divide(%0, %scalar)
return (%1))IR";
pointwise_test_helper(graph, true);
}

TEST(Converters, ATenMaxConvertsCorrectly) {
const auto graph = R"IR(
graph(%0 : Tensor, %1 : Tensor):
%2 : Tensor = aten::max(%0, %1)
return (%2))IR";
pointwise_test_helper(graph, false);
pointwise_test_helper(graph, false, false, {3, 4}, {4});
pointwise_test_helper(graph, false, false, {4}, {3, 4});
pointwise_test_helper(graph, false, true, {3, 4, 3}, {4, 3});
pointwise_test_helper(graph, false, true, {4, 3}, {3, 4, 3});
}

TEST(Converters, ATenMinConvertsCorrectly) {
const auto graph = R"IR(
graph(%0 : Tensor, %1 : Tensor):
%2 : Tensor = aten::min(%0, %1)
return (%2))IR";
pointwise_test_helper(graph, false);
pointwise_test_helper(graph, false, false, {3, 4}, {4});
pointwise_test_helper(graph, false, false, {4}, {3, 4});
pointwise_test_helper(graph, false, true, {3, 4, 3}, {4, 3});
pointwise_test_helper(graph, false, true, {4, 3}, {3, 4, 3});
}

TEST(Converters, ATenRsubWithTensorConvertsCorrectly) {
const auto graph = R"IR(
graph(%0 : Tensor, %1 : Tensor):
%2 : int = prim::Constant[value=2]()
%3 : Tensor = aten::rsub(%0, %1, %2)
return (%3))IR";
pointwise_test_helper(graph, false, false, {3, 4}, {4});
pointwise_test_helper(graph, false, false, {4}, {3, 4});
pointwise_test_helper(graph, false, true, {4, 3, 3, 3}, {4, 3, 3, 3});
}

TEST(Converters, ATenRsubWithScalarConvertsCorrectly) {
const auto graph = R"IR(
graph(%0 : Tensor):
%2 : int = prim::Constant[value=2]()
%scalar : float = prim::Constant[value=2.4]()
%3 : Tensor = aten::rsub(%0, %scalar, %2)
return (%3))IR";
pointwise_test_helper(graph, true, false, {4, 3, 3, 3});
}

TEST(Converters, ATenClampMinConvertsCorrectly) {
const auto graph = R"IR(
Expand Down