Skip to content

Commit 574d77e

Browse files
author
Anurag Dixit
committed
Device metadata serialization deserialization
Signed-off-by: Anurag Dixit <[email protected]>
1 parent e3dd820 commit 574d77e

16 files changed

+423
-23
lines changed

core/compiler.cpp

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
#include <cuda_runtime.h>
12
#include <iostream>
23
#include <memory>
34
#include <sstream>
@@ -46,8 +47,9 @@ c10::FunctionSchema GenerateGraphSchema(
4647
void AddEngineToGraph(
4748
torch::jit::script::Module mod,
4849
std::shared_ptr<torch::jit::Graph>& g,
49-
std::string& serialized_engine) {
50-
auto engine_ptr = c10::make_intrusive<runtime::TRTEngine>(mod._ivalue()->name(), serialized_engine);
50+
std::string& engine,
51+
CudaDevice& device_info) {
52+
auto engine_ptr = c10::make_intrusive<runtime::TRTEngine>(mod._ivalue()->name(), engine, device_info);
5153
// Get required metadata about the engine out
5254
auto num_io = engine_ptr->num_io;
5355
auto name = engine_ptr->name;
@@ -157,12 +159,15 @@ torch::jit::script::Module CompileGraph(const torch::jit::script::Module& mod, C
157159
// torch::jit::script::Module new_mod = mod.clone();
158160
torch::jit::script::Module new_mod(mod._ivalue()->name() + "_trt");
159161
std::vector<std::shared_ptr<torch::jit::Graph>> graphs;
162+
160163
for (const torch::jit::script::Method& method : mod.get_methods()) {
161164
// Don't convert hidden methods
162165
if (method.name().rfind("_", 0)) {
163166
auto engine = ConvertGraphToTRTEngine(mod, method.name(), cfg);
164167
auto new_g = std::make_shared<torch::jit::Graph>();
165-
AddEngineToGraph(new_mod, new_g, engine);
168+
169+
auto cuda_device = runtime::spec_to_device(cfg->convert_info.engine_settings.device);
170+
AddEngineToGraph(new_mod, new_g, engine, cuda_device);
166171
auto new_method = new_mod._ivalue()->compilation_unit()->create_function(method.name(), new_g);
167172
auto schema = GenerateGraphSchema(new_mod, new_method->name(), new_g);
168173
new_mod.type()->addMethod(new_method);
@@ -174,7 +179,7 @@ torch::jit::script::Module CompileGraph(const torch::jit::script::Module& mod, C
174179
}
175180

176181
void set_device(const int gpu_id) {
177-
TRTORCH_ASSERT(cudaSetDevice(gpu_id) == cudaSuccess, "Unable to set CUDA device: " << gpu_id);
182+
TRTORCH_CHECK((cudaSetDevice(gpu_id) == cudaSuccess), "Unable to set CUDA device: " << gpu_id);
178183
}
179184

180185
} // namespace core

core/runtime/TRTEngine.cpp

Lines changed: 198 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#include <algorithm>
22

3+
#include <cuda_runtime.h>
34
#include "NvInfer.h"
45
#include "torch/csrc/jit/frontend/function_schema_parser.h"
56

@@ -15,20 +16,39 @@ std::string slugify(std::string s) {
1516
return s;
1617
}
1718

18-
TRTEngine::TRTEngine(std::string serialized_engine)
19+
TRTEngine::TRTEngine(std::string serialized_engine, CudaDevice device)
1920
: logger(
2021
std::string("[] - "),
2122
util::logging::get_logger().get_reportable_severity(),
2223
util::logging::get_logger().get_is_colored_output_on()) {
2324
std::string _name = "deserialized_trt";
24-
new (this) TRTEngine(_name, serialized_engine);
25+
new (this) TRTEngine(_name, serialized_engine, device);
2526
}
2627

27-
TRTEngine::TRTEngine(std::string mod_name, std::string serialized_engine)
28+
TRTEngine::TRTEngine(std::vector<std::string> serialized_info)
29+
: logger(
30+
std::string("[] = "),
31+
util::logging::get_logger().get_reportable_severity(),
32+
util::logging::get_logger().get_is_colored_output_on()) {
33+
std::string _name = "deserialized_trt";
34+
std::string engine_info = serialized_info[EngineIdx];
35+
36+
CudaDevice cuda_device = deserialize_device(serialized_info[DeviceIdx]);
37+
38+
new (this) TRTEngine(_name, engine_info, cuda_device);
39+
}
40+
41+
TRTEngine::TRTEngine(
42+
std::string mod_name,
43+
std::string serialized_engine,
44+
CudaDevice cuda_device)
2845
: logger(
2946
std::string("[") + mod_name + std::string("_engine] - "),
3047
util::logging::get_logger().get_reportable_severity(),
3148
util::logging::get_logger().get_is_colored_output_on()) {
49+
50+
set_cuda_device(cuda_device);
51+
3252
rt = nvinfer1::createInferRuntime(logger);
3353

3454
name = slugify(mod_name) + "_engine";
@@ -63,6 +83,7 @@ TRTEngine& TRTEngine::operator=(const TRTEngine& other) {
6383
id = other.id;
6484
rt = other.rt;
6585
cuda_engine = other.cuda_engine;
86+
device_info = other.device_info;
6687
exec_ctx = other.exec_ctx;
6788
num_io = other.num_io;
6889
return (*this);
@@ -82,21 +103,188 @@ TRTEngine::~TRTEngine() {
82103
// return c10::List<at::Tensor>(output_vec);
83104
// }
84105

85-
namespace {
86106
static auto TRTORCH_UNUSED TRTEngineTSRegistrtion =
87107
torch::class_<TRTEngine>("tensorrt", "Engine")
88108
.def(torch::init<std::string>())
89109
// TODO: .def("__call__", &TRTEngine::Run)
90110
// TODO: .def("run", &TRTEngine::Run)
91111
.def_pickle(
92-
[](const c10::intrusive_ptr<TRTEngine>& self) -> std::string {
93-
auto serialized_engine = self->cuda_engine->serialize();
94-
return std::string((const char*)serialized_engine->data(), serialized_engine->size());
112+
[](const c10::intrusive_ptr<TRTEngine>& self) -> std::vector<std::string> {
113+
// Serialize TensorRT engine
114+
auto serialized_trt_engine = self->cuda_engine->serialize();
115+
116+
// Adding device info related meta data to the serialized file
117+
auto trt_engine = std::string((const char*)serialized_trt_engine->data(), serialized_trt_engine->size());
118+
119+
std::vector<std::string> serialize_info;
120+
serialize_info.push_back(serialize_device(self.cuda_device));
121+
serialize_info.push_back(trt_engine);
122+
return serialize_info;
95123
},
96-
[](std::string seralized_engine) -> c10::intrusive_ptr<TRTEngine> {
97-
return c10::make_intrusive<TRTEngine>(std::move(seralized_engine));
124+
[](std::vector<std::string> seralized_info) -> c10::intrusive_ptr<TRTEngine> {
125+
return c10::make_intrusive<TRTEngine>(std::move(seralized_info));
98126
});
99-
} // namespace
127+
128+
int64_t CudaDevice::get_id(void) {
129+
return this->id;
130+
}
131+
132+
void CudaDevice::set_id(int64_t id) {
133+
this->id = id;
134+
}
135+
136+
int64_t CudaDevice::get_major(void) {
137+
return this->major;
138+
}
139+
140+
void CudaDevice::set_major(int64_t major) {
141+
this->major = major;
142+
}
143+
144+
int64_t CudaDevice::get_minor(void) {
145+
return this->minor;
146+
}
147+
148+
void CudaDevice::set_minor(int64_t minor) {
149+
this->minor = minor;
150+
}
151+
152+
nvinfer1::DeviceType get_device_type(void) {
153+
return this->device_type;
154+
}
155+
156+
void set_device_type(nvinfer1::DeviceType device_type) {
157+
this->device_type = device_type;
158+
}
159+
160+
std::string get_device_name(void) {
161+
return this->device_name;
162+
}
163+
164+
void set_device_name(std::string& name) {
165+
this->device_name = name;
166+
}
167+
168+
size_t get_device_name_len(void) {
169+
return this->device_name_len;
170+
}
171+
172+
void set_device_name_len(size_t size) {
173+
this->device_name_len = size;
174+
}
175+
176+
void set_cuda_device(CudaDevice& cuda_device) {
177+
TRTORCH_CHECK((cudaSetDevice(cuda_device.id) == cudaSuccess), "Unable to set device: " << cuda_device.id);
178+
}
179+
180+
void get_cuda_device(CudaDevice& cuda_device) {
181+
TRTORCH_CHECK((cudaGetDevice(&cuda_device.id) == cudaSuccess), "Unable to get current device: " << cuda_device.id);
182+
cudaDeviceProp device_prop;
183+
TRTORCH_CHECK(
184+
(cudaGetDeviceProperties(&device_prop, cuda_device.id) == cudaSuccess),
185+
"Unable to get CUDA properties from device:" << cuda_device.id);
186+
cuda_device.set_major(device_prop.major);
187+
cuda_device.set_minor(device_prop.minor);
188+
cuda_device.set_device_name(std::string(device_prop.name));
189+
}
190+
191+
std::string serialize_device(CudaDevice& cuda_device) {
192+
void* buffer = new char[sizeof(cuda_device)];
193+
void* ref_buf = buffer;
194+
195+
int64_t temp = cuda_device.get_id();
196+
memcpy(buffer, reinterpret_cast<int64_t*>(&temp), sizeof(int64_t));
197+
buffer = static_cast<char*>(buffer) + sizeof(int64_t);
198+
199+
temp = cuda_device.get_major();
200+
memcpy(buffer, reinterpret_cast<int64_t*>(&temp), sizeof(int64_t));
201+
buffer = static_cast<char*>(buffer) + sizeof(int64_t);
202+
203+
temp = cuda_device.get_minor();
204+
memcpy(buffer, reinterpret_cast<int64_t*>(&temp), sizeof(int64_t));
205+
buffer = static_cast<char*>(buffer) + sizeof(int64_t);
206+
207+
auto device_type = cuda_device.get_device_type();
208+
memcpy(buffer, reinterpret_cast<char*>(&device_type), sizeof(nvinfer1::DeviceType));
209+
buffer = static_cast<char*>(buffer) + sizeof(nvinfer1::DeviceType);
210+
211+
size_t device_name_len = cuda_device.get_device_name_len();
212+
memcpy(buffer, reinterpret_cast<char*>(&device_name_len), sizeof(size_t));
213+
buffer = static_cast<char*>(buffer) + sizeof(size_t);
214+
215+
auto device_name = cuda_device.get_device_name();
216+
memcpy(buffer, reinterpret_cast<char*>(&device_name), device_name.size());
217+
buffer = static_cast<char*>(buffer) + device_name.size();
218+
219+
return std::string((const char*)ref_buf, sizeof(int64_t) * 3 + sizeof(nvinfer1::DeviceType) + device_name.size();
220+
}
221+
222+
CudaDevice deserialize_device(std::string device_info) {
223+
CudaDevice ret;
224+
char* buffer = new char[device_info.size() + 1];
225+
std::copy(device_info.begin(), device_info.end(), buffer);
226+
int64_t temp = 0;
227+
228+
memcpy(&temp, reinterpret_cast<char*>(buffer), sizeof(int64_t));
229+
buffer += sizeof(int64_t);
230+
ret.set_id(temp);
231+
232+
memcpy(&temp, reinterpret_cast<char*>(buffer), sizeof(int64_t));
233+
buffer += sizeof(int64_t);
234+
ret.set_major(temp);
235+
236+
memcpy(&temp, reinterpret_cast<char*>(buffer), sizeof(int64_t));
237+
buffer += sizeof(int64_t);
238+
ret.set_minor(temp);
239+
240+
nvinfer1::DeviceType device_type;
241+
memcpy(&device_type, reinterpret_cast<char*>(buffer), sizeof(nvinfer1::DeviceType));
242+
buffer += sizeof(nvinfer1::DeviceType);
243+
244+
size_t size;
245+
memcpy(&size, reinterpret_cast<size_t*>(&buffer), sizeof(size_t));
246+
buffer += sizeof(size_t);
247+
248+
ret.set_device_name_len(size);
249+
250+
std::string device_name;
251+
memcpy(&device_name, reinterpret_cast<char*>(buffer), size * sizeof(char));
252+
buffer += size * sizeof(char);
253+
254+
ret.set_device_name(device_name);
255+
256+
return ret;
257+
}
258+
259+
CudaDevice spec_to_device(conversion::Device& spec) {
260+
CudaDevice device;
261+
cudaDeviceProp device_prop;
262+
263+
// Device ID
264+
device.set_id(spec.gpu_id);
265+
266+
// Get Device Properties
267+
cudaGetDeviceProperties(&device_prop, spec.gpu_id);
268+
269+
// Compute capability major version
270+
device.set_major(device_prop.major);
271+
272+
// Compute capability minor version
273+
device.set_minor(device_prop.minor);
274+
275+
std::string device_name = std::string(device_prop.name);
276+
277+
// Set Device name
278+
device.set_device_name(device_name);
279+
280+
// Set Device name len for serialization/deserialization
281+
device.set_device_name_len(device_nmae.size());
282+
283+
// Set Device Type
284+
device.set_device_type(spec.device_type);
285+
return device;
286+
}
287+
100288
} // namespace runtime
101289
} // namespace core
102290
} // namespace trtorch

core/runtime/runtime.h

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,46 @@ namespace runtime {
1111

1212
using EngineID = int64_t;
1313

14+
typedef enum {
15+
DeviceIdx = 0,
16+
EngineIdx
17+
}SerializedInfoIndex;
18+
19+
struct CudaDevice {
20+
int64_t id; // CUDA device id
21+
int64_t major; // CUDA compute major version
22+
int64_t minor; // CUDA compute minor version
23+
nvinfer1::DeviceType device_type;
24+
size_t device_name_len;
25+
std::string device_name;
26+
27+
nvinfer1::DeviceType get_device_type(void);
28+
void set_device_type(nvinfer1::DeviceType dev_type);
29+
30+
size_t get_device_name_len(void);
31+
void set_device_name_len(size_t size);
32+
33+
std::string get_device_name(void);
34+
void set_device_name(std::string& name);
35+
36+
int64_t get_id(void);
37+
void set_id(int64_t id);
38+
39+
int64_t get_major(void);
40+
void set_major(int64_t major);
41+
42+
int64_t get_minor(void);
43+
void set_minor(int64_t minor);
44+
};
45+
46+
void set_cuda_device(CudaDevice& cuda_device);
47+
void get_cuda_device(CudaDevice& cuda_device);
48+
49+
std::string serialize_device(CudaDevice& cuda_device);
50+
CudaDevice deserialize_device(std::string device_info);
51+
52+
CudaDevice spec_to_device(conversion::Device& spec);
53+
1454
struct TRTEngine : torch::CustomClassHolder {
1555
// Each engine needs it's own runtime object
1656
nvinfer1::IRuntime* rt;
@@ -19,14 +59,16 @@ struct TRTEngine : torch::CustomClassHolder {
1959
std::pair<uint64_t, uint64_t> num_io;
2060
EngineID id;
2161
std::string name;
62+
CudaDevice device_info;
2263
util::logging::TRTorchLogger logger;
2364

2465
std::unordered_map<uint64_t, uint64_t> in_binding_map;
2566
std::unordered_map<uint64_t, uint64_t> out_binding_map;
2667

2768
~TRTEngine();
2869
TRTEngine(std::string serialized_engine);
29-
TRTEngine(std::string mod_name, std::string serialized_engine);
70+
TRTEngine(std::vector<std::string> serialized_info);
71+
TRTEngine(std::string mod_name, std::string serialized_engine, std::string device_info);
3072
TRTEngine& operator=(const TRTEngine& other);
3173
// TODO: Implement a call method
3274
// c10::List<at::Tensor> Run(c10::List<at::Tensor> inputs);

docs/_notebooks/Resnet50-example.html

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1935,4 +1935,4 @@ <h3 id="What's-next">
19351935
app.initialize({version: "1.0.4", url: {base: ".."}})
19361936
</script>
19371937
</body>
1938-
</html>
1938+
</html>

docs/_notebooks/lenet-getting-started.html

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1535,4 +1535,4 @@ <h3 id="What's-next">
15351535
app.initialize({version: "1.0.4", url: {base: ".."}})
15361536
</script>
15371537
</body>
1538-
</html>
1538+
</html>

docs/_notebooks/ssd-object-detection-demo.html

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1784,4 +1784,4 @@ <h3 id="References">
17841784
app.initialize({version: "1.0.4", url: {base: ".."}})
17851785
</script>
17861786
</body>
1787-
</html>
1787+
</html>

docs/py_api/trtorch.html

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1308,4 +1308,4 @@ <h2 id="submodules">
13081308
app.initialize({version: "1.0.4", url: {base: ".."}})
13091309
</script>
13101310
</body>
1311-
</html>
1311+
</html>

docs/searchindex.js

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

tests/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ test_suite(
33
tests = [
44
"//tests/core:core_tests",
55
"//tests/modules:module_tests"
6+
"//tests/api:test_apis"
67
],
78
)
89

0 commit comments

Comments
 (0)