1
1
#include < algorithm>
2
2
3
+ #include < cuda_runtime.h>
3
4
#include " NvInfer.h"
4
5
#include " torch/csrc/jit/frontend/function_schema_parser.h"
5
6
@@ -23,7 +24,7 @@ TRTEngine::TRTEngine(std::string serialized_engine)
23
24
util::logging::get_logger().get_reportable_severity(),
24
25
util::logging::get_logger().get_is_colored_output_on()) {
25
26
std::string _name = " deserialized_trt" ;
26
- new (this ) TRTEngine (_name, serialized_engine);
27
+ new (this ) TRTEngine (_name, serialized_engine, empty_string );
27
28
}
28
29
29
30
TRTEngine::TRTEngine (std::vector<std::string> serialized_info)
@@ -37,7 +38,8 @@ TRTEngine::TRTEngine(std::vector<std::string> serialized_info)
37
38
new (this ) TRTEngine (_name, engine_info, device_info);
38
39
}
39
40
40
- TRTEngine::TRTEngine (std::string mod_name, std::string serialized_engine)
41
+ TRTEngine::TRTEngine (std::string mod_name, std::string serialized_engine,
42
+ std::string serialized_device_info = empty_string)
41
43
: logger(
42
44
std::string (" [" ) + mod_name + std::string(" _engine] - " ),
43
45
util::logging::get_logger().get_reportable_severity(),
@@ -105,7 +107,6 @@ TRTEngine::~TRTEngine() {
105
107
// return c10::List<at::Tensor>(output_vec);
106
108
// }
107
109
108
- namespace {
109
110
static auto TRTORCH_UNUSED TRTEngineTSRegistrtion =
110
111
torch::class_<TRTEngine>(" tensorrt" , " Engine" )
111
112
.def(torch::init<std::string>())
@@ -120,13 +121,14 @@ static auto TRTORCH_UNUSED TRTEngineTSRegistrtion =
120
121
auto trt_engine = std::string ((const char *)serialized_trt_engine->data (), serialized_trt_engine->size ());
121
122
122
123
CudaDevice cuda_device;
124
+ get_cuda_device (cuda_device);
123
125
std::vector<std::string> serialize_info;
124
126
serialize_info.push_back (serialize_device (cuda_device));
125
127
serialize_info.push_back (trt_engine);
126
128
return serialize_info;
127
129
},
128
- [](std::string seralized_engine ) -> c10::intrusive_ptr<TRTEngine> {
129
- return c10::make_intrusive<TRTEngine>(std::move (seralized_engine ));
130
+ [](std::vector<std:: string> seralized_info ) -> c10::intrusive_ptr<TRTEngine> {
131
+ return c10::make_intrusive<TRTEngine>(std::move (seralized_info ));
130
132
});
131
133
132
134
int CudaDevice::get_id (void ) {
@@ -154,13 +156,13 @@ void CudaDevice::set_minor(int minor) {
154
156
}
155
157
156
158
void set_cuda_device (CudaDevice& cuda_device) {
157
- TRTORCH_CHECK ((cudaSetDevice (cuda_device.id ) ! = cudaSuccess), " Unable to set device: " << cuda_device.id );
159
+ TRTORCH_CHECK ((cudaSetDevice (cuda_device.id ) = = cudaSuccess), " Unable to set device: " << cuda_device.id );
158
160
}
159
161
160
162
void get_cuda_device (CudaDevice& cuda_device) {
161
- TRTORCH_CHECK ((cudaGetDevice (&cuda_device.id ) ! = cudaSuccess), " Unable to get current device: " << cuda_device.id );
163
+ TRTORCH_CHECK ((cudaGetDevice (&cuda_device.id ) = = cudaSuccess), " Unable to get current device: " << cuda_device.id );
162
164
cudaDeviceProp device_prop;
163
- TRTORCH_CHECK ((cudaGetDeviceProperties (&device_prop, cuda_device.id ) ! = cudaSuccess), " Unable to get CUDA properties from device:" << cuda_device.id );
165
+ TRTORCH_CHECK ((cudaGetDeviceProperties (&device_prop, cuda_device.id ) = = cudaSuccess), " Unable to get CUDA properties from device:" << cuda_device.id );
164
166
cuda_device.set_major (device_prop.major );
165
167
cuda_device.set_minor (device_prop.minor );
166
168
}
@@ -205,8 +207,6 @@ CudaDevice deserialize_device(std::string device_info) {
205
207
return ret;
206
208
}
207
209
208
-
209
- } // namespace
210
210
} // namespace runtime
211
211
} // namespace core
212
212
} // namespace trtorch
0 commit comments