Skip to content

Commit ac4ac5e

Browse files
authored
Merge pull request #92 from NVIDIA/narendasan/conditionals
Adds support for the evaluation of conditionals
2 parents 1074277 + 9d1946e commit ac4ac5e

File tree

7 files changed

+588
-58
lines changed

7 files changed

+588
-58
lines changed

core/conversion/conversion.cpp

Lines changed: 64 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,8 @@ void AddParamsToCtxValueMap(ConversionCtx* ctx, GraphParams& params) {
190190
}
191191
}
192192

193+
void EvaluateLoopBlock(ConversionCtx* ctx, const torch::jit::Node* n);
194+
193195
void MapIValues(ConversionCtx* ctx, c10::ArrayRef<const torch::jit::Value*> in_list, c10::ArrayRef<const torch::jit::Value*> out_list, int64_t in_offset, int64_t out_offset) {
194196
std::vector<std::pair<const torch::jit::Value*, const torch::jit::Value*>> input_output_pairs;
195197
std::transform(in_list.begin() + in_offset, in_list.end(), out_list.begin() + out_offset,
@@ -199,11 +201,54 @@ void MapIValues(ConversionCtx* ctx, c10::ArrayRef<const torch::jit::Value*> in_l
199201
});
200202

201203
for (auto p : input_output_pairs) {
202-
auto input = ctx->evaluated_value_map[p.first];
203-
ctx->evaluated_value_map[p.second] = torch::jit::IValue(input);
204+
if (ctx->evaluated_value_map.find(p.first) != ctx->evaluated_value_map.end()) {
205+
auto input = ctx->evaluated_value_map[p.first];
206+
ctx->evaluated_value_map[p.second] = torch::jit::IValue(input);
207+
} else if (ctx->value_tensor_map.find(p.first) != ctx->value_tensor_map.end()) {
208+
auto input = ctx->value_tensor_map[p.first];
209+
ctx->value_tensor_map[p.second] = input;
210+
} else {
211+
TRTORCH_THROW_ERROR("Cannot find Value " << p.first->debugName() << " either evaluated values or tensor maps (MapIValues)");
212+
}
204213
}
205214
}
206215

