Skip to content

Adds support for the evaluation of conditionals #92

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jun 15, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 64 additions & 12 deletions core/conversion/conversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,8 @@ void AddParamsToCtxValueMap(ConversionCtx* ctx, GraphParams& params) {
}
}

void EvaluateLoopBlock(ConversionCtx* ctx, const torch::jit::Node* n);

void MapIValues(ConversionCtx* ctx, c10::ArrayRef<const torch::jit::Value*> in_list, c10::ArrayRef<const torch::jit::Value*> out_list, int64_t in_offset, int64_t out_offset) {
std::vector<std::pair<const torch::jit::Value*, const torch::jit::Value*>> input_output_pairs;
std::transform(in_list.begin() + in_offset, in_list.end(), out_list.begin() + out_offset,
Expand All @@ -199,11 +201,54 @@ void MapIValues(ConversionCtx* ctx, c10::ArrayRef<const torch::jit::Value*> in_l
});

for (auto p : input_output_pairs) {
auto input = ctx->evaluated_value_map[p.first];
ctx->evaluated_value_map[p.second] = torch::jit::IValue(input);
if (ctx->evaluated_value_map.find(p.first) != ctx->evaluated_value_map.end()) {
auto input = ctx->evaluated_value_map[p.first];
ctx->evaluated_value_map[p.second] = torch::jit::IValue(input);
} else if (ctx->value_tensor_map.find(p.first) != ctx->value_tensor_map.end()) {
auto input = ctx->value_tensor_map[p.first];
ctx->value_tensor_map[p.second] = input;
} else {
TRTORCH_THROW_ERROR("Cannot find Value " << p.first->debugName() << " either evaluated values or tensor maps (MapIValues)");
}
}
}

void EvaluateConditionalBlock(ConversionCtx* ctx, const torch::jit::Node* n, bool contained_in_loop = false) {
bool output_type_includes_tensor = false;
for (auto o : n->outputs()) {
if (o->type()->isSubtypeOf(c10::TensorType::get())) {
output_type_includes_tensor = true;
}
}
TRTORCH_CHECK(!(contained_in_loop && output_type_includes_tensor), "TRTorch currently cannot compile conditionals within loops");

auto condition = ctx->evaluated_value_map[n->input(0)].toBool();
LOG_DEBUG(ctx->logger, "(Conditional Evaluation) Evaluating block " << (int) condition);
auto b = condition ? n->blocks()[0] : n->blocks()[1];

for (const auto bn : b->nodes()) {
if (bn->kind() == torch::jit::prim::Loop) {
EvaluateLoopBlock(ctx, bn);
} else if (bn->kind() == torch::jit::prim::If) {
EvaluateConditionalBlock(ctx, bn, contained_in_loop);
} else if (evaluators::shouldEvalAtConversionTime(bn)) {
auto eval = EvaluateNode(ctx, bn);
if (!eval.value().isTensor()) {
LOG_DEBUG(ctx->logger, "(Conditional Evaluation) Found the value to be: " << eval.value());
} else {
LOG_DEBUG(ctx->logger, "(Conditional Evaluation) Found the value to be a tensor (shape " << eval.value().toTensor().sizes() << ')');
}
ctx->AssociateValueAndIValue(bn->output(0), eval.value());
} else if (converters::node_is_convertable(bn)) {
AddLayer(ctx, bn);
} else {
TRTORCH_THROW_ERROR("TRTorch is unable to compile this conditional, a converter or evaluator is not available for node " << *bn);
}
}

MapIValues(ctx, b->outputs(), n->outputs(), 0, 0);
}

