Skip to content

feat(//core/converters): Add expand layer, expand_as and repeat layer functionality #289

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 2, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions core/conversion/converters/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ cc_library(
"impl/constant.cpp",
"impl/conv_deconv.cpp",
"impl/element_wise.cpp",
"impl/expand.cpp",
"impl/linear.cpp",
"impl/matrix_multiply.cpp",
"impl/pooling.cpp",
Expand Down
152 changes: 152 additions & 0 deletions core/conversion/converters/impl/expand.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
#include "NvInfer.h"
#include "core/conversion/converters/converters.h"
#include "core/conversion/tensorcontainer/TensorContainer.h"
#include "core/util/prelude.h"
#include "core/util/trt_util.h"
#include "torch/torch.h"

#include <ATen/ATen.h>
#include <vector>

namespace trtorch {
namespace core {
namespace conversion {
namespace converters {
namespace impl {
namespace {

bool add_expand(ConversionCtx* ctx, const torch::jit::Node* n, nvinfer1::ITensor* in, nvinfer1::Dims expandedDims) {
auto input_dims = in->getDimensions();
TRTORCH_CHECK(
input_dims.nbDims <= expandedDims.nbDims,
"Number of dimensions of the desired expansion must be greater than or equal to the number of input dimensions");

// Validate the expansion. Eg: an input of [3, 1] can be expanded to [1, 3, 4] but not [3, 4, 1]
for (int i = expandedDims.nbDims - 1; i >= 0; --i) {
int64_t offset = expandedDims.nbDims - 1 - i;
int64_t dim = input_dims.nbDims - 1 - offset;
int64_t size = (dim >= 0) ? input_dims.d[dim] : 1;
int64_t targetSize = expandedDims.d[i];
if (size != targetSize) {
if (size != 1) {
TRTORCH_THROW_ERROR(
"The expanded size of tensor (" << targetSize << ")"
<< " must match the existing size (" << size << ")"
<< " at dimension " << i);
}
}
}

auto num_expand_dims = expandedDims.nbDims - input_dims.nbDims;
if (num_expand_dims > 0) {
nvinfer1::Dims reshape_dims;
reshape_dims.nbDims = expandedDims.nbDims;
for (int i = 0; i < num_expand_dims; i++) {
reshape_dims.d[i] = 1;
}
for (int i = 0; i < input_dims.nbDims; i++) {
reshape_dims.d[num_expand_dims + i] = input_dims.d[i];
}
// Add a reshape layer to expand dims
auto reshape_layer = ctx->net->addShuffle(*in);
reshape_layer->setReshapeDimensions(reshape_dims);
in = reshape_layer->getOutput(0);
LOG_DEBUG("Input reshaped to : " << in->getDimensions() << " from " << input_dims);
}

// Start the slicing from beginning of tensor since this is an expand layer
std::vector<int64_t> start_vec(expandedDims.nbDims, 0);
auto start_offset = util::toDims(c10::IntArrayRef(start_vec));

// Set the stride of non singleton dimension to 1
std::vector<int64_t> strides_vec(expandedDims.nbDims, 0);
for (int i = 0; i < expandedDims.nbDims; i++) {
strides_vec[i] = (in->getDimensions().d[i] != 1);
}

auto strides = util::toDims(c10::IntArrayRef(strides_vec));
// Slice layer does the expansion in TRT. Desired output size is specified by expandedDims
auto slice_layer = ctx->net->addSlice(*in, start_offset, expandedDims, strides);
slice_layer->setName(util::node_info(n).c_str());

auto out = ctx->AssociateValueAndTensor(n->outputs()[0], slice_layer->getOutput(0));

LOG_DEBUG("Expand layer output tensor shape: " << out->getDimensions());

return true;
}

auto expand_registrations TRTORCH_UNUSED =
RegisterNodeConversionPatterns()
.pattern({"aten::expand(Tensor(a) self, int[] size, *, bool implicit=False) -> (Tensor(a))",
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
auto in = args[0].ITensor();
auto input_dims = in->getDimensions();
auto expanded_size = args[1].unwrapToIntList();
auto expandedDims = util::toDims(expanded_size);
LOG_DEBUG("(expand layer) Expand input from " << input_dims << " to " << expandedDims);
return add_expand(ctx, n, in, expandedDims);
}})
.pattern({"aten::expand_as(Tensor(a) self, Tensor other) -> (Tensor(a))",
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
// TODO: Currently expand supports static shapes. Need to explore if the same code can be extended
// to dynamic expansion.
auto in = args[0].ITensor();
auto input_dims = in->getDimensions();
auto targetTensor = args[1].ITensor();
auto targetDims = targetTensor->getDimensions();
LOG_DEBUG("(expand_as layer) Expand input from " << input_dims << " to " << targetDims);
return add_expand(ctx, n, in, targetDims);
}})
.pattern({"aten::repeat(Tensor self, int[] repeats) -> (Tensor)",
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
auto in = args[0].ITensor();
auto input_dims = in->getDimensions();
auto repeats = args[1].unwrapToIntList().vec();
TRTORCH_CHECK(
repeats.size() >= input_dims.nbDims,
"Number of repeat dimensions cannot be smaller than number of input dimensions");
auto num_expand_dims = repeats.size() - input_dims.nbDims;
if (num_expand_dims > 0) {
nvinfer1::Dims reshape_dims;
reshape_dims.nbDims = repeats.size();
for (int i = 0; i < num_expand_dims; i++) {
reshape_dims.d[i] = 1;
}
for (int i = 0; i < input_dims.nbDims; i++) {
reshape_dims.d[num_expand_dims + i] = input_dims.d[i];
}
// Add a reshape layer to expand dims
auto reshape_layer = ctx->net->addShuffle(*in);
reshape_layer->setReshapeDimensions(reshape_dims);
in = reshape_layer->getOutput(0);
LOG_DEBUG("Input reshaped to : " << in->getDimensions() << " from " << input_dims);
}

LOG_DEBUG("Repeats: " << repeats);

// Concat across all repeat axes.
// TODO: Implementation might not be performant. Explore other strategies to improve performance.
for (int i = repeats.size() - 1; i >= 0; --i) {
std::vector<nvinfer1::ITensor*> tensors_vec;
for (int j = 0; j < repeats[i]; j++) {
tensors_vec.push_back(in);
}
auto concat_layer = ctx->net->addConcatenation(tensors_vec.data(), tensors_vec.size());
concat_layer->setAxis(i);
in = concat_layer->getOutput(0);
}

auto out = ctx->AssociateValueAndTensor(n->outputs()[0], in);

LOG_DEBUG("Repeat layer output tensor shape: " << in->getDimensions());

return true;
}});

} // namespace
} // namespace impl
} // namespace converters
} // namespace conversion
} // namespace core
} // namespace trtorch
5 changes: 5 additions & 0 deletions tests/core/conversion/converters/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@ converter_test(
name = "test_element_wise"
)

