Skip to content

Commit 819c911

Browse files
committed
fix(//core/runtime): Support more delimiter variants
Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
1 parent 67e320c commit 819c911

File tree

2 files changed

+41
-1
lines changed

2 files changed

+41
-1
lines changed

core/runtime/TRTEngine.cpp

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,13 +53,20 @@ TRTEngine::TRTEngine(std::string mod_name, std::string serialized_engine, CudaDe
5353
TORCHTRT_CHECK((cuda_engine.get() != nullptr), "Unable to deserialize the TensorRT engine");
5454

5555
exec_ctx = make_trt(cuda_engine->createExecutionContext());
56+
TORCHTRT_CHECK((exec_ctx.get() != nullptr), "Unable to create TensorRT execution context");
5657

5758
uint64_t inputs = 0;
5859
uint64_t outputs = 0;
5960

6061
for (int64_t x = 0; x < cuda_engine->getNbBindings(); x++) {
6162
std::string bind_name = cuda_engine->getBindingName(x);
62-
std::string idx_s = bind_name.substr(bind_name.find("_") + 1);
63+
auto delim = bind_name.find(".");
64+
if (delim == std::string::npos) {
65+
delim = bind_name.find("_");
66+
TORCHTRT_CHECK(delim != std::string::npos, "Unable to determine binding index for input " << bind_name << "\nEnsure module was compile with Torch-TensorRT.ts");
67+
}
68+
69+
std::string idx_s = bind_name.substr(delim + 1);
6370
uint64_t idx = static_cast<uint64_t>(std::stoi(idx_s));
6471

6572
if (cuda_engine->bindingIsInput(x)) {
@@ -71,6 +78,8 @@ TRTEngine::TRTEngine(std::string mod_name, std::string serialized_engine, CudaDe
7178
}
7279
}
7380
num_io = std::make_pair(inputs, outputs);
81+
82+
LOG_DEBUG(*this);
7483
}
7584

7685
TRTEngine& TRTEngine::operator=(const TRTEngine& other) {
@@ -82,6 +91,34 @@ TRTEngine& TRTEngine::operator=(const TRTEngine& other) {
8291
return (*this);
8392
}
8493

94+
std::string TRTEngine::to_str() const {
95+
std::stringstream ss;
96+
ss << "Torch-TensorRT TensorRT Engine:" << std::endl;
97+
ss << " Name: " << name << std::endl;
98+
ss << " Inputs: [" << std::endl;
99+
for (uint64_t i = 0; i < num_io.first; i++) {
100+
ss << " id: " << i << std::endl;
101+
ss << " shape: " << exec_ctx->getBindingDimensions(i) << std::endl;
102+
ss << " dtype: " << util::TRTDataTypeToScalarType(exec_ctx->getEngine().getBindingDataType(i)) << std::endl;
103+
}
104+
ss << " ]" << std::endl;
105+
ss << " Outputs: [" << std::endl;
106+
for (uint64_t o = 0; o < num_io.second; o++) {
107+
ss << " id: " << o << std::endl;
108+
ss << " shape: " << exec_ctx->getBindingDimensions(o) << std::endl;
109+
ss << " dtype: " << util::TRTDataTypeToScalarType(exec_ctx->getEngine().getBindingDataType(o)) << std::endl;
110+
}
111+
ss << " ]" << std::endl;
112+
ss << " Device: " << device_info << std::endl;
113+
114+
return ss.str();
115+
}
116+
117+
std::ostream& operator<<(std::ostream& os, const TRTEngine& engine) {
118+
os << engine.to_str();
119+
return os;
120+
}
121+
85122
// TODO: Implement a call method
86123
// c10::List<at::Tensor> TRTEngine::Run(c10::List<at::Tensor> inputs) {
87124
// auto input_vec = inputs.vec();
@@ -96,6 +133,7 @@ static auto TORCHTRT_UNUSED TRTEngineTSRegistrtion =
96133
.def(torch::init<std::vector<std::string>>())
97134
// TODO: .def("__call__", &TRTEngine::Run)
98135
// TODO: .def("run", &TRTEngine::Run)
136+
.def("__str__", &TRTEngine::to_str)
99137
.def_pickle(
100138
[](const c10::intrusive_ptr<TRTEngine>& self) -> std::vector<std::string> {
101139
// Serialize TensorRT engine

core/runtime/runtime.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,8 @@ struct TRTEngine : torch::CustomClassHolder {
5959
TRTEngine(std::vector<std::string> serialized_info);
6060
TRTEngine(std::string mod_name, std::string serialized_engine, CudaDevice cuda_device);
6161
TRTEngine& operator=(const TRTEngine& other);
62+
std::string to_str() const;
63+
friend std::ostream& operator<<(std::ostream& os, const TRTEngine& engine);
6264
// TODO: Implement a call method
6365
// c10::List<at::Tensor> Run(c10::List<at::Tensor> inputs);
6466
};

0 commit comments

Comments
 (0)