216+
void EvaluateConditionalBlock(ConversionCtx* ctx, const torch::jit::Node* n, bool contained_in_loop = false) {
217+
bool output_type_includes_tensor = false;
218+
for (auto o : n->outputs()) {
219+
if (o->type()->isSubtypeOf(c10::TensorType::get())) {
220+
output_type_includes_tensor = true;
221+
}
222+
}
223+
TRTORCH_CHECK(!(contained_in_loop && output_type_includes_tensor), "TRTorch currently cannot compile conditionals within loops");
224+
225+
auto condition = ctx->evaluated_value_map[n->input(0)].toBool();
226+
LOG_DEBUG(ctx->logger, "(Conditional Evaluation) Evaluating block " << (int) condition);
227+
auto b = condition ? n->blocks()[0] : n->blocks()[1];
228+
229+
for (const auto bn : b->nodes()) {
230+
if (bn->kind() == torch::jit::prim::Loop) {
231+
EvaluateLoopBlock(ctx, bn);
232+
} else if (bn->kind() == torch::jit::prim::If) {
233+
EvaluateConditionalBlock(ctx, bn, contained_in_loop);
234+
} else if (evaluators::shouldEvalAtConversionTime(bn)) {
235+
auto eval = EvaluateNode(ctx, bn);
236+
if (!eval.value().isTensor()) {
237+
LOG_DEBUG(ctx->logger, "(Conditional Evaluation) Found the value to be: " << eval.value());
238+
} else {
239+
LOG_DEBUG(ctx->logger, "(Conditional Evaluation) Found the value to be a tensor (shape " << eval.value().toTensor().sizes() << ')');
240+
}
241+
ctx->AssociateValueAndIValue(bn->output(0), eval.value());
242+
} else if (converters::node_is_convertable(bn)) {
243+
AddLayer(ctx, bn);
244+
} else {
245+
TRTORCH_THROW_ERROR("TRTorch is unable to compile this conditional, a converter or evaluator is not available for node " << *bn);
246+
}
247+
}
248+
249+
MapIValues(ctx, b->outputs(), n->outputs(), 0, 0);
250+
}
251+
207252
// TODO: With functionalization pass we may be able to make this into a regular evaluator later
208253
void EvaluateLoopBlock(ConversionCtx* ctx, const torch::jit::Node* n) {
209254
auto max_trip_count = ctx->evaluated_value_map[n->input(0)];
@@ -213,16 +258,21 @@ void EvaluateLoopBlock(ConversionCtx* ctx, const torch::jit::Node* n) {
213258

214259
MapIValues(ctx, n->inputs(), n->outputs(), 2, 0);
215260

216-
LOG_DEBUG("(Loop Evaluation) Evaluating loop " << *n);
217-
LOG_DEBUG("(Loop Evaluation) Max Trip Count: " << max_trip_count.toInt());
218-
LOG_DEBUG("(Loop Evaluation) Start Condition: " << start_cond.toBool());
219-
LOG_DEBUG("(Loop Evaluation) Current Trip Count: " << trip_count.toInt());
261+
LOG_DEBUG(ctx->logger, "(Loop Evaluation) Evaluating loop " << *n);
262+
LOG_DEBUG(ctx->logger, "(Loop Evaluation) Max Trip Count: " << max_trip_count.toInt());
263+
LOG_DEBUG(ctx->logger, "(Loop Evaluation) Start Condition: " << start_cond.toBool());
264+
LOG_DEBUG(ctx->logger, "(Loop Evaluation) Current Trip Count: " << trip_count.toInt());
220265

221266
while (start_cond.toBool() && trip_count.toInt() < max_trip_count.toInt()) {
222267
MapIValues(ctx, n->outputs(), n->blocks()[0]->inputs(), 0, 1);
223268
for (auto bn : n->blocks()[0]->nodes()) {
224-
auto eval = EvaluateNode(ctx, bn);
225-
if (eval) {
269+
if (bn->kind() == torch::jit::prim::Loop) {
270+
EvaluateLoopBlock(ctx, n);
271+
} else if (bn->kind() == torch::jit::prim::If) {
272+
EvaluateConditionalBlock(ctx, bn, true);
273+
} else {
274+
TRTORCH_CHECK(evaluators::shouldEvalAtConversionTime(bn), "TRTorch currently can only compile loops that are evaluatable at conversion time but node " << *bn << " cannot be evaluated.");
275+
auto eval = EvaluateNode(ctx, bn);
226276
if (!eval.value().isTensor()) {
227277
LOG_DEBUG(ctx->logger, "(Loop Evaluation) Found the value to be: " << eval.value());
228278
} else {
@@ -236,8 +286,8 @@ void EvaluateLoopBlock(ConversionCtx* ctx, const torch::jit::Node* n) {
236286
start_cond = ctx->evaluated_value_map[n->blocks()[0]->outputs()[0]];
237287
auto new_trip_count = torch::jit::IValue(trip_count.toInt() + 1);
238288
trip_count.swap(new_trip_count);
239-
LOG_DEBUG("(Loop Evaluation) Condition: " << start_cond.toBool());
240-
LOG_DEBUG("(Loop Evaluation) Current Trip Count: " << trip_count.toInt());
289+
LOG_DEBUG(ctx->logger, "(Loop Evaluation) Condition: " << start_cond.toBool());
290+
LOG_DEBUG(ctx->logger, "(Loop Evaluation) Current Trip Count: " << trip_count.toInt());
241291
}
242292
}
243293

@@ -255,6 +305,8 @@ void ConvertBlockToNetDef(ConversionCtx* ctx, const torch::jit::Block* b, Conver
255305
bool blacklisted = isNodeConversionBlacklisted(n);
256306
if (n->kind() == torch::jit::prim::Loop) {
257307
EvaluateLoopBlock(ctx, n);
308+
} else if (n->kind() == torch::jit::prim::If) {
309+
EvaluateConditionalBlock(ctx, n);
258310
} else if (to_eval) {
259311
auto eval = EvaluateNode(ctx, n);
260312
if (eval) {
@@ -303,10 +355,10 @@ std::string ConvertBlockToEngine(const torch::jit::Block* b, ConversionInfo buil
303355
std::set<std::string> GetUnsupportedOpsInBlock(const torch::jit::Block* b ) {
304356
std::set<std::string> unsupported_ops;
305357
for (const auto n : b->nodes()) {
306-
if (n->kind() != torch::jit::prim::Loop && !OpSupported(n)) {
358+
if (n->kind() != torch::jit::prim::Loop && n->kind() != torch::jit::prim::If && !OpSupported(n)) {
307359
auto schema = n->maybeSchema();
308360
TRTORCH_CHECK(schema, "Unable to get schema for Node " << util::node_info(n) \
309-
<< " (conversion.VerifyCoverterSupportForBlock");
361+
<< " (conversion.VerifyCoverterSupportForBlock)");
310362
std::stringstream ss;
311363
ss << *schema;
312364
unsupported_ops.insert(ss.str());

core/conversion/converters/NodeConverterRegistry.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,10 @@ class NodeConverterRegistry {
4848
bool RegisterConverter(torch::jit::FunctionSchema* signature, OpConverter& converter) {
4949
LOG_DEBUG("Registering converter for " << canonical_schema_string(*signature));
5050
auto name = signature->operator_name();
51+
auto iter = converter_lut_.find(name);
52+
if (iter != converter_lut_.end()) {
53+
LOG_WARNING("Overriding already registered converter " << signature->name() << ", unexpected behavior may occur");
54+
}
5155
converter_lut_[name] = std::move(converter);
5256
return true;
5357
}

core/conversion/evaluators/BUILD

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@ cc_library(
1515
srcs = [
1616
"NodeEvaluatorRegistry.cpp",
1717
"prim.cpp",
18-
"aten.cpp"
18+
"aten.cpp",
19+
"eval_macros.h"
1920
],
2021
deps = [
2122
"//core/util:prelude",

core/conversion/evaluators/NodeEvaluatorRegistry.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,10 @@ class NodeEvaluatorRegistry {
3030
public:
3131
void RegisterEvaluator(torch::jit::NodeKind node_kind, EvalRegistration eval_reg) {
3232
LOG_DEBUG("Registering evaluator for " << node_kind.toQualString());
33+
auto iter = evaluator_lut_.find(node_kind);
34+
if (iter != evaluator_lut_.end()) {
35+
TRTORCH_THROW_ERROR("Attempting to override already registered evaluator " << node_kind.toQualString() << ", merge implementations instead");
36+
}
3337
evaluator_lut_[node_kind] = std::move(eval_reg);
3438
}
3539

0 commit comments

Comments
 (0)