converter_test(
name = "test_expand"
)

converter_test(
name = "test_linear"
)
Expand Down Expand Up @@ -78,6 +82,7 @@ test_suite(
":test_batch_norm",
":test_conv_deconv",
":test_element_wise",
":test_expand",
":test_linear",
":test_matrix_multiply",
":test_pooling",
Expand Down
202 changes: 202 additions & 0 deletions tests/core/conversion/converters/test_expand.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,202 @@
#include <torch/torch.h>
#include <string>
#include "core/compiler.h"
#include "gtest/gtest.h"
#include "tests/util/util.h"
#include "torch/csrc/jit/ir/irparser.h"

TEST(Converters, ATenExpandSameDimConvertsCorrectly) {
const auto graph = R"IR(
graph(%x.1 : Tensor):
%2 : int[] = prim::Constant[value=[3, 4]]()
%3 : bool = prim::Constant[value=0]()
%4 : Tensor = aten::expand(%x.1, %2, %3)
return (%4))IR";

auto g = std::make_shared<torch::jit::Graph>();

torch::jit::parseIR(graph, &*g);

auto in = at::randint(1, 10, {3, 1}, {at::kCUDA});

auto jit_in = at::clone(in);
auto params = trtorch::core::conversion::get_named_params(g->inputs(), {});
auto jit_results = trtorch::tests::util::RunGraph(g, params, {jit_in});

auto trt_in = at::clone(in);
params = trtorch::core::conversion::get_named_params(g->inputs(), {});
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {trt_in});

auto trt = trt_results[0].reshape(jit_results[0].sizes());

ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt, 2e-6));
}

TEST(Converters, ATenExpandTileConvertsCorrectly) {
const auto graph = R"IR(
graph(%x.1 : Tensor):
%2 : int[] = prim::Constant[value=[2, 3, 1]]()
%3 : bool = prim::Constant[value=0]()
%4 : Tensor = aten::expand(%x.1, %2, %3)
return (%4))IR";

auto g = std::make_shared<torch::jit::Graph>();

torch::jit::parseIR(graph, &*g);

auto in = at::randint(1, 10, {3, 1}, {at::kCUDA});

auto jit_in = at::clone(in);
auto params = trtorch::core::conversion::get_named_params(g->inputs(), {});
auto jit_results = trtorch::tests::util::RunGraph(g, params, {jit_in});

auto trt_in = at::clone(in);
params = trtorch::core::conversion::get_named_params(g->inputs(), {});
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {trt_in});

auto trt = trt_results[0].reshape(jit_results[0].sizes());

ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt, 2e-6));
}

TEST(Converters, ATenExpandTileLastConvertsCorrectly) {
const auto graph = R"IR(
graph(%x.1 : Tensor):
%2 : int[] = prim::Constant[value=[1, 3, 4]]()
%3 : bool = prim::Constant[value=0]()
%4 : Tensor = aten::expand(%x.1, %2, %3)
return (%4))IR";

