@@ -53,13 +53,20 @@ TRTEngine::TRTEngine(std::string mod_name, std::string serialized_engine, CudaDe
53
53
TORCHTRT_CHECK ((cuda_engine.get () != nullptr ), " Unable to deserialize the TensorRT engine" );
54
54
55
55
exec_ctx = make_trt (cuda_engine->createExecutionContext ());
56
+ TORCHTRT_CHECK ((exec_ctx.get () != nullptr ), " Unable to create TensorRT execution context" );
56
57
57
58
uint64_t inputs = 0 ;
58
59
uint64_t outputs = 0 ;
59
60
60
61
for (int64_t x = 0 ; x < cuda_engine->getNbBindings (); x++) {
61
62
std::string bind_name = cuda_engine->getBindingName (x);
62
- std::string idx_s = bind_name.substr (bind_name.find (" _" ) + 1 );
63
+ auto delim = bind_name.find (" ." );
64
+ if (delim == std::string::npos) {
65
+ delim = bind_name.find (" _" );
66
+ TORCHTRT_CHECK (delim != std::string::npos, " Unable to determine binding index for input " << bind_name << " \n Ensure module was compile with Torch-TensorRT.ts" );
67
+ }
68
+
69
+ std::string idx_s = bind_name.substr (delim + 1 );
63
70
uint64_t idx = static_cast <uint64_t >(std::stoi (idx_s));
64
71
65
72
if (cuda_engine->bindingIsInput (x)) {
@@ -71,6 +78,8 @@ TRTEngine::TRTEngine(std::string mod_name, std::string serialized_engine, CudaDe
71
78
}
72
79
}
73
80
num_io = std::make_pair (inputs, outputs);
81
+
82
+ LOG_DEBUG (*this );
74
83
}
75
84
76
85
TRTEngine& TRTEngine::operator =(const TRTEngine& other) {
@@ -82,6 +91,34 @@ TRTEngine& TRTEngine::operator=(const TRTEngine& other) {
82
91
return (*this );
83
92
}
84
93
94
+ std::string TRTEngine::to_str () const {
95
+ std::stringstream ss;
96
+ ss << " Torch-TensorRT TensorRT Engine:" << std::endl;
97
+ ss << " Name: " << name << std::endl;
98
+ ss << " Inputs: [" << std::endl;
99
+ for (uint64_t i = 0 ; i < num_io.first ; i++) {
100
+ ss << " id: " << i << std::endl;
101
+ ss << " shape: " << exec_ctx->getBindingDimensions (i) << std::endl;
102
+ ss << " dtype: " << util::TRTDataTypeToScalarType (exec_ctx->getEngine ().getBindingDataType (i)) << std::endl;
103
+ }
104
+ ss << " ]" << std::endl;
105
+ ss << " Outputs: [" << std::endl;
106
+ for (uint64_t o = 0 ; o < num_io.second ; o++) {
107
+ ss << " id: " << o << std::endl;
108
+ ss << " shape: " << exec_ctx->getBindingDimensions (o) << std::endl;
109
+ ss << " dtype: " << util::TRTDataTypeToScalarType (exec_ctx->getEngine ().getBindingDataType (o)) << std::endl;
110
+ }
111
+ ss << " ]" << std::endl;
112
+ ss << " Device: " << device_info << std::endl;
113
+
114
+ return ss.str ();
115
+ }
116
+
117
+ std::ostream& operator <<(std::ostream& os, const TRTEngine& engine) {
118
+ os << engine.to_str ();
119
+ return os;
120
+ }
121
+
85
122
// TODO: Implement a call method
86
123
// c10::List<at::Tensor> TRTEngine::Run(c10::List<at::Tensor> inputs) {
87
124
// auto input_vec = inputs.vec();
@@ -96,6 +133,7 @@ static auto TORCHTRT_UNUSED TRTEngineTSRegistrtion =
96
133
.def(torch::init<std::vector<std::string>>())
97
134
// TODO: .def("__call__", &TRTEngine::Run)
98
135
// TODO: .def("run", &TRTEngine::Run)
136
+ .def(" __str__" , &TRTEngine::to_str)
99
137
.def_pickle(
100
138
[](const c10::intrusive_ptr<TRTEngine>& self) -> std::vector<std::string> {
101
139
// Serialize TensorRT engine
0 commit comments