Skip to content

Narendasan/int8 mixed precision fix #94

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 44 additions & 10 deletions core/conversion/conversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,8 @@ void AddParamsToCtxValueMap(ConversionCtx* ctx, GraphParams& params) {
}
}

void EvaluateLoopBlock(ConversionCtx* ctx, const torch::jit::Node* n);

void MapIValues(ConversionCtx* ctx, c10::ArrayRef<const torch::jit::Value*> in_list, c10::ArrayRef<const torch::jit::Value*> out_list, int64_t in_offset, int64_t out_offset) {
std::vector<std::pair<const torch::jit::Value*, const torch::jit::Value*>> input_output_pairs;
std::transform(in_list.begin() + in_offset, in_list.end(), out_list.begin() + out_offset,
Expand All @@ -204,6 +206,31 @@ void MapIValues(ConversionCtx* ctx, c10::ArrayRef<const torch::jit::Value*> in_l
}
}

void EvaluateConditionalBlock(ConversionCtx* ctx, const torch::jit::Node* n) {
auto condition = ctx->evaluated_value_map[n->input(0)].toBool();
LOG_DEBUG(ctx->logger, "(Conditional Evaluation) Evaluating block " << (int) condition);
auto b = condition ? n->blocks()[0] : n->blocks()[1];

for (const auto bn : b->nodes()) {
if (bn->kind() == torch::jit::prim::Loop) {
EvaluateLoopBlock(ctx, bn);
} else if (bn->kind() == torch::jit::prim::If) {
EvaluateConditionalBlock(ctx, bn);
} else {
TRTORCH_CHECK(evaluators::shouldEvalAtConversionTime(bn), "TRTorch currently can only compile conditionals that are evaluatable at conversion time but node " << *bn << " cannot be evaluated.")
auto eval = EvaluateNode(ctx, bn);
if (!eval.value().isTensor()) {
LOG_DEBUG(ctx->logger, "(Conditional Evaluation) Found the value to be: " << eval.value());
} else {
LOG_DEBUG(ctx->logger, "(Conditional Evaluation) Found the value to be a tensor (shape " << eval.value().toTensor().sizes() << ')');
}
ctx->AssociateValueAndIValue(bn->output(0), eval.value());
}
}

MapIValues(ctx, b->outputs(), n->outputs(), 0, 0);
}

// TODO: With functionalization pass we may be able to make this into a regular evaluator later
void EvaluateLoopBlock(ConversionCtx* ctx, const torch::jit::Node* n) {
auto max_trip_count = ctx->evaluated_value_map[n->input(0)];
Expand All @@ -213,16 +240,21 @@ void EvaluateLoopBlock(ConversionCtx* ctx, const torch::jit::Node* n) {

MapIValues(ctx, n->inputs(), n->outputs(), 2, 0);

LOG_DEBUG("(Loop Evaluation) Evaluating loop " << *n);
LOG_DEBUG("(Loop Evaluation) Max Trip Count: " << max_trip_count.toInt());
LOG_DEBUG("(Loop Evaluation) Start Condition: " << start_cond.toBool());
LOG_DEBUG("(Loop Evaluation) Current Trip Count: " << trip_count.toInt());
LOG_DEBUG(ctx->logger, "(Loop Evaluation) Evaluating loop " << *n);
LOG_DEBUG(ctx->logger, "(Loop Evaluation) Max Trip Count: " << max_trip_count.toInt());
LOG_DEBUG(ctx->logger, "(Loop Evaluation) Start Condition: " << start_cond.toBool());
LOG_DEBUG(ctx->logger, "(Loop Evaluation) Current Trip Count: " << trip_count.toInt());

while (start_cond.toBool() && trip_count.toInt() < max_trip_count.toInt()) {
MapIValues(ctx, n->outputs(), n->blocks()[0]->inputs(), 0, 1);
for (auto bn : n->blocks()[0]->nodes()) {
auto eval = EvaluateNode(ctx, bn);
if (eval) {
if (bn->kind() == torch::jit::prim::Loop) {
EvaluateLoopBlock(ctx, n);
} else if (bn->kind() == torch::jit::prim::If) {
EvaluateConditionalBlock(ctx, bn);
} else {
TRTORCH_CHECK(evaluators::shouldEvalAtConversionTime(bn), "TRTorch currently can only compile loops that are evaluatable at conversion time but node " << *bn << " cannot be evaluated.");
auto eval = EvaluateNode(ctx, bn);
if (!eval.value().isTensor()) {
LOG_DEBUG(ctx->logger, "(Loop Evaluation) Found the value to be: " << eval.value());
} else {
Expand All @@ -236,8 +268,8 @@ void EvaluateLoopBlock(ConversionCtx* ctx, const torch::jit::Node* n) {
start_cond = ctx->evaluated_value_map[n->blocks()[0]->outputs()[0]];
auto new_trip_count = torch::jit::IValue(trip_count.toInt() + 1);
trip_count.swap(new_trip_count);
LOG_DEBUG("(Loop Evaluation) Condition: " << start_cond.toBool());
LOG_DEBUG("(Loop Evaluation) Current Trip Count: " << trip_count.toInt());
LOG_DEBUG(ctx->logger, "(Loop Evaluation) Condition: " << start_cond.toBool());
LOG_DEBUG(ctx->logger, "(Loop Evaluation) Current Trip Count: " << trip_count.toInt());
}
}

Expand All @@ -255,6 +287,8 @@ void ConvertBlockToNetDef(ConversionCtx* ctx, const torch::jit::Block* b, Conver
bool blacklisted = isNodeConversionBlacklisted(n);
if (n->kind() == torch::jit::prim::Loop) {
EvaluateLoopBlock(ctx, n);
} else if (n->kind() == torch::jit::prim::If) {
EvaluateConditionalBlock(ctx, n);
} else if (to_eval) {
auto eval = EvaluateNode(ctx, n);
if (eval) {
Expand Down Expand Up @@ -303,10 +337,10 @@ std::string ConvertBlockToEngine(const torch::jit::Block* b, ConversionInfo buil
std::set<std::string> GetUnsupportedOpsInBlock(const torch::jit::Block* b ) {
std::set<std::string> unsupported_ops;
for (const auto n : b->nodes()) {
if (n->kind() != torch::jit::prim::Loop && !OpSupported(n)) {
if (n->kind() != torch::jit::prim::Loop && n->kind() != torch::jit::prim::If && !OpSupported(n)) {
auto schema = n->maybeSchema();
TRTORCH_CHECK(schema, "Unable to get schema for Node " << util::node_info(n) \
<< " (conversion.VerifyCoverterSupportForBlock");
<< " (conversion.VerifyCoverterSupportForBlock)");
std::stringstream ss;
ss << *schema;
unsupported_ops.insert(ss.str());
Expand Down
5 changes: 4 additions & 1 deletion core/conversion/conversionctx/ConversionCtx.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ std::ostream& operator<<(std::ostream& os, const BuilderSettings& s) {
<< "\n Operating Precision: " << s.op_precision \
<< "\n Make Refittable Engine: " << s.refit \
<< "\n Debuggable Engine: " << s.debug \
<< "\n Strict Type: " << s.strict_types \
<< "\n Strict Types: " << s.strict_types \
<< "\n Allow GPU Fallback (if running on DLA): " << s.allow_gpu_fallback \
<< "\n Min Timing Iterations: " << s.num_min_timing_iters \
<< "\n Avg Timing Iterations: " << s.num_avg_timing_iters \
Expand Down Expand Up @@ -51,6 +51,9 @@ ConversionCtx::ConversionCtx(BuilderSettings build_settings)
case nvinfer1::DataType::kINT8:
TRTORCH_CHECK(builder->platformHasFastInt8(), "Requested inference in INT8 but platform does support INT8");
cfg->setFlag(nvinfer1::BuilderFlag::kINT8);
if (!settings.strict_types) {
cfg->setFlag(nvinfer1::BuilderFlag::kFP16);
}
input_type = nvinfer1::DataType::kFLOAT;
TRTORCH_CHECK(settings.calibrator != nullptr, "Requested inference in INT8 but no calibrator provided, set the ptq_calibrator field in the ExtraInfo struct with your calibrator");
cfg->setInt8Calibrator(settings.calibrator);
Expand Down
4 changes: 4 additions & 0 deletions core/conversion/converters/NodeConverterRegistry.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,10 @@ class NodeConverterRegistry {
bool RegisterConverter(torch::jit::FunctionSchema* signature, OpConverter& converter) {
LOG_DEBUG("Registering converter for " << canonical_schema_string(*signature));
auto name = signature->operator_name();
auto iter = converter_lut_.find(name);
if (iter != converter_lut_.end()) {
LOG_WARNING("Overriding already registered converter " << signature->name() << ", unexpected behavior may occur");
}
converter_lut_[name] = std::move(converter);
return true;
}
Expand Down
3 changes: 2 additions & 1 deletion core/conversion/evaluators/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ cc_library(
srcs = [
"NodeEvaluatorRegistry.cpp",
"prim.cpp",
"aten.cpp"
"aten.cpp",
"eval_macros.h"
],
deps = [
"//core/util:prelude",
Expand Down
4 changes: 4 additions & 0 deletions core/conversion/evaluators/NodeEvaluatorRegistry.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@ class NodeEvaluatorRegistry {
public:
void RegisterEvaluator(torch::jit::NodeKind node_kind, EvalRegistration eval_reg) {
LOG_DEBUG("Registering evaluator for " << node_kind.toQualString());
auto iter = evaluator_lut_.find(node_kind);
if (iter != evaluator_lut_.end()) {
TRTORCH_THROW_ERROR("Attempting to override already registered evaluator " << node_kind.toQualString() << ", merge implementations instead");
}
evaluator_lut_[node_kind] = std::move(eval_reg);
}

Expand Down
Loading