Skip to content

Commit 9a32619

Browse files
authored
Merge pull request #283 from NVIDIA/split_unpack
feat(//core/): Add support for Split converter and unpack evaluator
2 parents 2e570ae + 74f4a26 commit 9a32619

File tree

6 files changed

+245
-23
lines changed

6 files changed

+245
-23
lines changed

core/conversion/conversion.cpp

Lines changed: 48 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@
77
#include "core/conversion/var/Var.h"
88
#include "core/util/prelude.h"
99

10+
#include "c10/util/intrusive_ptr.h"
11+
#include "core/conversion/tensorcontainer/TensorContainer.h"
12+
1013
namespace trtorch {
1114
namespace core {
1215
namespace conversion {
@@ -173,18 +176,32 @@ void AddInputs(ConversionCtx* ctx, at::ArrayRef<const torch::jit::Value*> inputs
173176

174177
void MarkOutputs(ConversionCtx* ctx, at::ArrayRef<const torch::jit::Value*> outputs) {
175178
for (auto out : outputs) {
176-
std::string name = std::string("output_") + std::to_string(ctx->num_outputs);
177179
auto it = ctx->value_tensor_map.find(out);
178-
// Leaves the potential for unused outputs to be populated with nullptr
179-
// "safely"
180-
TRTORCH_CHECK(
181-
it != ctx->value_tensor_map.end() && it->second,
182-
"No corresponding output TRT Tensor found for TorchScript output: " << out->debugName());
183-
auto out_tensor = it->second;
184-
out_tensor->setName(name.c_str());
185-
ctx->net->markOutput(*out_tensor);
186-
LOG_INFO(ctx->logger, "Marking Output " << out->debugName() << " named " << name << " in engine (ctx.MarkOutput)");
187-
ctx->num_outputs += 1;
180+
if (it == ctx->value_tensor_map.end()) {
181+
if (ctx->evaluated_value_map.find(out) != ctx->evaluated_value_map.end()) {
182+
auto out_ivalue = ctx->evaluated_value_map[out];
183+
if (out_ivalue.isCustomClass()) {
184+
std::string name = std::string("output_") + std::to_string(ctx->num_outputs);
185+
auto output_container = out_ivalue.toCustomClass<TensorContainer>();
186+
nvinfer1::ITensor* out_tensor = output_container.get()->tensor();
187+
out_tensor->setName(name.c_str());
188+
ctx->net->markOutput(*out_tensor);
189+
LOG_INFO(
190+
ctx->logger, "Marking Output " << out->debugName() << " named " << name << " in engine (ctx.MarkOutput)");
191+
ctx->num_outputs += 1;
192+
} else {
193+
TRTORCH_THROW_ERROR("Unknown output type. Only a single tensor or a TensorList type is supported.");
194+
}
195+
}
196+
} else {
197+
std::string name = std::string("output_") + std::to_string(ctx->num_outputs);
198+
auto out_tensor = it->second;
199+
out_tensor->setName(name.c_str());
200+
ctx->net->markOutput(*out_tensor);
201+
LOG_INFO(
202+
ctx->logger, "Marking Output " << out->debugName() << " named " << name << " in engine (ctx.MarkOutput)");
203+
ctx->num_outputs += 1;
204+
}
188205
}
189206
}
190207

@@ -337,12 +354,30 @@ void ConvertBlockToNetDef(
337354
} else if (to_eval) {
338355
auto eval = EvaluateNode(ctx, n);
339356
if (eval) {
340-
if (!eval.value().isTensor()) {
357+
if (n->outputs().size() > 1) { // For ListUnpack scenario
358+
if (eval.value().isTuple()) {
359+
auto eval_list = eval.value().toTuple();
360+
TRTORCH_CHECK(
361+
eval_list->elements().size() == n->outputs().size(),
362+
"Size of evaluated results: " << eval_list->elements().size()
363+
<< " and node outputs size: " << n->outputs().size() << " must match.");
364+
for (int i = 0; i < eval_list->elements().size(); i++) {
365+
auto eval_output = eval_list.get()->elements()[i];
366+
LOG_DEBUG(
367+
ctx->logger,
368+
"Found the evaluated value(s) to be " << eval_output << " for node: " << util::node_info(n));
369+
ctx->AssociateValueAndIValue(n->output(i), eval_output);
370+
}
371+
} else {
372+
TRTORCH_THROW_ERROR("Unsupported return type for evaluated node");
373+
}
374+
} else if (!eval.value().isTensor()) {
341375
LOG_DEBUG(ctx->logger, "Found the value to be: " << eval.value());
376+
ctx->AssociateValueAndIValue(n->output(0), eval.value());
342377
} else {
343378
LOG_DEBUG(ctx->logger, "Found the value to be a tensor (shape " << eval.value().toTensor().sizes() << ')');
379+
ctx->AssociateValueAndIValue(n->output(0), eval.value());
344380
}
345-
ctx->AssociateValueAndIValue(n->output(0), eval.value());
346381
}
347382
} else if (!ignored) {
348383
// Should error out if something fails

core/conversion/converters/impl/select.cpp

Lines changed: 71 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,66 @@
1+
#include <ATen/ATen.h>
2+
#include <vector>
13
#include "NvInfer.h"
4+
#include "c10/util/intrusive_ptr.h"
25
#include "core/conversion/converters/converters.h"
6+
#include "core/conversion/tensorcontainer/TensorContainer.h"
37
#include "core/util/prelude.h"
48
#include "torch/torch.h"
59

6-
#include <ATen/ATen.h>
7-
#include <vector>
8-
910
namespace trtorch {
1011
namespace core {
1112
namespace conversion {
1213
namespace converters {
1314
namespace impl {
1415
namespace {
1516

17+
bool add_split(ConversionCtx* ctx, const torch::jit::Node* n, args& args, bool split_list) {
18+
auto in = args[0].ITensor();
19+
auto axis = args[2].unwrapToInt();
20+
auto inDimSize = in->getDimensions().d[axis];
21+
auto numOutputs = 1;
22+
std::vector<int64_t> sizes;
23+
24+
if (split_list) {
25+
sizes = args[1].unwrapToIntList().vec();
26+
numOutputs = sizes.size();
27+
} else {
28+
auto split_size = args[1].unwrapToInt();
29+
numOutputs = inDimSize / split_size;
30+
if (numOutputs == 1) {
31+
sizes.push_back(split_size);
32+
} else {
33+
sizes = std::vector<int64_t>(numOutputs, 1);
34+
}
35+
}
36+
37+
LOG_DEBUG("Number of split outputs: " << numOutputs);
38+
39+
c10::ListTypePtr lt = n->output()->type()->expect<c10::ListType>();
40+
c10::TypePtr elementType = lt->getElementType();
41+
auto list = c10::impl::GenericList(elementType);
42+
list.reserve(numOutputs);
43+
44+
int start_idx = 0;
45+
for (int i = 0; i < numOutputs; i++) {
46+
at::Tensor indices = torch::arange(start_idx, start_idx + sizes[i], 1).to(torch::kI32);
47+
auto indicesTensor = tensor_to_const(ctx, indices);
48+
49+
auto gather_layer = ctx->net->addGather(*in, *indicesTensor, axis);
50+
auto gather_out = gather_layer->getOutput(0);
51+
52+
auto tensor_holder = TensorContainer();
53+
tensor_holder.hold_tensor(gather_out);
54+
auto ival = c10::IValue(std::move(c10::make_intrusive<TensorContainer>(tensor_holder)));
55+
list.emplace_back(ival);
56+
57+
start_idx = start_idx + sizes[i];
58+
}
59+
60+
auto split_output_ivalue = std::move(torch::jit::IValue(list));
61+
auto out = ctx->AssociateValueAndIValue(n->outputs()[0], split_output_ivalue);
62+
}
63+
1664
auto select_registrations TRTORCH_UNUSED =
1765
RegisterNodeConversionPatterns()
1866
.pattern({"aten::select.int(Tensor(a) self, int dim, int index) -> (Tensor(a))",
@@ -172,11 +220,29 @@ auto select_registrations TRTORCH_UNUSED =
172220
LOG_DEBUG("Slice layer output shape: " << out->getDimensions());
173221

174222
return true;
175-
}});
223+
}})
224+
.pattern({"aten::split(Tensor self, int[] split_sizes, int dim=0) -> (Tensor[])",
225+
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
226+
add_split(ctx, n, args, true);
227+
LOG_DEBUG("Converted split op into a list of IValues");
228+
return true;
229+
}})
230+
.pattern({"aten::split.Tensor(Tensor(a) self, int split_size, int dim=0) -> (Tensor[])",
231+
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
232+
add_split(ctx, n, args, false);
233+
LOG_DEBUG("Converted split op into a list of IValues");
234+
return true;
235+
}})
236+
.pattern({"aten::split_with_sizes(Tensor(a) self, int[] split_sizes, int dim=0) -> (Tensor[])",
237+
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
238+
add_split(ctx, n, args, true);
239+
LOG_DEBUG("Converted split op into a list of IValues");
240+
return true;
241+
}});
176242

177243
} // namespace
178244
} // namespace impl
179245
} // namespace converters
180246
} // namespace conversion
181247
} // namespace core
182-
} // namespace trtorch
248+
} // namespace trtorch