// TODO: With functionalization pass we may be able to make this into a regular evaluator later
void EvaluateLoopBlock(ConversionCtx* ctx, const torch::jit::Node* n) {
auto max_trip_count = ctx->evaluated_value_map[n->input(0)];
Expand All @@ -213,16 +258,21 @@ void EvaluateLoopBlock(ConversionCtx* ctx, const torch::jit::Node* n) {

MapIValues(ctx, n->inputs(), n->outputs(), 2, 0);

LOG_DEBUG("(Loop Evaluation) Evaluating loop " << *n);
LOG_DEBUG("(Loop Evaluation) Max Trip Count: " << max_trip_count.toInt());
LOG_DEBUG("(Loop Evaluation) Start Condition: " << start_cond.toBool());
LOG_DEBUG("(Loop Evaluation) Current Trip Count: " << trip_count.toInt());
LOG_DEBUG(ctx->logger, "(Loop Evaluation) Evaluating loop " << *n);
LOG_DEBUG(ctx->logger, "(Loop Evaluation) Max Trip Count: " << max_trip_count.toInt());
LOG_DEBUG(ctx->logger, "(Loop Evaluation) Start Condition: " << start_cond.toBool());
LOG_DEBUG(ctx->logger, "(Loop Evaluation) Current Trip Count: " << trip_count.toInt());

while (start_cond.toBool() && trip_count.toInt() < max_trip_count.toInt()) {
MapIValues(ctx, n->outputs(), n->blocks()[0]->inputs(), 0, 1);
for (auto bn : n->blocks()[0]->nodes()) {
auto eval = EvaluateNode(ctx, bn);
if (eval) {
if (bn->kind() == torch::jit::prim::Loop) {
EvaluateLoopBlock(ctx, n);
} else if (bn->kind() == torch::jit::prim::If) {
EvaluateConditionalBlock(ctx, bn, true);
} else {
TRTORCH_CHECK(evaluators::shouldEvalAtConversionTime(bn), "TRTorch currently can only compile loops that are evaluatable at conversion time but node " << *bn << " cannot be evaluated.");
auto eval = EvaluateNode(ctx, bn);
if (!eval.value().isTensor()) {
LOG_DEBUG(ctx->logger, "(Loop Evaluation) Found the value to be: " << eval.value());
} else {
Expand All @@ -236,8 +286,8 @@ void EvaluateLoopBlock(ConversionCtx* ctx, const torch::jit::Node* n) {
start_cond = ctx->evaluated_value_map[n->blocks()[0]->outputs()[0]];
auto new_trip_count = torch::jit::IValue(trip_count.toInt() + 1);
trip_count.swap(new_trip_count);
LOG_DEBUG("(Loop Evaluation) Condition: " << start_cond.toBool());
LOG_DEBUG("(Loop Evaluation) Current Trip Count: " << trip_count.toInt());
LOG_DEBUG(ctx->logger, "(Loop Evaluation) Condition: " << start_cond.toBool());
LOG_DEBUG(ctx->logger, "(Loop Evaluation) Current Trip Count: " << trip_count.toInt());
}
}

Expand All @@ -255,6 +305,8 @@ void ConvertBlockToNetDef(ConversionCtx* ctx, const torch::jit::Block* b, Conver
bool blacklisted = isNodeConversionBlacklisted(n);
if (n->kind() == torch::jit::prim::Loop) {
EvaluateLoopBlock(ctx, n);
} else if (n->kind() == torch::jit::prim::If) {
EvaluateConditionalBlock(ctx, n);
} else if (to_eval) {
auto eval = EvaluateNode(ctx, n);
if (eval) {
Expand Down Expand Up @@ -303,10 +355,10 @@ std::string ConvertBlockToEngine(const torch::jit::Block* b, ConversionInfo buil
std::set<std::string> GetUnsupportedOpsInBlock(const torch::jit::Block* b ) {
std::set<std::string> unsupported_ops;
for (const auto n : b->nodes()) {
if (n->kind() != torch::jit::prim::Loop && !OpSupported(n)) {
if (n->kind() != torch::jit::prim::Loop && n->kind() != torch::jit::prim::If && !OpSupported(n)) {
auto schema = n->maybeSchema();
TRTORCH_CHECK(schema, "Unable to get schema for Node " << util::node_info(n) \
<< " (conversion.VerifyCoverterSupportForBlock");
<< " (conversion.VerifyCoverterSupportForBlock)");
std::stringstream ss;
ss << *schema;
unsupported_ops.insert(ss.str());
Expand Down
4 changes: 4 additions & 0 deletions core/conversion/converters/NodeConverterRegistry.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,10 @@ class NodeConverterRegistry {
bool RegisterConverter(torch::jit::FunctionSchema* signature, OpConverter& converter) {
LOG_DEBUG("Registering converter for " << canonical_schema_string(*signature));
auto name = signature->operator_name();
auto iter = converter_lut_.find(name);
if (iter != converter_lut_.end()) {
LOG_WARNING("Overriding already registered converter " << signature->name() << ", unexpected behavior may occur");
}
converter_lut_[name] = std::move(converter);
return true;
}
Expand Down
3 changes: 2 additions & 1 deletion core/conversion/evaluators/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ cc_library(
srcs = [
"NodeEvaluatorRegistry.cpp",
"prim.cpp",
"aten.cpp"
"aten.cpp",
"eval_macros.h"
],
deps = [
"//core/util:prelude",
Expand Down
4 changes: 4 additions & 0 deletions core/conversion/evaluators/NodeEvaluatorRegistry.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@ class NodeEvaluatorRegistry {
public:
void RegisterEvaluator(torch::jit::NodeKind node_kind, EvalRegistration eval_reg) {
LOG_DEBUG("Registering evaluator for " << node_kind.toQualString());
auto iter = evaluator_lut_.find(node_kind);
if (iter != evaluator_lut_.end()) {
TRTORCH_THROW_ERROR("Attempting to override already registered evaluator " << node_kind.toQualString() << ", merge implementations instead");
}
evaluator_lut_[node_kind] = std::move(eval_reg);
}

Expand Down
Loading