@@ -42,11 +42,10 @@ TRTEngine::TRTEngine(std::string mod_name, std::string serialized_engine, std::s
42
42
util::logging::get_logger().get_reportable_severity(),
43
43
util::logging::get_logger().get_is_colored_output_on()) {
44
44
45
- CudaDevice cuda_device;
46
45
// Deserialize device meta data if device_info is non-empty
47
46
if (!serialized_device_info.empty ())
48
47
{
49
- cuda_device = deserialize_device (serialized_device_info);
48
+ auto cuda_device = deserialize_device (serialized_device_info);
50
49
// Set CUDA device as configured in serialized meta data
51
50
set_cuda_device (cuda_device);
52
51
}
@@ -119,6 +118,7 @@ static auto TRTORCH_UNUSED TRTEngineTSRegistrtion = torch::class_<TRTEngine>("te
119
118
auto trt_engine = std::string ((const char *)serialized_trt_engine->data (), serialized_trt_engine->size ());
120
119
121
120
CudaDevice cuda_device;
121
+ get_cuda_device (cuda_device);
122
122
std::vector<std::string> serialize_info;
123
123
serialize_info.push_back (serialize_device (cuda_device));
124
124
serialize_info.push_back (trt_engine);
@@ -155,13 +155,13 @@ void CudaDevice::set_minor(int minor) {
155
155
}
156
156
157
157
void set_cuda_device (CudaDevice& cuda_device) {
158
- TRTORCH_CHECK ((cudaSetDevice (cuda_device.id ) ! = cudaSuccess), " Unable to set device: " << cuda_device.id );
158
+ TRTORCH_CHECK ((cudaSetDevice (cuda_device.id ) = = cudaSuccess), " Unable to set device: " << cuda_device.id );
159
159
}
160
160
161
161
void get_cuda_device (CudaDevice& cuda_device) {
162
- TRTORCH_CHECK ((cudaGetDevice (&cuda_device.id ) ! = cudaSuccess), " Unable to get current device: " << cuda_device.id );
162
+ TRTORCH_CHECK ((cudaGetDevice (&cuda_device.id ) = = cudaSuccess), " Unable to get current device: " << cuda_device.id );
163
163
cudaDeviceProp device_prop;
164
- TRTORCH_CHECK ((cudaGetDeviceProperties (&device_prop, cuda_device.id ) ! = cudaSuccess), " Unable to get CUDA properties from device:" << cuda_device.id );
164
+ TRTORCH_CHECK ((cudaGetDeviceProperties (&device_prop, cuda_device.id ) = = cudaSuccess), " Unable to get CUDA properties from device:" << cuda_device.id );
165
165
cuda_device.set_major (device_prop.major );
166
166
cuda_device.set_minor (device_prop.minor );
167
167
}
0 commit comments