core/conversion/evaluators/prim.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,13 @@ auto prim_registrations =
3232
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
3333
return at::scalar_to_tensor(args.at(n->output(0)).IValue()->toScalar());
3434
}})
35+
.evaluator({torch::jit::prim::ListUnpack,
36+
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
37+
// Outputs is an IValue which has list of tensors which can be found in ctx->evaluated_value_map
38+
const torch::jit::IValue* outputs = args.at(n->input()).IValue();
39+
auto outputVec = outputs->toList().vec();
40+
return std::move(c10::ivalue::Tuple::create(outputVec));
41+
}})
3542
.evaluator({torch::jit::prim::ListConstruct,
3643
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
3744
const auto num_inputs = n->inputs().size();

tests/core/conversion/converters/test_select.cpp

Lines changed: 84 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -204,4 +204,87 @@ TEST(Converters, ATenSliceNegEndIndexConvertsCorrectly) {
204204
auto trt = trt_results[0].reshape(jit_results[0].sizes());
205205

206206
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt, 2e-6));
207-
}
207+
}
208+
209+
TEST(Converters, ATenSplitSizesInScriptingConvertsCorrectly) {
210+
const auto graph = R"IR(
211+
graph(%x.1 : Tensor):
212+
%2 : int[] = prim::Constant[value=[1, 2]]()
213+
%3 : int = prim::Constant[value=1]()
214+
%4 : Tensor[] = aten::split(%x.1, %2, %3)
215+
%x1.1 : Tensor, %x2.1 : Tensor = prim::ListUnpack(%4)
216+
return (%x1.1, %x2.1))IR";
217+
218+
auto g = std::make_shared<torch::jit::Graph>();
219+
220+
torch::jit::parseIR(graph, &*g);
221+
222+
auto in = at::randint(1, 10, {1, 3, 4, 4}, {at::kCUDA});
223+
224+
auto jit_in = at::clone(in);
225+
auto params = trtorch::core::conversion::get_named_params(g->inputs(), {});
226+
auto jit_results = trtorch::tests::util::RunGraph(g, params, {jit_in});
227+
228+
auto trt_in = at::clone(in);
229+
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {trt_in});
230+
231+
for (int i = 0; i < jit_results.size(); i++) {
232+
auto trt = trt_results[i].reshape(jit_results[i].sizes());
233+
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[i], trt, 2e-6));
234+
}
235+
}
236+
237+
TEST(Converters, ATenSplitSizesinTracingConvertsCorrectly) {
238+
const auto graph = R"IR(
239+
graph(%argument_1.1 : Tensor):
240+
%2 : int[] = prim::Constant[value=[1, 2]]()
241+
%3 : int = prim::Constant[value=1]()
242+
%4 : Tensor[] = aten::split_with_sizes(%argument_1.1, %2, %3)
243+
%5 : Tensor, %6 : Tensor = prim::ListUnpack(%4)
244+
return (%5, %6))IR";
245+
246+
auto g = std::make_shared<torch::jit::Graph>();
247+
248+
torch::jit::parseIR(graph, &*g);
249+
250+
auto in = at::randint(1, 10, {1, 3, 4, 4}, {at::kCUDA});
251+
252+
auto jit_in = at::clone(in);
253+
auto params = trtorch::core::conversion::get_named_params(g->inputs(), {});
254+
auto jit_results = trtorch::tests::util::RunGraph(g, params, {jit_in});
255+
256+
auto trt_in = at::clone(in);
257+
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {trt_in});
258+
259+
for (int i = 0; i < jit_results.size(); i++) {
260+
auto trt = trt_results[i].reshape(jit_results[i].sizes());
261+
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[i], trt, 2e-6));
262+
}
263+
}
264+
265+
TEST(Converters, ATenSplitFixedConvertsCorrectly) {
266+
const auto graph = R"IR(
267+
graph(%argument_1.1 : Tensor):
268+
%2 : int = prim::Constant[value=1]()
269+
%3 : Tensor[] = aten::split(%argument_1.1, %2, %2)
270+
%4 : Tensor, %5 : Tensor, %6 : Tensor = prim::ListUnpack(%3)
271+
return (%4, %5, %6))IR";
272+
273+
auto g = std::make_shared<torch::jit::Graph>();
274+
275+
torch::jit::parseIR(graph, &*g);
276+
277+
auto in = at::randint(1, 10, {1, 3, 4, 4}, {at::kCUDA});
278+
279+
auto jit_in = at::clone(in);
280+
auto params = trtorch::core::conversion::get_named_params(g->inputs(), {});
281+
auto jit_results = trtorch::tests::util::RunGraph(g, params, {jit_in});
282+
283+
auto trt_in = at::clone(in);
284+
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {trt_in});
285+
286+
for (int i = 0; i < jit_results.size(); i++) {
287+
auto trt = trt_results[i].reshape(jit_results[i].sizes());
288+
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[i], trt, 2e-6));
289+
}
290+
}

