Skip to content

Commit 968cb82

Browse files
author
Anurag Dixit
committed
Device metadata serialization deserialization
Signed-off-by: Anurag Dixit <[email protected]>
1 parent e3dd820 commit 968cb82

17 files changed

+479
-40
lines changed

core/compiler.cpp

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
#include <cuda_runtime.h>
12
#include <iostream>
23
#include <memory>
34
#include <sstream>
@@ -46,8 +47,9 @@ c10::FunctionSchema GenerateGraphSchema(
4647
void AddEngineToGraph(
4748
torch::jit::script::Module mod,
4849
std::shared_ptr<torch::jit::Graph>& g,
49-
std::string& serialized_engine) {
50-
auto engine_ptr = c10::make_intrusive<runtime::TRTEngine>(mod._ivalue()->name(), serialized_engine);
50+
std::string& engine,
51+
runtime::CudaDevice& device_info) {
52+
auto engine_ptr = c10::make_intrusive<runtime::TRTEngine>(mod._ivalue()->name(), engine, device_info);
5153
// Get required metadata about the engine out
5254
auto num_io = engine_ptr->num_io;
5355
auto name = engine_ptr->name;
@@ -157,12 +159,16 @@ torch::jit::script::Module CompileGraph(const torch::jit::script::Module& mod, C
157159
// torch::jit::script::Module new_mod = mod.clone();
158160
torch::jit::script::Module new_mod(mod._ivalue()->name() + "_trt");
159161
std::vector<std::shared_ptr<torch::jit::Graph>> graphs;
162+
160163
for (const torch::jit::script::Method& method : mod.get_methods()) {
161164
// Don't convert hidden methods
162165
if (method.name().rfind("_", 0)) {
163166
auto engine = ConvertGraphToTRTEngine(mod, method.name(), cfg);
164167
auto new_g = std::make_shared<torch::jit::Graph>();
165-
AddEngineToGraph(new_mod, new_g, engine);
168+
169+
auto device_spec = cfg.convert_info.engine_settings.device;
170+
auto cuda_device = runtime::get_device_info(device_spec.gpu_id, device_spec.device_type);
171+
AddEngineToGraph(new_mod, new_g, engine, cuda_device);
166172
auto new_method = new_mod._ivalue()->compilation_unit()->create_function(method.name(), new_g);
167173
auto schema = GenerateGraphSchema(new_mod, new_method->name(), new_g);
168174
new_mod.type()->addMethod(new_method);
@@ -174,7 +180,7 @@ torch::jit::script::Module CompileGraph(const torch::jit::script::Module& mod, C
174180
}
175181

176182
void set_device(const int gpu_id) {
177-
TRTORCH_ASSERT(cudaSetDevice(gpu_id) == cudaSuccess, "Unable to set CUDA device: " << gpu_id);
183+
TRTORCH_CHECK((cudaSetDevice(gpu_id) == cudaSuccess), "Unable to set CUDA device: " << gpu_id);
178184
}
179185

180186
} // namespace core

core/runtime/TRTEngine.cpp

Lines changed: 205 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#include <algorithm>
22

3+
#include <cuda_runtime.h>
34
#include "NvInfer.h"
45
#include "torch/csrc/jit/frontend/function_schema_parser.h"
56

@@ -15,20 +16,43 @@ std::string slugify(std::string s) {
1516
return s;
1617
}
1718

19+
CudaDevice default_device = {0, 0, 0, nvinfer1::DeviceType::kGPU, 0, ""};
20+
1821
TRTEngine::TRTEngine(std::string serialized_engine)
1922
: logger(
2023
std::string("[] - "),
2124
util::logging::get_logger().get_reportable_severity(),
2225
util::logging::get_logger().get_is_colored_output_on()) {
2326
std::string _name = "deserialized_trt";
24-
new (this) TRTEngine(_name, serialized_engine);
27+
// TODO: Need to add the option to configure target device at execution time
28+
auto device = get_device_info(0, nvinfer1::DeviceType::kGPU);
29+
new (this) TRTEngine(_name, serialized_engine, device);
30+
}
31+
32+
TRTEngine::TRTEngine(std::vector<std::string> serialized_info)
33+
: logger(
34+
std::string("[] = "),
35+
util::logging::get_logger().get_reportable_severity(),
36+
util::logging::get_logger().get_is_colored_output_on()) {
37+
std::string _name = "deserialized_trt";
38+
std::string engine_info = serialized_info[EngineIdx];
39+
40+
CudaDevice cuda_device = deserialize_device(serialized_info[DeviceIdx]);
41+
42+
new (this) TRTEngine(_name, engine_info, cuda_device);
2543
}
2644

27-
TRTEngine::TRTEngine(std::string mod_name, std::string serialized_engine)
45+
TRTEngine::TRTEngine(
46+
std::string mod_name,
47+
std::string serialized_engine,
48+
CudaDevice cuda_device)
2849
: logger(
2950
std::string("[") + mod_name + std::string("_engine] - "),
3051
util::logging::get_logger().get_reportable_severity(),
3152
util::logging::get_logger().get_is_colored_output_on()) {
53+
54+
set_cuda_device(cuda_device);
55+
3256
rt = nvinfer1::createInferRuntime(logger);
3357

3458
name = slugify(mod_name) + "_engine";
@@ -63,6 +87,7 @@ TRTEngine& TRTEngine::operator=(const TRTEngine& other) {
6387
id = other.id;
6488
rt = other.rt;
6589
cuda_engine = other.cuda_engine;
90+
device_info = other.device_info;
6691
exec_ctx = other.exec_ctx;
6792
num_io = other.num_io;
6893
return (*this);
@@ -82,21 +107,191 @@ TRTEngine::~TRTEngine() {
82107
// return c10::List<at::Tensor>(output_vec);
83108
// }
84109

85-
namespace {
86110
static auto TRTORCH_UNUSED TRTEngineTSRegistrtion =
87111
torch::class_<TRTEngine>("tensorrt", "Engine")
88-
.def(torch::init<std::string>())
112+
.def(torch::init<std::vector<std::string>>())
89113
// TODO: .def("__call__", &TRTEngine::Run)
90114
// TODO: .def("run", &TRTEngine::Run)
91115
.def_pickle(
92-
[](const c10::intrusive_ptr<TRTEngine>& self) -> std::string {
93-
auto serialized_engine = self->cuda_engine->serialize();
94-
return std::string((const char*)serialized_engine->data(), serialized_engine->size());
116+
[](const c10::intrusive_ptr<TRTEngine>& self) -> std::vector<std::string> {
117+
// Serialize TensorRT engine
118+
auto serialized_trt_engine = self->cuda_engine->serialize();
119+
120+
// Adding device info related meta data to the serialized file
121+
auto trt_engine = std::string((const char*)serialized_trt_engine->data(), serialized_trt_engine->size());
122+
123+
std::vector<std::string> serialize_info;
124+
serialize_info.push_back(serialize_device(self->device_info));
125+
serialize_info.push_back(trt_engine);
126+
return serialize_info;
95127
},
96-
[](std::string seralized_engine) -> c10::intrusive_ptr<TRTEngine> {
97-
return c10::make_intrusive<TRTEngine>(std::move(seralized_engine));
128+
[](std::vector<std::string> serial_info) -> c10::intrusive_ptr<TRTEngine> {
129+
return c10::make_intrusive<TRTEngine>(std::move(serial_info));
98130
});
99-
} // namespace
131+
132+
/*
133+
int64_t CudaDevice::get_id(void) {
134+
return this->id;
135+
}
136+
137+
void CudaDevice::set_id(int64_t id) {
138+
this->id = id;
139+
}
140+
141+
int64_t CudaDevice::get_major(void) {
142+
return this->major;
143+
}
144+
145+
void CudaDevice::set_major(int64_t major) {
146+
this->major = major;
147+
}
148+
149+
int64_t CudaDevice::get_minor(void) {
150+
return this->minor;
151+
}
152+
153+
void CudaDevice::set_minor(int64_t minor) {
154+
this->minor = minor;
155+
}
156+
157+
nvinfer1::DeviceType get_device_type(void) {
158+
return this->device_type;
159+
}
160+
161+
void set_device_type(nvinfer1::DeviceType device_type) {
162+
this->device_type = device_type;
163+
}
164+
165+
std::string get_device_name(void) {
166+
return this->device_name;
167+
}
168+
169+
void set_device_name(std::string& name) {
170+
this->device_name = name;
171+
}
172+
173+
size_t get_device_name_len(void) {
174+
return this->device_name_len;
175+
}
176+
177+
void set_device_name_len(size_t size) {
178+
this->device_name_len = size;
179+
}
180+
*/
181+
182+
void set_cuda_device(CudaDevice& cuda_device) {
183+
TRTORCH_CHECK((cudaSetDevice(cuda_device.id) == cudaSuccess), "Unable to set device: " << cuda_device.id);
184+
}
185+
186+
void get_cuda_device(CudaDevice& cuda_device) {
187+
TRTORCH_CHECK((cudaGetDevice(reinterpret_cast<int*>(&cuda_device.id)) == cudaSuccess), "Unable to get current device: " << cuda_device.id);
188+
cudaDeviceProp device_prop;
189+
TRTORCH_CHECK(
190+
(cudaGetDeviceProperties(&device_prop, cuda_device.id) == cudaSuccess),
191+
"Unable to get CUDA properties from device:" << cuda_device.id);
192+
cuda_device.set_major(device_prop.major);
193+
cuda_device.set_minor(device_prop.minor);
194+
std::string device_name(device_prop.name);
195+
cuda_device.set_device_name(device_name);
196+
}
197+
198+
std::string serialize_device(CudaDevice& cuda_device) {
199+
void* buffer = new char[sizeof(cuda_device)];
200+
void* ref_buf = buffer;
201+
202+
int64_t temp = cuda_device.get_id();
203+
memcpy(buffer, reinterpret_cast<int64_t*>(&temp), sizeof(int64_t));
204+
buffer = static_cast<char*>(buffer) + sizeof(int64_t);
205+
206+
temp = cuda_device.get_major();
207+
memcpy(buffer, reinterpret_cast<int64_t*>(&temp), sizeof(int64_t));
208+
buffer = static_cast<char*>(buffer) + sizeof(int64_t);
209+
210+
temp = cuda_device.get_minor();
211+
memcpy(buffer, reinterpret_cast<int64_t*>(&temp), sizeof(int64_t));
212+
buffer = static_cast<char*>(buffer) + sizeof(int64_t);
213+
214+
auto device_type = cuda_device.get_device_type();
215+
memcpy(buffer, reinterpret_cast<char*>(&device_type), sizeof(nvinfer1::DeviceType));
216+
buffer = static_cast<char*>(buffer) + sizeof(nvinfer1::DeviceType);
217+
218+
size_t device_name_len = cuda_device.get_device_name_len();
219+
memcpy(buffer, reinterpret_cast<char*>(&device_name_len), sizeof(size_t));
220+
buffer = static_cast<char*>(buffer) + sizeof(size_t);
221+
222+
auto device_name = cuda_device.get_device_name();
223+
memcpy(buffer, reinterpret_cast<char*>(&device_name), device_name.size());
224+
buffer = static_cast<char*>(buffer) + device_name.size();
225+
226+
return std::string((const char*)ref_buf, sizeof(int64_t) * 3 + sizeof(nvinfer1::DeviceType) + device_name.size());
227+
}
228+
229+
CudaDevice deserialize_device(std::string device_info) {
230+
CudaDevice ret;
231+
char* buffer = new char[device_info.size() + 1];
232+
std::copy(device_info.begin(), device_info.end(), buffer);
233+
int64_t temp = 0;
234+
235+
memcpy(&temp, reinterpret_cast<char*>(buffer), sizeof(int64_t));
236+
buffer += sizeof(int64_t);
237+
ret.set_id(temp);
238+
239+
memcpy(&temp, reinterpret_cast<char*>(buffer), sizeof(int64_t));
240+
buffer += sizeof(int64_t);
241+
ret.set_major(temp);
242+
243+
memcpy(&temp, reinterpret_cast<char*>(buffer), sizeof(int64_t));
244+
buffer += sizeof(int64_t);
245+
ret.set_minor(temp);
246+
247+
nvinfer1::DeviceType device_type;
248+
memcpy(&device_type, reinterpret_cast<char*>(buffer), sizeof(nvinfer1::DeviceType));
249+
buffer += sizeof(nvinfer1::DeviceType);
250+
251+
size_t size;
252+
memcpy(&size, reinterpret_cast<size_t*>(&buffer), sizeof(size_t));
253+
buffer += sizeof(size_t);
254+
255+
ret.set_device_name_len(size);
256+
257+
std::string device_name;
258+
memcpy(&device_name, reinterpret_cast<char*>(buffer), size * sizeof(char));
259+
buffer += size * sizeof(char);
260+
261+
ret.set_device_name(device_name);
262+
263+
return ret;
264+
}
265+
266+
CudaDevice get_device_info(int64_t gpu_id, nvinfer1::DeviceType device_type) {
267+
CudaDevice device;
268+
cudaDeviceProp device_prop;
269+
270+
// Device ID
271+
device.set_id(gpu_id);
272+
273+
// Get Device Properties
274+
cudaGetDeviceProperties(&device_prop, gpu_id);
275+
276+
// Compute capability major version
277+
device.set_major(device_prop.major);
278+
279+
// Compute capability minor version
280+
device.set_minor(device_prop.minor);
281+
282+
std::string device_name(device_prop.name);
283+
284+
// Set Device name
285+
device.set_device_name(device_name);
286+
287+
// Set Device name len for serialization/deserialization
288+
device.set_device_name_len(device_name.size());
289+
290+
// Set Device Type
291+
device.set_device_type(device_type);
292+
return device;
293+
}
294+
100295
} // namespace runtime
101296
} // namespace core
102297
} // namespace trtorch

core/runtime/register_trt_op.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ namespace runtime {
1212

1313
std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intrusive_ptr<TRTEngine> compiled_engine) {
1414
LOG_DEBUG("Attempting to run engine (ID: " << compiled_engine->name << ")");
15+
LOG_DEBUG("Check device_info : " << compiled_engine->device_info.device_name);
1516
std::vector<void*> gpu_handles;
1617

1718
std::vector<at::Tensor> contig_inputs{};

0 commit comments

Comments
 (0)