Skip to content

Commit 65af9d1

Browse files
committed
refactor(//core/runtime): Updating the logging for runtime
deserialization NOTE: This does not fully address the deserialization issue as the root cause is TensorRT modifies the input binding names which is leading to these cases of stoi errors. Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
1 parent 49d367d commit 65af9d1

File tree

2 files changed

+7
-5
lines changed

2 files changed

+7
-5
lines changed

core/ir/ir.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ InputSpecMap pair_input_vals_with_specs(std::vector<const torch::jit::Value*> va
2121

2222
std::unordered_map<const torch::jit::Value*, core::ir::Input> a;
2323
for (size_t i = 0; i < vals.size(); i++) {
24-
LOG_DEBUG("Paring " << i << ": " << vals[i]->debugName() << " : " << specs[i]);
24+
LOG_DEBUG("Pairing " << i << ": " << vals[i]->debugName() << " : " << specs[i]);
2525
a.insert({vals[i], specs[i]});
2626
}
2727
return a;

core/runtime/TRTEngine.cpp

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -60,13 +60,15 @@ TRTEngine::TRTEngine(std::string mod_name, std::string serialized_engine, CudaDe
6060

6161
for (int64_t x = 0; x < cuda_engine->getNbBindings(); x++) {
6262
std::string bind_name = cuda_engine->getBindingName(x);
63+
LOG_DEBUG("Binding name: " << bind_name);
6364
auto delim = bind_name.find(".");
6465
if (delim == std::string::npos) {
6566
delim = bind_name.find("_");
6667
TORCHTRT_CHECK(
6768
delim != std::string::npos,
68-
"Unable to determine binding index for input " << bind_name
69-
<< "\nEnsure module was compile with Torch-TensorRT.ts");
69+
"Unable to determine binding index for input "
70+
<< bind_name
71+
<< "\nEnsure module was compiled with Torch-TensorRT.ts or follows Torch-TensorRT Runtime conventions");
7072
}
7173

7274
std::string idx_s = bind_name.substr(delim + 1);
@@ -108,8 +110,8 @@ std::string TRTEngine::to_str() const {
108110
ss << " Outputs: [" << std::endl;
109111
for (uint64_t o = 0; o < num_io.second; o++) {
110112
ss << " id: " << o << std::endl;
111-
ss << " shape: " << exec_ctx->getBindingDimensions(o) << std::endl;
112-
ss << " dtype: " << util::TRTDataTypeToScalarType(exec_ctx->getEngine().getBindingDataType(o)) << std::endl;
113+
ss << " shape: " << exec_ctx->getBindingDimensions(o) << std::endl;
114+
ss << " dtype: " << util::TRTDataTypeToScalarType(exec_ctx->getEngine().getBindingDataType(o)) << std::endl;
113115
}
114116
ss << " ]" << std::endl;
115117
ss << " Device: " << device_info << std::endl;

0 commit comments

Comments
 (0)