@@ -16,13 +16,13 @@ std::string slugify(std::string s) {
16
16
return s;
17
17
}
18
18
19
- TRTEngine::TRTEngine (std::string serialized_engine)
19
+ TRTEngine::TRTEngine (std::string serialized_engine, CudaDevice device )
20
20
: logger(
21
21
std::string (" [] - " ),
22
22
util::logging::get_logger().get_reportable_severity(),
23
23
util::logging::get_logger().get_is_colored_output_on()) {
24
24
std::string _name = " deserialized_trt" ;
25
- new (this ) TRTEngine (_name, serialized_engine, std::string () );
25
+ new (this ) TRTEngine (_name, serialized_engine, device );
26
26
}
27
27
28
28
TRTEngine::TRTEngine (std::vector<std::string> serialized_info)
@@ -31,27 +31,23 @@ TRTEngine::TRTEngine(std::vector<std::string> serialized_info)
31
31
util::logging::get_logger().get_reportable_severity(),
32
32
util::logging::get_logger().get_is_colored_output_on()) {
33
33
std::string _name = " deserialized_trt" ;
34
- std::string device_info = serialized_info[0 ];
35
- std::string engine_info = serialized_info[1 ];
34
+ std::string engine_info = serialized_info[EngineIdx];
36
35
37
- new (this ) TRTEngine (_name, engine_info, device_info);
36
+ CudaDevice cuda_device = deserialize_device (serialized_info[DeviceIdx]);
37
+
38
+ new (this ) TRTEngine (_name, engine_info, cuda_device);
38
39
}
39
40
40
41
TRTEngine::TRTEngine (
41
42
std::string mod_name,
42
43
std::string serialized_engine,
43
- std::string serialized_device_info = std::string() )
44
+ CudaDevice cuda_device )
44
45
: logger(
45
46
std::string (" [" ) + mod_name + std::string(" _engine] - " ),
46
47
util::logging::get_logger().get_reportable_severity(),
47
48
util::logging::get_logger().get_is_colored_output_on()) {
48
- CudaDevice cuda_device;
49
- // Deserialize device meta data if device_info is non-empty
50
- if (!serialized_device_info.empty ()) {
51
- cuda_device = deserialize_device (serialized_device_info);
52
- // Set CUDA device as configured in serialized meta data
53
- set_cuda_device (cuda_device);
54
- }
49
+
50
+ set_cuda_device (cuda_device);
55
51
56
52
rt = nvinfer1::createInferRuntime (logger);
57
53
@@ -120,41 +116,63 @@ static auto TRTORCH_UNUSED TRTEngineTSRegistrtion =
120
116
// Adding device info related meta data to the serialized file
121
117
auto trt_engine = std::string ((const char *)serialized_trt_engine->data (), serialized_trt_engine->size ());
122
118
123
- CudaDevice cuda_device;
124
- get_cuda_device (cuda_device);
125
119
std::vector<std::string> serialize_info;
126
- serialize_info.push_back (serialize_device (cuda_device));
120
+ serialize_info.push_back (serialize_device (self. cuda_device ));
127
121
serialize_info.push_back (trt_engine);
128
122
return serialize_info;
129
123
},
130
124
[](std::vector<std::string> seralized_info) -> c10::intrusive_ptr<TRTEngine> {
131
125
return c10::make_intrusive<TRTEngine>(std::move (seralized_info));
132
126
});
133
127
134
- int CudaDevice::get_id (void ) {
128
+ int64_t CudaDevice::get_id (void ) {
135
129
return this ->id ;
136
130
}
137
131
138
- void CudaDevice::set_id (int id) {
132
+ void CudaDevice::set_id (int64_t id) {
139
133
this ->id = id;
140
134
}
141
135
142
- int CudaDevice::get_major (void ) {
136
+ int64_t CudaDevice::get_major (void ) {
143
137
return this ->major ;
144
138
}
145
139
146
- void CudaDevice::set_major (int major) {
140
+ void CudaDevice::set_major (int64_t major) {
147
141
this ->major = major;
148
142
}
149
143
150
- int CudaDevice::get_minor (void ) {
144
+ int64_t CudaDevice::get_minor (void ) {
151
145
return this ->minor ;
152
146
}
153
147
154
- void CudaDevice::set_minor (int minor) {
148
+ void CudaDevice::set_minor (int64_t minor) {
155
149
this ->minor = minor;
156
150
}
157
151
152
+ nvinfer1::DeviceType get_device_type (void ) {
153
+ return this ->device_type ;
154
+ }
155
+
156
+ void set_device_type (nvinfer1::DeviceType device_type) {
157
+ this ->device_type = device_type;
158
+ }
159
+
160
+ std::string get_device_name (void ) {
161
+ return this ->device_name ;
162
+ }
163
+
164
+ void set_device_name (std::string& name) {
165
+ this ->device_name = name;
166
+ }
167
+
168
+ size_t get_device_name_len (void ) {
169
+ return this ->device_name_len ;
170
+ }
171
+
172
+ void set_device_name_len (size_t size) {
173
+ this ->device_name_len = size;
174
+ }
175
+
158
176
void set_cuda_device (CudaDevice& cuda_device) {
159
177
TRTORCH_CHECK ((cudaSetDevice (cuda_device.id ) == cudaSuccess), " Unable to set device: " << cuda_device.id );
160
178
}
@@ -167,48 +185,106 @@ void get_cuda_device(CudaDevice& cuda_device) {
167
185
" Unable to get CUDA properties from device:" << cuda_device.id );
168
186
cuda_device.set_major (device_prop.major );
169
187
cuda_device.set_minor (device_prop.minor );
188
+ cuda_device.set_device_name (std::string (device_prop.name ));
170
189
}
171
190
172
191
std::string serialize_device (CudaDevice& cuda_device) {
173
192
void * buffer = new char [sizeof (cuda_device)];
174
193
void * ref_buf = buffer;
175
194
176
- int temp = cuda_device.get_id ();
177
- memcpy (buffer, reinterpret_cast <int *>(&temp), sizeof (int ));
178
- buffer = static_cast <char *>(buffer) + sizeof (int );
195
+ int64_t temp = cuda_device.get_id ();
196
+ memcpy (buffer, reinterpret_cast <int64_t *>(&temp), sizeof (int64_t ));
197
+ buffer = static_cast <char *>(buffer) + sizeof (int64_t );
179
198
180
199
temp = cuda_device.get_major ();
181
- memcpy (buffer, reinterpret_cast <int *>(&temp), sizeof (int ));
182
- buffer = static_cast <char *>(buffer) + sizeof (int );
200
+ memcpy (buffer, reinterpret_cast <int64_t *>(&temp), sizeof (int64_t ));
201
+ buffer = static_cast <char *>(buffer) + sizeof (int64_t );
183
202
184
203
temp = cuda_device.get_minor ();
185
- memcpy (buffer, reinterpret_cast <int *>(&temp), sizeof (int ));
186
- buffer = static_cast <char *>(buffer) + sizeof (int );
204
+ memcpy (buffer, reinterpret_cast <int64_t *>(&temp), sizeof (int64_t ));
205
+ buffer = static_cast <char *>(buffer) + sizeof (int64_t );
206
+
207
+ auto device_type = cuda_device.get_device_type ();
208
+ memcpy (buffer, reinterpret_cast <char *>(&device_type), sizeof (nvinfer1::DeviceType));
209
+ buffer = static_cast <char *>(buffer) + sizeof (nvinfer1::DeviceType);
210
+
211
+ size_t device_name_len = cuda_device.get_device_name_len ();
212
+ memcpy (buffer, reinterpret_cast <char *>(&device_name_len), sizeof (size_t ));
213
+ buffer = static_cast <char *>(buffer) + sizeof (size_t );
187
214
188
- return std::string ((const char *)ref_buf, sizeof (int ) * 3 );
215
+ auto device_name = cuda_device.get_device_name ();
216
+ memcpy (buffer, reinterpret_cast <char *>(&device_name), device_name.size ());
217
+ buffer = static_cast <char *>(buffer) + device_name.size ();
218
+
219
+ return std::string ((const char *)ref_buf, sizeof (int64_t ) * 3 + sizeof (nvinfer1::DeviceType) + device_name.size ();
189
220
}
190
221
191
222
CudaDevice deserialize_device (std::string device_info) {
192
223
CudaDevice ret;
193
224
char * buffer = new char [device_info.size () + 1 ];
194
225
std::copy (device_info.begin (), device_info.end (), buffer);
195
- int temp = 0 ;
226
+ int64_t temp = 0 ;
196
227
197
- memcpy (&temp, reinterpret_cast <char *>(buffer), sizeof (int ));
198
- buffer += sizeof (int );
228
+ memcpy (&temp, reinterpret_cast <char *>(buffer), sizeof (int64_t ));
229
+ buffer += sizeof (int64_t );
199
230
ret.set_id (temp);
200
231
201
- memcpy (&temp, reinterpret_cast <char *>(buffer), sizeof (int ));
202
- buffer += sizeof (int );
232
+ memcpy (&temp, reinterpret_cast <char *>(buffer), sizeof (int64_t ));
233
+ buffer += sizeof (int64_t );
203
234
ret.set_major (temp);
204
235
205
- memcpy (&temp, reinterpret_cast <char *>(buffer), sizeof (int ));
206
- buffer += sizeof (int );
236
+ memcpy (&temp, reinterpret_cast <char *>(buffer), sizeof (int64_t ));
237
+ buffer += sizeof (int64_t );
207
238
ret.set_minor (temp);
208
239
240
+ nvinfer1::DeviceType device_type;
241
+ memcpy (&device_type, reinterpret_cast <char *>(buffer), sizeof (nvinfer1::DeviceType));
242
+ buffer += sizeof (nvinfer1::DeviceType);
243
+
244
+ size_t size;
245
+ memcpy (&size, reinterpret_cast <size_t *>(&buffer), sizeof (size_t ));
246
+ buffer += sizeof (size_t );
247
+
248
+ ret.set_device_name_len (size);
249
+
250
+ std::string device_name;
251
+ memcpy (&device_name, reinterpret_cast <char *>(buffer), size * sizeof (char ));
252
+ buffer += size * sizeof (char );
253
+
254
+ ret.set_device_name (device_name);
255
+
209
256
return ret;
210
257
}
211
258
259
+ CudaDevice spec_to_device (conversion::Device& spec) {
260
+ CudaDevice device;
261
+ cudaDeviceProp device_prop;
262
+
263
+ // Device ID
264
+ device.set_id (spec.gpu_id );
265
+
266
+ // Get Device Properties
267
+ cudaGetDeviceProperties (&device_prop, spec.gpu_id );
268
+
269
+ // Compute capability major version
270
+ device.set_major (device_prop.major );
271
+
272
+ // Compute capability minor version
273
+ device.set_minor (device_prop.minor );
274
+
275
+ std::string device_name = std::string (device_prop.name );
276
+
277
+ // Set Device name
278
+ device.set_device_name (device_name);
279
+
280
+ // Set Device name len for serialization/deserialization
281
+ device.set_device_name_len (device_nmae.size ());
282
+
283
+ // Set Device Type
284
+ device.set_device_type (spec.device_type );
285
+ return device;
286
+ }
287
+
212
288
} // namespace runtime
213
289
} // namespace core
214
290
} // namespace trtorch
0 commit comments