auto g = std::make_shared<torch::jit::Graph>();

torch::jit::parseIR(graph, &*g);

auto in = at::randint(1, 10, {3, 1}, {at::kCUDA});

auto jit_in = at::clone(in);
auto params = trtorch::core::conversion::get_named_params(g->inputs(), {});
auto jit_results = trtorch::tests::util::RunGraph(g, params, {jit_in});

auto trt_in = at::clone(in);
params = trtorch::core::conversion::get_named_params(g->inputs(), {});
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {trt_in});

auto trt = trt_results[0].reshape(jit_results[0].sizes());

ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt, 2e-6));
}

/* Expand_as layer takes two inputs and only dimensions of second input are
actually used. TRT prunes away the second input. This will result in internal
failure from TRT. To avoid unrelated issues, we add a dummy operation which
outputs second_input+2 as a second output. The second input is preserved.
*/
TEST(Converters, ATenExpandASConvertsCorrectly) {
const auto graph = R"IR(
graph(%x.1 : Tensor,
%y.1 : Tensor):
%3 : int = prim::Constant[value=1]()
%4 : int = prim::Constant[value=2]()
%5 : Tensor = aten::expand_as(%x.1, %y.1)
%6 : Tensor = aten::add(%y.1, %4, %3)
return (%5, %6))IR";

auto g = std::make_shared<torch::jit::Graph>();

torch::jit::parseIR(graph, &*g);

auto in = at::randint(1, 10, {3, 1}, {at::kCUDA});
auto target_in = at::randint(1, 10, {2, 3, 1}, {at::kCUDA});

auto jit_in = at::clone(in);
auto jit_target_in = at::clone(target_in);
auto params = trtorch::core::conversion::get_named_params(g->inputs(), {});
auto jit_results = trtorch::tests::util::RunGraph(g, params, {jit_in, jit_target_in});

auto trt_in = at::clone(jit_in);
auto trt_target_in = at::clone(jit_target_in);
params = trtorch::core::conversion::get_named_params(g->inputs(), {});
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {trt_in, trt_target_in});

auto trt = trt_results[0].reshape(jit_results[0].sizes());

ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt, 2e-6));
}

TEST(Converters, ATenRepeatConvertsCorrectly) {
const auto graph = R"IR(
graph(%x.1 : Tensor):
%2 : int[] = prim::Constant[value=[4, 2]]()
%3 : Tensor = aten::repeat(%x.1, %2)
return (%3))IR";

auto g = std::make_shared<torch::jit::Graph>();

torch::jit::parseIR(graph, &*g);

auto in = at::randint(1, 10, {1, 3}, {at::kCUDA});

auto jit_in = at::clone(in);
auto params = trtorch::core::conversion::get_named_params(g->inputs(), {});
auto jit_results = trtorch::tests::util::RunGraph(g, params, {jit_in});

auto trt_in = at::clone(jit_in);
params = trtorch::core::conversion::get_named_params(g->inputs(), {});
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {trt_in});

auto trt = trt_results[0].reshape(jit_results[0].sizes());

ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt, 2e-6));
}

TEST(Converters, ATenRepeat3dConvertsCorrectly) {
const auto graph = R"IR(
graph(%x.1 : Tensor):
%2 : int[] = prim::Constant[value=[2, 2, 2]]()
%3 : Tensor = aten::repeat(%x.1, %2)
return (%3))IR";

auto g = std::make_shared<torch::jit::Graph>();

torch::jit::parseIR(graph, &*g);

auto in = at::randint(1, 10, {2, 3, 2}, {at::kCUDA});

auto jit_in = at::clone(in);
auto params = trtorch::core::conversion::get_named_params(g->inputs(), {});
auto jit_results = trtorch::tests::util::RunGraph(g, params, {jit_in});

auto trt_in = at::clone(jit_in);
params = trtorch::core::conversion::get_named_params(g->inputs(), {});
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {trt_in});

auto trt = trt_results[0].reshape(jit_results[0].sizes());

ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt, 2e-6));
}

TEST(Converters, ATenRepeatExtraDimsConvertsCorrectly) {
const auto graph = R"IR(
graph(%x.1 : Tensor):
%2 : int[] = prim::Constant[value=[1, 3, 2]]()
%3 : Tensor = aten::repeat(%x.1, %2)
return (%3))IR";

auto g = std::make_shared<torch::jit::Graph>();

torch::jit::parseIR(graph, &*g);

auto in = at::randint(1, 10, {1, 3}, {at::kCUDA});

auto jit_in = at::clone(in);
auto params = trtorch::core::conversion::get_named_params(g->inputs(), {});
auto jit_results = trtorch::tests::util::RunGraph(g, params, {jit_in});

auto trt_in = at::clone(jit_in);
params = trtorch::core::conversion::get_named_params(g->inputs(), {});
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {trt_in});

auto trt = trt_results[0].reshape(jit_results[0].sizes());

ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt, 2e-6));
}