1
1
#include < algorithm>
2
2
3
+ #include < cuda_runtime.h>
3
4
#include " NvInfer.h"
4
5
#include " torch/csrc/jit/frontend/function_schema_parser.h"
5
6
@@ -15,20 +16,43 @@ std::string slugify(std::string s) {
15
16
return s;
16
17
}
17
18
19
+ CudaDevice default_device = {0 , 0 , 0 , nvinfer1::DeviceType::kGPU , 0 , " " };
20
+
18
21
TRTEngine::TRTEngine (std::string serialized_engine)
19
22
: logger(
20
23
std::string (" [] - " ),
21
24
util::logging::get_logger().get_reportable_severity(),
22
25
util::logging::get_logger().get_is_colored_output_on()) {
23
26
std::string _name = " deserialized_trt" ;
24
- new (this ) TRTEngine (_name, serialized_engine);
27
+ // TODO: Need to add the option to configure target device at execution time
28
+ auto device = get_device_info (0 , nvinfer1::DeviceType::kGPU );
29
+ new (this ) TRTEngine (_name, serialized_engine, device);
30
+ }
31
+
32
+ TRTEngine::TRTEngine (std::vector<std::string> serialized_info)
33
+ : logger(
34
+ std::string (" [] = " ),
35
+ util::logging::get_logger().get_reportable_severity(),
36
+ util::logging::get_logger().get_is_colored_output_on()) {
37
+ std::string _name = " deserialized_trt" ;
38
+ std::string engine_info = serialized_info[EngineIdx];
39
+
40
+ CudaDevice cuda_device = deserialize_device (serialized_info[DeviceIdx]);
41
+
42
+ new (this ) TRTEngine (_name, engine_info, cuda_device);
25
43
}
26
44
27
- TRTEngine::TRTEngine (std::string mod_name, std::string serialized_engine)
45
+ TRTEngine::TRTEngine (
46
+ std::string mod_name,
47
+ std::string serialized_engine,
48
+ CudaDevice cuda_device)
28
49
: logger(
29
50
std::string (" [" ) + mod_name + std::string(" _engine] - " ),
30
51
util::logging::get_logger().get_reportable_severity(),
31
52
util::logging::get_logger().get_is_colored_output_on()) {
53
+
54
+ set_cuda_device (cuda_device);
55
+
32
56
rt = nvinfer1::createInferRuntime (logger);
33
57
34
58
name = slugify (mod_name) + " _engine" ;
@@ -63,6 +87,7 @@ TRTEngine& TRTEngine::operator=(const TRTEngine& other) {
63
87
id = other.id ;
64
88
rt = other.rt ;
65
89
cuda_engine = other.cuda_engine ;
90
+ device_info = other.device_info ;
66
91
exec_ctx = other.exec_ctx ;
67
92
num_io = other.num_io ;
68
93
return (*this );
@@ -82,21 +107,191 @@ TRTEngine::~TRTEngine() {
82
107
// return c10::List<at::Tensor>(output_vec);
83
108
// }
84
109
85
- namespace {
86
110
static auto TRTORCH_UNUSED TRTEngineTSRegistrtion =
87
111
torch::class_<TRTEngine>(" tensorrt" , " Engine" )
88
- .def(torch::init<std::string>())
112
+ .def(torch::init<std::vector<std:: string> >())
89
113
// TODO: .def("__call__", &TRTEngine::Run)
90
114
// TODO: .def("run", &TRTEngine::Run)
91
115
.def_pickle(
92
- [](const c10::intrusive_ptr<TRTEngine>& self) -> std::string {
93
- auto serialized_engine = self->cuda_engine ->serialize ();
94
- return std::string ((const char *)serialized_engine->data (), serialized_engine->size ());
116
+ [](const c10::intrusive_ptr<TRTEngine>& self) -> std::vector<std::string> {
117
+ // Serialize TensorRT engine
118
+ auto serialized_trt_engine = self->cuda_engine ->serialize ();
119
+
120
+ // Adding device info related meta data to the serialized file
121
+ auto trt_engine = std::string ((const char *)serialized_trt_engine->data (), serialized_trt_engine->size ());
122
+
123
+ std::vector<std::string> serialize_info;
124
+ serialize_info.push_back (serialize_device (self->device_info ));
125
+ serialize_info.push_back (trt_engine);
126
+ return serialize_info;
95
127
},
96
- [](std::string seralized_engine ) -> c10::intrusive_ptr<TRTEngine> {
97
- return c10::make_intrusive<TRTEngine>(std::move (seralized_engine ));
128
+ [](std::vector<std:: string> serial_info ) -> c10::intrusive_ptr<TRTEngine> {
129
+ return c10::make_intrusive<TRTEngine>(std::move (serial_info ));
98
130
});
99
- } // namespace
131
+
132
+ /*
133
+ int64_t CudaDevice::get_id(void) {
134
+ return this->id;
135
+ }
136
+
137
+ void CudaDevice::set_id(int64_t id) {
138
+ this->id = id;
139
+ }
140
+
141
+ int64_t CudaDevice::get_major(void) {
142
+ return this->major;
143
+ }
144
+
145
+ void CudaDevice::set_major(int64_t major) {
146
+ this->major = major;
147
+ }
148
+
149
+ int64_t CudaDevice::get_minor(void) {
150
+ return this->minor;
151
+ }
152
+
153
+ void CudaDevice::set_minor(int64_t minor) {
154
+ this->minor = minor;
155
+ }
156
+
157
+ nvinfer1::DeviceType get_device_type(void) {
158
+ return this->device_type;
159
+ }
160
+
161
+ void set_device_type(nvinfer1::DeviceType device_type) {
162
+ this->device_type = device_type;
163
+ }
164
+
165
+ std::string get_device_name(void) {
166
+ return this->device_name;
167
+ }
168
+
169
+ void set_device_name(std::string& name) {
170
+ this->device_name = name;
171
+ }
172
+
173
+ size_t get_device_name_len(void) {
174
+ return this->device_name_len;
175
+ }
176
+
177
+ void set_device_name_len(size_t size) {
178
+ this->device_name_len = size;
179
+ }
180
+ */
181
+
182
+ void set_cuda_device (CudaDevice& cuda_device) {
183
+ TRTORCH_CHECK ((cudaSetDevice (cuda_device.id ) == cudaSuccess), " Unable to set device: " << cuda_device.id );
184
+ }
185
+
186
+ void get_cuda_device (CudaDevice& cuda_device) {
187
+ TRTORCH_CHECK ((cudaGetDevice (reinterpret_cast <int *>(&cuda_device.id )) == cudaSuccess), " Unable to get current device: " << cuda_device.id );
188
+ cudaDeviceProp device_prop;
189
+ TRTORCH_CHECK (
190
+ (cudaGetDeviceProperties (&device_prop, cuda_device.id ) == cudaSuccess),
191
+ " Unable to get CUDA properties from device:" << cuda_device.id );
192
+ cuda_device.set_major (device_prop.major );
193
+ cuda_device.set_minor (device_prop.minor );
194
+ std::string device_name (device_prop.name );
195
+ cuda_device.set_device_name (device_name);
196
+ }
197
+
198
+ std::string serialize_device (CudaDevice& cuda_device) {
199
+ void * buffer = new char [sizeof (cuda_device)];
200
+ void * ref_buf = buffer;
201
+
202
+ int64_t temp = cuda_device.get_id ();
203
+ memcpy (buffer, reinterpret_cast <int64_t *>(&temp), sizeof (int64_t ));
204
+ buffer = static_cast <char *>(buffer) + sizeof (int64_t );
205
+
206
+ temp = cuda_device.get_major ();
207
+ memcpy (buffer, reinterpret_cast <int64_t *>(&temp), sizeof (int64_t ));
208
+ buffer = static_cast <char *>(buffer) + sizeof (int64_t );
209
+
210
+ temp = cuda_device.get_minor ();
211
+ memcpy (buffer, reinterpret_cast <int64_t *>(&temp), sizeof (int64_t ));
212
+ buffer = static_cast <char *>(buffer) + sizeof (int64_t );
213
+
214
+ auto device_type = cuda_device.get_device_type ();
215
+ memcpy (buffer, reinterpret_cast <char *>(&device_type), sizeof (nvinfer1::DeviceType));
216
+ buffer = static_cast <char *>(buffer) + sizeof (nvinfer1::DeviceType);
217
+
218
+ size_t device_name_len = cuda_device.get_device_name_len ();
219
+ memcpy (buffer, reinterpret_cast <char *>(&device_name_len), sizeof (size_t ));
220
+ buffer = static_cast <char *>(buffer) + sizeof (size_t );
221
+
222
+ auto device_name = cuda_device.get_device_name ();
223
+ memcpy (buffer, reinterpret_cast <char *>(&device_name), device_name.size ());
224
+ buffer = static_cast <char *>(buffer) + device_name.size ();
225
+
226
+ return std::string ((const char *)ref_buf, sizeof (int64_t ) * 3 + sizeof (nvinfer1::DeviceType) + device_name.size ());
227
+ }
228
+
229
+ CudaDevice deserialize_device (std::string device_info) {
230
+ CudaDevice ret;
231
+ char * buffer = new char [device_info.size () + 1 ];
232
+ std::copy (device_info.begin (), device_info.end (), buffer);
233
+ int64_t temp = 0 ;
234
+
235
+ memcpy (&temp, reinterpret_cast <char *>(buffer), sizeof (int64_t ));
236
+ buffer += sizeof (int64_t );
237
+ ret.set_id (temp);
238
+
239
+ memcpy (&temp, reinterpret_cast <char *>(buffer), sizeof (int64_t ));
240
+ buffer += sizeof (int64_t );
241
+ ret.set_major (temp);
242
+
243
+ memcpy (&temp, reinterpret_cast <char *>(buffer), sizeof (int64_t ));
244
+ buffer += sizeof (int64_t );
245
+ ret.set_minor (temp);
246
+
247
+ nvinfer1::DeviceType device_type;
248
+ memcpy (&device_type, reinterpret_cast <char *>(buffer), sizeof (nvinfer1::DeviceType));
249
+ buffer += sizeof (nvinfer1::DeviceType);
250
+
251
+ size_t size;
252
+ memcpy (&size, reinterpret_cast <size_t *>(&buffer), sizeof (size_t ));
253
+ buffer += sizeof (size_t );
254
+
255
+ ret.set_device_name_len (size);
256
+
257
+ std::string device_name;
258
+ memcpy (&device_name, reinterpret_cast <char *>(buffer), size * sizeof (char ));
259
+ buffer += size * sizeof (char );
260
+
261
+ ret.set_device_name (device_name);
262
+
263
+ return ret;
264
+ }
265
+
266
+ CudaDevice get_device_info (int64_t gpu_id, nvinfer1::DeviceType device_type) {
267
+ CudaDevice device;
268
+ cudaDeviceProp device_prop;
269
+
270
+ // Device ID
271
+ device.set_id (gpu_id);
272
+
273
+ // Get Device Properties
274
+ cudaGetDeviceProperties (&device_prop, gpu_id);
275
+
276
+ // Compute capability major version
277
+ device.set_major (device_prop.major );
278
+
279
+ // Compute capability minor version
280
+ device.set_minor (device_prop.minor );
281
+
282
+ std::string device_name (device_prop.name );
283
+
284
+ // Set Device name
285
+ device.set_device_name (device_name);
286
+
287
+ // Set Device name len for serialization/deserialization
288
+ device.set_device_name_len (device_name.size ());
289
+
290
+ // Set Device Type
291
+ device.set_device_type (device_type);
292
+ return device;
293
+ }
294
+
100
295
} // namespace runtime
101
296
} // namespace core
102
297
} // namespace trtorch
0 commit comments