Skip to content

Commit 5f77f56

Browse files
authored
fix: Error caused by invalid binding name in TRTEngine.to_str() method (#1846)
1 parent 1d78f43 commit 5f77f56

File tree

1 file changed

+10
-6
lines changed

1 file changed

+10
-6
lines changed

core/runtime/TRTEngine.cpp

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -118,9 +118,11 @@ TRTEngine::TRTEngine(
118118
TORCHTRT_CHECK(
119119
(cuda_engine->getTensorIOMode(binding_name.c_str()) == nvinfer1::TensorIOMode::kINPUT),
120120
"Binding " << binding_name << " specified as input but found as output in TensorRT engine");
121-
LOG_DEBUG("Input binding name: " << binding_name << "pyt arg idx: " << pyt_idx << ")");
121+
LOG_DEBUG(
122+
"Input binding name: " << binding_name << " has TensorRT binding index: " << trt_idx
123+
<< ", Torch binding index: " << pyt_idx);
122124
in_binding_map[trt_idx] = pyt_idx;
123-
in_binding_names[pyt_idx] = _in_binding_names[pyt_idx];
125+
in_binding_names[pyt_idx] = binding_name;
124126
}
125127

126128
uint64_t outputs = _out_binding_names.size();
@@ -210,19 +212,21 @@ std::string TRTEngine::to_str() const {
210212
ss << " Inputs: [" << std::endl;
211213
for (uint64_t i = 0; i < num_io.first; i++) {
212214
ss << " id: " << i << std::endl;
213-
ss << " shape: " << exec_ctx->getTensorShape(std::string("input_" + str(i)).c_str()) << std::endl;
215+
ss << " name: " << in_binding_names[i].c_str() << std::endl;
216+
ss << " shape: " << exec_ctx->getTensorShape(in_binding_names[i].c_str()) << std::endl;
214217
ss << " dtype: "
215-
<< util::TRTDataTypeToScalarType(exec_ctx->getEngine().getTensorDataType(std::string("input_" + str(i)).c_str()))
218+
<< util::TRTDataTypeToScalarType(exec_ctx->getEngine().getTensorDataType(in_binding_names[i].c_str()))
216219
<< std::endl;
217220
}
218221
ss << " ]" << std::endl;
219222
ss << " Outputs: [" << std::endl;
220223
for (uint64_t o = 0; o < num_io.second; o++) {
221224
ss << " id: " << o << std::endl;
222-
ss << " shape: " << exec_ctx->getTensorShape(std::string("output_" + str(o)).c_str()) << std::endl;
225+
ss << " name: " << out_binding_names[o].c_str() << std::endl;
226+
ss << " shape: " << exec_ctx->getTensorShape(out_binding_names[o].c_str()) << std::endl;
223227
ss << " dtype: "
224228
<< util::TRTDataTypeToScalarType(
225-
exec_ctx->getEngine().getTensorDataType(std::string("output_" + str(o)).c_str()))
229+
exec_ctx->getEngine().getTensorDataType(out_binding_names[o].c_str()))
226230
<< std::endl;
227231
}
228232
ss << " }" << std::endl;

0 commit comments

Comments
 (0)