1
1
#include < algorithm>
2
2
3
+ #include < cuda_runtime.h>
3
4
#include " NvInfer.h"
4
5
#include " torch/csrc/jit/frontend/function_schema_parser.h"
5
6
@@ -15,20 +16,35 @@ std::string slugify(std::string s) {
15
16
return s;
16
17
}
17
18
18
- TRTEngine::TRTEngine (std::string serialized_engine)
19
+ TRTEngine::TRTEngine (std::string serialized_engine, CudaDevice cuda_device )
19
20
: logger(
20
21
std::string (" [] - " ),
21
22
util::logging::get_logger().get_reportable_severity(),
22
23
util::logging::get_logger().get_is_colored_output_on()) {
23
24
std::string _name = " deserialized_trt" ;
24
- new (this ) TRTEngine (_name, serialized_engine);
25
+ new (this ) TRTEngine (_name, serialized_engine, cuda_device );
25
26
}
26
27
27
- TRTEngine::TRTEngine (std::string mod_name, std::string serialized_engine)
28
+ TRTEngine::TRTEngine (std::vector<std::string> serialized_info)
29
+ : logger(
30
+ std::string (" [] = " ),
31
+ util::logging::get_logger().get_reportable_severity(),
32
+ util::logging::get_logger().get_is_colored_output_on()) {
33
+ std::string _name = " deserialized_trt" ;
34
+ std::string engine_info = serialized_info[EngineIdx];
35
+
36
+ CudaDevice cuda_device = deserialize_device (serialized_info[DeviceIdx]);
37
+ new (this ) TRTEngine (_name, engine_info, cuda_device);
38
+ }
39
+
40
+ TRTEngine::TRTEngine (std::string mod_name, std::string serialized_engine, CudaDevice cuda_device)
28
41
: logger(
29
42
std::string (" [" ) + mod_name + std::string(" _engine] - " ),
30
43
util::logging::get_logger().get_reportable_severity(),
31
44
util::logging::get_logger().get_is_colored_output_on()) {
45
+ device_info = cuda_device;
46
+ set_cuda_device (device_info);
47
+
32
48
rt = nvinfer1::createInferRuntime (logger);
33
49
34
50
name = slugify (mod_name) + " _engine" ;
@@ -63,6 +79,7 @@ TRTEngine& TRTEngine::operator=(const TRTEngine& other) {
63
79
id = other.id ;
64
80
rt = other.rt ;
65
81
cuda_engine = other.cuda_engine ;
82
+ device_info = other.device_info ;
66
83
exec_ctx = other.exec_ctx ;
67
84
num_io = other.num_io ;
68
85
return (*this );
@@ -85,18 +102,144 @@ TRTEngine::~TRTEngine() {
85
102
namespace {
86
103
static auto TRTORCH_UNUSED TRTEngineTSRegistrtion =
87
104
torch::class_<TRTEngine>(" tensorrt" , " Engine" )
88
- .def(torch::init<std::string>())
105
+ .def(torch::init<std::vector<std:: string> >())
89
106
// TODO: .def("__call__", &TRTEngine::Run)
90
107
// TODO: .def("run", &TRTEngine::Run)
91
108
.def_pickle(
92
- [](const c10::intrusive_ptr<TRTEngine>& self) -> std::string {
93
- auto serialized_engine = self->cuda_engine ->serialize ();
94
- return std::string ((const char *)serialized_engine->data (), serialized_engine->size ());
109
+ [](const c10::intrusive_ptr<TRTEngine>& self) -> std::vector<std::string> {
110
+ // Serialize TensorRT engine
111
+ auto serialized_trt_engine = self->cuda_engine ->serialize ();
112
+
113
+ // Adding device info related meta data to the serialized file
114
+ auto trt_engine = std::string ((const char *)serialized_trt_engine->data (), serialized_trt_engine->size ());
115
+
116
+ std::vector<std::string> serialize_info;
117
+ serialize_info.push_back (serialize_device (self->device_info ));
118
+ serialize_info.push_back (trt_engine);
119
+ return serialize_info;
95
120
},
96
- [](std::string seralized_engine ) -> c10::intrusive_ptr<TRTEngine> {
97
- return c10::make_intrusive<TRTEngine>(std::move (seralized_engine ));
121
+ [](std::vector<std:: string> seralized_info ) -> c10::intrusive_ptr<TRTEngine> {
122
+ return c10::make_intrusive<TRTEngine>(std::move (seralized_info ));
98
123
});
99
124
} // namespace
125
+ void set_cuda_device (CudaDevice& cuda_device) {
126
+ TRTORCH_CHECK ((cudaSetDevice (cuda_device.id ) == cudaSuccess), " Unable to set device: " << cuda_device.id );
127
+ }
128
+
129
+ void get_cuda_device (CudaDevice& cuda_device) {
130
+ int device = 0 ;
131
+ TRTORCH_CHECK (
132
+ (cudaGetDevice (reinterpret_cast <int *>(&device)) == cudaSuccess),
133
+ " Unable to get current device: " << cuda_device.id );
134
+ cuda_device.id = static_cast <int64_t >(device);
135
+ cudaDeviceProp device_prop;
136
+ TRTORCH_CHECK (
137
+ (cudaGetDeviceProperties (&device_prop, cuda_device.id ) == cudaSuccess),
138
+ " Unable to get CUDA properties from device:" << cuda_device.id );
139
+ cuda_device.set_major (device_prop.major );
140
+ cuda_device.set_minor (device_prop.minor );
141
+ std::string device_name (device_prop.name );
142
+ cuda_device.set_device_name (device_name);
143
+ }
144
+
145
+ std::string serialize_device (CudaDevice& cuda_device) {
146
+ void * buffer = new char [sizeof (cuda_device)];
147
+ void * ref_buf = buffer;
148
+
149
+ int64_t temp = cuda_device.get_id ();
150
+ memcpy (buffer, reinterpret_cast <int64_t *>(&temp), sizeof (int64_t ));
151
+ buffer = static_cast <char *>(buffer) + sizeof (int64_t );
152
+
153
+ temp = cuda_device.get_major ();
154
+ memcpy (buffer, reinterpret_cast <int64_t *>(&temp), sizeof (int64_t ));
155
+ buffer = static_cast <char *>(buffer) + sizeof (int64_t );
156
+
157
+ temp = cuda_device.get_minor ();
158
+ memcpy (buffer, reinterpret_cast <int64_t *>(&temp), sizeof (int64_t ));
159
+ buffer = static_cast <char *>(buffer) + sizeof (int64_t );
160
+
161
+ auto device_type = cuda_device.get_device_type ();
162
+ memcpy (buffer, reinterpret_cast <char *>(&device_type), sizeof (nvinfer1::DeviceType));
163
+ buffer = static_cast <char *>(buffer) + sizeof (nvinfer1::DeviceType);
164
+
165
+ size_t device_name_len = cuda_device.get_device_name_len ();
166
+ memcpy (buffer, reinterpret_cast <char *>(&device_name_len), sizeof (size_t ));
167
+ buffer = static_cast <char *>(buffer) + sizeof (size_t );
168
+
169
+ auto device_name = cuda_device.get_device_name ();
170
+ memcpy (buffer, reinterpret_cast <char *>(&device_name), device_name.size ());
171
+ buffer = static_cast <char *>(buffer) + device_name.size ();
172
+
173
+ return std::string ((const char *)ref_buf, sizeof (int64_t ) * 3 + sizeof (nvinfer1::DeviceType) + device_name.size ());
174
+ }
175
+
176
+ CudaDevice deserialize_device (std::string device_info) {
177
+ CudaDevice ret;
178
+ char * buffer = new char [device_info.size () + 1 ];
179
+ std::copy (device_info.begin (), device_info.end (), buffer);
180
+ int64_t temp = 0 ;
181
+
182
+ memcpy (&temp, reinterpret_cast <char *>(buffer), sizeof (int64_t ));
183
+ buffer += sizeof (int64_t );
184
+ ret.set_id (temp);
185
+
186
+ memcpy (&temp, reinterpret_cast <char *>(buffer), sizeof (int64_t ));
187
+ buffer += sizeof (int64_t );
188
+ ret.set_major (temp);
189
+
190
+ memcpy (&temp, reinterpret_cast <char *>(buffer), sizeof (int64_t ));
191
+ buffer += sizeof (int64_t );
192
+ ret.set_minor (temp);
193
+
194
+ nvinfer1::DeviceType device_type;
195
+ memcpy (&device_type, reinterpret_cast <char *>(buffer), sizeof (nvinfer1::DeviceType));
196
+ buffer += sizeof (nvinfer1::DeviceType);
197
+
198
+ size_t size;
199
+ memcpy (&size, reinterpret_cast <size_t *>(&buffer), sizeof (size_t ));
200
+ buffer += sizeof (size_t );
201
+
202
+ ret.set_device_name_len (size);
203
+
204
+ std::string device_name;
205
+ memcpy (&device_name, reinterpret_cast <char *>(buffer), size * sizeof (char ));
206
+ buffer += size * sizeof (char );
207
+
208
+ ret.set_device_name (device_name);
209
+
210
+ return ret;
211
+ }
212
+
213
+ CudaDevice get_device_info (int64_t gpu_id, nvinfer1::DeviceType device_type) {
214
+ CudaDevice cuda_device;
215
+ cudaDeviceProp device_prop;
216
+
217
+ // Device ID
218
+ cuda_device.set_id (gpu_id);
219
+
220
+ // Get Device Properties
221
+ cudaGetDeviceProperties (&device_prop, gpu_id);
222
+
223
+ // Compute capability major version
224
+ cuda_device.set_major (device_prop.major );
225
+
226
+ // Compute capability minor version
227
+ cuda_device.set_minor (device_prop.minor );
228
+
229
+ std::string device_name (device_prop.name );
230
+
231
+ // Set Device name
232
+ cuda_device.set_device_name (device_name);
233
+
234
+ // Set Device name len for serialization/deserialization
235
+ cuda_device.set_device_name_len (device_name.size ());
236
+
237
+ // Set Device Type
238
+ cuda_device.set_device_type (device_type);
239
+
240
+ return cuda_device;
241
+ }
242
+
100
243
} // namespace runtime
101
244
} // namespace core
102
245
} // namespace trtorch
0 commit comments