1
1
#include < algorithm>
2
2
3
+ #include < cuda_runtime.h>
3
4
#include " NvInfer.h"
4
5
#include " torch/csrc/jit/frontend/function_schema_parser.h"
5
6
@@ -15,20 +16,39 @@ std::string slugify(std::string s) {
15
16
return s;
16
17
}
17
18
18
- TRTEngine::TRTEngine (std::string serialized_engine)
19
+ TRTEngine::TRTEngine (std::string serialized_engine, CudaDevice device )
19
20
: logger(
20
21
std::string (" [] - " ),
21
22
util::logging::get_logger().get_reportable_severity(),
22
23
util::logging::get_logger().get_is_colored_output_on()) {
23
24
std::string _name = " deserialized_trt" ;
24
- new (this ) TRTEngine (_name, serialized_engine);
25
+ new (this ) TRTEngine (_name, serialized_engine, device );
25
26
}
26
27
27
- TRTEngine::TRTEngine (std::string mod_name, std::string serialized_engine)
28
+ TRTEngine::TRTEngine (std::vector<std::string> serialized_info)
29
+ : logger(
30
+ std::string (" [] = " ),
31
+ util::logging::get_logger().get_reportable_severity(),
32
+ util::logging::get_logger().get_is_colored_output_on()) {
33
+ std::string _name = " deserialized_trt" ;
34
+ std::string engine_info = serialized_info[EngineIdx];
35
+
36
+ CudaDevice cuda_device = deserialize_device (serialized_info[DeviceIdx]);
37
+
38
+ new (this ) TRTEngine (_name, engine_info, cuda_device);
39
+ }
40
+
41
+ TRTEngine::TRTEngine (
42
+ std::string mod_name,
43
+ std::string serialized_engine,
44
+ CudaDevice cuda_device)
28
45
: logger(
29
46
std::string (" [" ) + mod_name + std::string(" _engine] - " ),
30
47
util::logging::get_logger().get_reportable_severity(),
31
48
util::logging::get_logger().get_is_colored_output_on()) {
49
+
50
+ set_cuda_device (cuda_device);
51
+
32
52
rt = nvinfer1::createInferRuntime (logger);
33
53
34
54
name = slugify (mod_name) + " _engine" ;
@@ -63,6 +83,7 @@ TRTEngine& TRTEngine::operator=(const TRTEngine& other) {
63
83
id = other.id ;
64
84
rt = other.rt ;
65
85
cuda_engine = other.cuda_engine ;
86
+ device_info = other.device_info ;
66
87
exec_ctx = other.exec_ctx ;
67
88
num_io = other.num_io ;
68
89
return (*this );
@@ -82,21 +103,188 @@ TRTEngine::~TRTEngine() {
82
103
// return c10::List<at::Tensor>(output_vec);
83
104
// }
84
105
85
- namespace {
86
106
static auto TRTORCH_UNUSED TRTEngineTSRegistrtion =
87
107
torch::class_<TRTEngine>(" tensorrt" , " Engine" )
88
108
.def(torch::init<std::string>())
89
109
// TODO: .def("__call__", &TRTEngine::Run)
90
110
// TODO: .def("run", &TRTEngine::Run)
91
111
.def_pickle(
92
- [](const c10::intrusive_ptr<TRTEngine>& self) -> std::string {
93
- auto serialized_engine = self->cuda_engine ->serialize ();
94
- return std::string ((const char *)serialized_engine->data (), serialized_engine->size ());
112
+ [](const c10::intrusive_ptr<TRTEngine>& self) -> std::vector<std::string> {
113
+ // Serialize TensorRT engine
114
+ auto serialized_trt_engine = self->cuda_engine ->serialize ();
115
+
116
+ // Adding device info related meta data to the serialized file
117
+ auto trt_engine = std::string ((const char *)serialized_trt_engine->data (), serialized_trt_engine->size ());
118
+
119
+ std::vector<std::string> serialize_info;
120
+ serialize_info.push_back (serialize_device (self.cuda_device ));
121
+ serialize_info.push_back (trt_engine);
122
+ return serialize_info;
95
123
},
96
- [](std::string seralized_engine ) -> c10::intrusive_ptr<TRTEngine> {
97
- return c10::make_intrusive<TRTEngine>(std::move (seralized_engine ));
124
+ [](std::vector<std:: string> seralized_info ) -> c10::intrusive_ptr<TRTEngine> {
125
+ return c10::make_intrusive<TRTEngine>(std::move (seralized_info ));
98
126
});
99
- } // namespace
127
+
128
+ int64_t CudaDevice::get_id (void ) {
129
+ return this ->id ;
130
+ }
131
+
132
+ void CudaDevice::set_id (int64_t id) {
133
+ this ->id = id;
134
+ }
135
+
136
+ int64_t CudaDevice::get_major (void ) {
137
+ return this ->major ;
138
+ }
139
+
140
+ void CudaDevice::set_major (int64_t major) {
141
+ this ->major = major;
142
+ }
143
+
144
+ int64_t CudaDevice::get_minor (void ) {
145
+ return this ->minor ;
146
+ }
147
+
148
+ void CudaDevice::set_minor (int64_t minor) {
149
+ this ->minor = minor;
150
+ }
151
+
152
+ nvinfer1::DeviceType get_device_type (void ) {
153
+ return this ->device_type ;
154
+ }
155
+
156
+ void set_device_type (nvinfer1::DeviceType device_type) {
157
+ this ->device_type = device_type;
158
+ }
159
+
160
+ std::string get_device_name (void ) {
161
+ return this ->device_name ;
162
+ }
163
+
164
+ void set_device_name (std::string& name) {
165
+ this ->device_name = name;
166
+ }
167
+
168
+ size_t get_device_name_len (void ) {
169
+ return this ->device_name_len ;
170
+ }
171
+
172
+ void set_device_name_len (size_t size) {
173
+ this ->device_name_len = size;
174
+ }
175
+
176
+ void set_cuda_device (CudaDevice& cuda_device) {
177
+ TRTORCH_CHECK ((cudaSetDevice (cuda_device.id ) == cudaSuccess), " Unable to set device: " << cuda_device.id );
178
+ }
179
+
180
+ void get_cuda_device (CudaDevice& cuda_device) {
181
+ TRTORCH_CHECK ((cudaGetDevice (&cuda_device.id ) == cudaSuccess), " Unable to get current device: " << cuda_device.id );
182
+ cudaDeviceProp device_prop;
183
+ TRTORCH_CHECK (
184
+ (cudaGetDeviceProperties (&device_prop, cuda_device.id ) == cudaSuccess),
185
+ " Unable to get CUDA properties from device:" << cuda_device.id );
186
+ cuda_device.set_major (device_prop.major );
187
+ cuda_device.set_minor (device_prop.minor );
188
+ cuda_device.set_device_name (std::string (device_prop.name ));
189
+ }
190
+
191
+ std::string serialize_device (CudaDevice& cuda_device) {
192
+ void * buffer = new char [sizeof (cuda_device)];
193
+ void * ref_buf = buffer;
194
+
195
+ int64_t temp = cuda_device.get_id ();
196
+ memcpy (buffer, reinterpret_cast <int64_t *>(&temp), sizeof (int64_t ));
197
+ buffer = static_cast <char *>(buffer) + sizeof (int64_t );
198
+
199
+ temp = cuda_device.get_major ();
200
+ memcpy (buffer, reinterpret_cast <int64_t *>(&temp), sizeof (int64_t ));
201
+ buffer = static_cast <char *>(buffer) + sizeof (int64_t );
202
+
203
+ temp = cuda_device.get_minor ();
204
+ memcpy (buffer, reinterpret_cast <int64_t *>(&temp), sizeof (int64_t ));
205
+ buffer = static_cast <char *>(buffer) + sizeof (int64_t );
206
+
207
+ auto device_type = cuda_device.get_device_type ();
208
+ memcpy (buffer, reinterpret_cast <char *>(&device_type), sizeof (nvinfer1::DeviceType));
209
+ buffer = static_cast <char *>(buffer) + sizeof (nvinfer1::DeviceType);
210
+
211
+ size_t device_name_len = cuda_device.get_device_name_len ();
212
+ memcpy (buffer, reinterpret_cast <char *>(&device_name_len), sizeof (size_t ));
213
+ buffer = static_cast <char *>(buffer) + sizeof (size_t );
214
+
215
+ auto device_name = cuda_device.get_device_name ();
216
+ memcpy (buffer, reinterpret_cast <char *>(&device_name), device_name.size ());
217
+ buffer = static_cast <char *>(buffer) + device_name.size ();
218
+
219
+ return std::string ((const char *)ref_buf, sizeof (int64_t ) * 3 + sizeof (nvinfer1::DeviceType) + device_name.size ();
220
+ }
221
+
222
+ CudaDevice deserialize_device (std::string device_info) {
223
+ CudaDevice ret;
224
+ char * buffer = new char [device_info.size () + 1 ];
225
+ std::copy (device_info.begin (), device_info.end (), buffer);
226
+ int64_t temp = 0 ;
227
+
228
+ memcpy (&temp, reinterpret_cast <char *>(buffer), sizeof (int64_t ));
229
+ buffer += sizeof (int64_t );
230
+ ret.set_id (temp);
231
+
232
+ memcpy (&temp, reinterpret_cast <char *>(buffer), sizeof (int64_t ));
233
+ buffer += sizeof (int64_t );
234
+ ret.set_major (temp);
235
+
236
+ memcpy (&temp, reinterpret_cast <char *>(buffer), sizeof (int64_t ));
237
+ buffer += sizeof (int64_t );
238
+ ret.set_minor (temp);
239
+
240
+ nvinfer1::DeviceType device_type;
241
+ memcpy (&device_type, reinterpret_cast <char *>(buffer), sizeof (nvinfer1::DeviceType));
242
+ buffer += sizeof (nvinfer1::DeviceType);
243
+
244
+ size_t size;
245
+ memcpy (&size, reinterpret_cast <size_t *>(&buffer), sizeof (size_t ));
246
+ buffer += sizeof (size_t );
247
+
248
+ ret.set_device_name_len (size);
249
+
250
+ std::string device_name;
251
+ memcpy (&device_name, reinterpret_cast <char *>(buffer), size * sizeof (char ));
252
+ buffer += size * sizeof (char );
253
+
254
+ ret.set_device_name (device_name);
255
+
256
+ return ret;
257
+ }
258
+
259
+ CudaDevice spec_to_device (conversion::Device& spec) {
260
+ CudaDevice device;
261
+ cudaDeviceProp device_prop;
262
+
263
+ // Device ID
264
+ device.set_id (spec.gpu_id );
265
+
266
+ // Get Device Properties
267
+ cudaGetDeviceProperties (&device_prop, spec.gpu_id );
268
+
269
+ // Compute capability major version
270
+ device.set_major (device_prop.major );
271
+
272
+ // Compute capability minor version
273
+ device.set_minor (device_prop.minor );
274
+
275
+ std::string device_name = std::string (device_prop.name );
276
+
277
+ // Set Device name
278
+ device.set_device_name (device_name);
279
+
280
+ // Set Device name len for serialization/deserialization
281
+ device.set_device_name_len (device_nmae.size ());
282
+
283
+ // Set Device Type
284
+ device.set_device_type (spec.device_type );
285
+ return device;
286
+ }
287
+
100
288
} // namespace runtime
101
289
} // namespace core
102
290
} // namespace trtorch
0 commit comments