tests/core/conversion/evaluators/test_prim_evaluators.cpp

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,4 +17,23 @@ TEST(Evaluators, PrimConstantEvaluatesCorrectly) {
1717
auto trt_results = trtorch::tests::util::EvaluateGraph(g->block(), {});
1818

1919
ASSERT_TRUE(jit_results[0] == trt_results[0]);
20-
}
20+
}
21+
22+
TEST(Evaluators, PrimListUnpackEvaluatesCorrectly) {
23+
const auto graph = R"IR(
24+
graph():
25+
%1 : int = prim::Constant[value=3]()
26+
%2 : int = prim::Constant[value=4]()
27+
%lc : int[] = prim::ListConstruct(%1, %2)
28+
%lu.1 : int, %lu.2 : int = prim::ListUnpack(%lc)
29+
return (%lu.1, %lu.2))IR";
30+
31+
auto g = std::make_shared<torch::jit::Graph>();
32+
torch::jit::parseIR(graph, &*g);
33+
34+
auto jit_results = trtorch::tests::util::EvaluateGraphJIT(g, {});
35+
auto trt_results = trtorch::tests::util::EvaluateGraph(g->block(), {});
36+
37+
ASSERT_TRUE(jit_results[0] == trt_results[0]);
38+
ASSERT_TRUE(jit_results[1] == trt_results[1]);
39+
}

tests/util/evaluate_graph.cpp

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#include "core/conversion/converters/converters.h"
66
#include "core/conversion/evaluators/evaluators.h"
77
#include "core/conversion/var/Var.h"
8+
#include "core/util/jit_util.h"
89
#include "core/util/prelude.h"
910

1011
namespace trtorch {
@@ -20,20 +21,31 @@ std::vector<torch::jit::IValue> EvaluateGraph(const torch::jit::Block* b, std::v
2021
for (size_t i = 0; i < inputs.size(); i++) {
2122
ctx->AssociateValueAndIValue(b->inputs()[i], inputs[i]);
2223
}
23-
24+
LOG_DEBUG("Checking nodes");
2425
for (const auto n : b->nodes()) {
2526
TRTORCH_CHECK(
2627
core::conversion::evaluators::shouldEvalAtConversionTime(n),
2728
"Test graph contains non evaluatable nodes: " << *n);
2829
auto eval = core::conversion::EvaluateNode(ctx, n);
2930
if (eval) {
30-
if (!eval.value().isTensor()) {
31+
if (eval.value().isTuple()) {
32+
auto eval_list = eval.value().toTuple();
33+
for (int i = 0; i < eval_list->elements().size(); i++) {
34+
auto eval_output = eval_list.get()->elements()[i];
35+
LOG_DEBUG(
36+
ctx->logger,
37+
"Found the evaluated value(s) to be " << eval_output
38+
<< " for node: " << trtorch::core::util::node_info(n));
39+
ctx->AssociateValueAndIValue(n->output(i), eval_output);
40+
}
41+
} else if (!eval.value().isTensor()) {
3142
LOG_DEBUG("Found the value to be: " << eval.value());
43+
ctx->AssociateValueAndIValue(n->output(0), eval.value());
3244
} else {
3345
LOG_DEBUG("Found the value to be a tensor (shape " << eval.value().toTensor().sizes() << ')');
46+
ctx->AssociateValueAndIValue(n->output(0), eval.value());
3447
}
3548
}
36-
ctx->AssociateValueAndIValue(n->output(0), eval.value());
3749
}
3850

3951
std::vector<torch::jit::IValue> outputs;

0 commit comments

Comments
 (0)