Skip to content

Commit ac65885

Browse files
authored
Merge pull request #484 from NVIDIA/anuragd/dev_serdes
feat(//core)!: Added support for Device meta data serialization and d…
2 parents 7367fe2 + 9d77206 commit ac65885

19 files changed

+532
-28
lines changed

core/compiler.cpp

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,11 @@ void AddEngineToGraph(
3131
torch::jit::script::Module mod,
3232
std::shared_ptr<torch::jit::Graph>& g,
3333
const std::string& serialized_engine,
34+
runtime::CudaDevice& device_info,
3435
std::string engine_id = "",
3536
bool fallback = false) {
36-
auto engine_ptr = c10::make_intrusive<runtime::TRTEngine>(mod._ivalue()->name() + engine_id, serialized_engine);
37+
auto engine_ptr =
38+
c10::make_intrusive<runtime::TRTEngine>(mod._ivalue()->name() + engine_id, serialized_engine, device_info);
3739
// Get required metadata about the engine out
3840
auto num_io = engine_ptr->num_io;
3941
auto name = engine_ptr->name;
@@ -220,7 +222,9 @@ torch::jit::script::Module CompileGraphWithFallback(const torch::jit::script::Mo
220222
convert_cfg.input_ranges = input_ranges;
221223
auto engine = conversion::ConvertBlockToEngine(seg_block.block(), convert_cfg, named_params);
222224
auto temp_g = std::make_shared<torch::jit::Graph>();
223-
AddEngineToGraph(new_mod, temp_g, engine, trt_engine_id.str(), true);
225+
auto device_spec = convert_cfg.engine_settings.device;
226+
auto cuda_device = runtime::CudaDevice(device_spec.gpu_id, device_spec.device_type);
227+
AddEngineToGraph(new_mod, temp_g, engine, cuda_device, trt_engine_id.str(), true);
224228

225229
seg_block.update_graph(temp_g);
226230
AddSegmentedBlockToGraph(new_g, seg_block, old_to_new_g);
@@ -260,7 +264,9 @@ torch::jit::script::Module CompileGraph(const torch::jit::script::Module& mod, C
260264
if (method.name().compare("forward") == 0) {
261265
auto engine = ConvertGraphToTRTEngine(mod, method.name(), cfg);
262266
auto new_g = std::make_shared<torch::jit::Graph>();
263-
AddEngineToGraph(new_mod, new_g, engine);
267+
auto device_spec = cfg.convert_info.engine_settings.device;
268+
auto cuda_device = runtime::CudaDevice(device_spec.gpu_id, device_spec.device_type);
269+
AddEngineToGraph(new_mod, new_g, engine, cuda_device);
264270
auto new_method = new_mod._ivalue()->compilation_unit()->create_function(method.name(), new_g);
265271
auto schema = util::GenerateGraphSchema(new_method->name(), new_g);
266272
new_mod.type()->addMethod(new_method);
@@ -271,12 +277,12 @@ torch::jit::script::Module CompileGraph(const torch::jit::script::Module& mod, C
271277
return new_mod;
272278
}
273279

274-
torch::jit::script::Module EmbedEngineInNewModule(const std::string& engine) {
280+
torch::jit::script::Module EmbedEngineInNewModule(const std::string& engine, runtime::CudaDevice cuda_device) {
275281
std::ostringstream engine_id;
276282
engine_id << reinterpret_cast<const int*>(&engine);
277283
torch::jit::script::Module new_mod("tensorrt_engine_mod_" + engine_id.str());
278284
auto new_g = std::make_shared<torch::jit::Graph>();
279-
AddEngineToGraph(new_mod, new_g, engine);
285+
AddEngineToGraph(new_mod, new_g, engine, cuda_device);
280286
auto new_method = new_mod._ivalue()->compilation_unit()->create_function("forward", new_g);
281287
auto schema = util::GenerateGraphSchema(new_method->name(), new_g);
282288
new_mod.type()->addMethod(new_method);

core/compiler.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#include "core/conversion/conversion.h"
66
#include "core/ir/ir.h"
77
#include "core/partitioning/partitioning.h"
8+
#include "core/runtime/runtime.h"
89
#include "torch/csrc/jit/api/module.h"
910

1011
namespace trtorch {
@@ -22,7 +23,7 @@ std::string ConvertGraphToTRTEngine(const torch::jit::script::Module& mod, std::
2223

2324
torch::jit::script::Module CompileGraph(const torch::jit::script::Module& module, CompileSpec cfg);
2425

25-
torch::jit::script::Module EmbedEngineInNewModule(const std::string& engine);
26+
torch::jit::script::Module EmbedEngineInNewModule(const std::string& engine, runtime::CudaDevice cuda_device);
2627

2728
void set_device(const int gpu_id);
2829

core/runtime/BUILD

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,11 @@ config_setting(
1010
cc_library(
1111
name = "runtime",
1212
srcs = [
13+
"CudaDevice.cpp",
14+
"DeviceList.cpp",
1315
"TRTEngine.cpp",
1416
"register_trt_op.cpp",
17+
"runtime.cpp"
1518
],
1619
hdrs = [
1720
"runtime.h",

core/runtime/CudaDevice.cpp

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
#include "cuda_runtime.h"
2+
3+
#include "core/runtime/runtime.h"
4+
#include "core/util/prelude.h"
5+
6+
namespace trtorch {
7+
namespace core {
8+
namespace runtime {
9+
10+
const std::string DEVICE_INFO_DELIM = "%";
11+
12+
typedef enum { ID_IDX = 0, SM_MAJOR_IDX, SM_MINOR_IDX, DEVICE_TYPE_IDX, DEVICE_NAME_IDX } SerializedDeviceInfoIndex;
13+
14+
CudaDevice::CudaDevice() : id{-1}, major{-1}, minor{-1}, device_type{nvinfer1::DeviceType::kGPU} {}
15+
16+
CudaDevice::CudaDevice(int64_t gpu_id, nvinfer1::DeviceType device_type) {
17+
CudaDevice cuda_device;
18+
cudaDeviceProp device_prop;
19+
20+
// Device ID
21+
this->id = gpu_id;
22+
23+
// Get Device Properties
24+
cudaGetDeviceProperties(&device_prop, gpu_id);
25+
26+
// Compute capability major version
27+
this->major = device_prop.major;
28+
29+
// Compute capability minor version
30+
this->minor = device_prop.minor;
31+
32+
std::string device_name(device_prop.name);
33+
34+
// Set Device name
35+
this->device_name = device_name;
36+
37+
// Set Device Type
38+
this->device_type = device_type;
39+
}
40+
41+
// NOTE: Serialization Format for Device Info:
42+
// id%major%minor%(enum)device_type%device_name
43+
44+
CudaDevice::CudaDevice(std::string device_info) {
45+
LOG_DEBUG("Deserializing Device Info: " << device_info);
46+
47+
std::vector<std::string> tokens;
48+
int64_t start = 0;
49+
int64_t end = device_info.find(DEVICE_INFO_DELIM);
50+
51+
while (end != -1) {
52+
tokens.push_back(device_info.substr(start, end - start));
53+
start = end + DEVICE_INFO_DELIM.size();
54+
end = device_info.find(DEVICE_INFO_DELIM, start);
55+
}
56+
tokens.push_back(device_info.substr(start, end - start));
57+
58+
TRTORCH_CHECK(tokens.size() == DEVICE_NAME_IDX + 1, "Unable to deserializable program target device infomation");
59+
60+
id = std::stoi(tokens[ID_IDX]);
61+
major = std::stoi(tokens[SM_MAJOR_IDX]);
62+
minor = std::stoi(tokens[SM_MINOR_IDX]);
63+
device_type = (nvinfer1::DeviceType)(std::stoi(tokens[DEVICE_TYPE_IDX]));
64+
device_name = tokens[DEVICE_NAME_IDX];
65+
66+
LOG_DEBUG("Deserialized Device Info: " << *this);
67+
}
68+
69+
std::string CudaDevice::serialize() {
70+
std::vector<std::string> content;
71+
content.resize(DEVICE_NAME_IDX + 1);
72+
73+
content[ID_IDX] = std::to_string(id);
74+
content[SM_MAJOR_IDX] = std::to_string(major);
75+
content[SM_MINOR_IDX] = std::to_string(minor);
76+
content[DEVICE_TYPE_IDX] = std::to_string((int64_t)device_type);
77+
content[DEVICE_NAME_IDX] = device_name;
78+
79+
std::stringstream ss;
80+
for (size_t i = 0; i < content.size() - 1; i++) {
81+
ss << content[i] << DEVICE_INFO_DELIM;
82+
}
83+
ss << content[DEVICE_NAME_IDX];
84+
85+
std::string serialized_device_info = ss.str();
86+
87+
LOG_DEBUG("Serialized Device Info: " << serialized_device_info);
88+
89+
return serialized_device_info;
90+
}
91+
92+
std::string CudaDevice::getSMCapability() const {
93+
std::stringstream ss;
94+
ss << major << "." << minor;
95+
return ss.str();
96+
}
97+
98+
std::ostream& operator<<(std::ostream& os, const CudaDevice& device) {
99+
os << "Device(ID: " << device.id << ", Name: " << device.device_name << ", SM Capability: " << device.major << '.'
100+
<< device.minor << ", Type: " << device.device_type << ')';
101+
return os;
102+
}
103+
104+
} // namespace runtime
105+
} // namespace core
106+
} // namespace trtorch

core/runtime/DeviceList.cpp

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
#include "cuda_runtime.h"
2+
3+
#include "core/runtime/runtime.h"
4+
#include "core/util/prelude.h"
5+
6+
namespace trtorch {
7+
namespace core {
8+
namespace runtime {
9+
10+
DeviceList::DeviceList() {
11+
int num_devices = 0;
12+
auto status = cudaGetDeviceCount(&num_devices);
13+
TRTORCH_ASSERT((status == cudaSuccess), "Unable to read CUDA capable devices. Return status: " << status);
14+
for (int i = 0; i < num_devices; i++) {
15+
device_list[i] = CudaDevice(i, nvinfer1::DeviceType::kGPU);
16+
}
17+
18+
// REVIEW: DO WE CARE ABOUT DLA?
19+
20+
LOG_DEBUG("Runtime:\n Available CUDA Devices: \n" << this->dump_list());
21+
}
22+
23+
void DeviceList::insert(int device_id, CudaDevice cuda_device) {
24+
device_list[device_id] = cuda_device;
25+
}
26+
27+
CudaDevice DeviceList::find(int device_id) {
28+
return device_list[device_id];
29+
}
30+
31+
DeviceList::DeviceMap DeviceList::get_devices() {
32+
return device_list;
33+
}
34+
35+
std::string DeviceList::dump_list() {
36+
std::stringstream ss;
37+
for (auto it = device_list.begin(); it != device_list.end(); ++it) {
38+
ss << " " << it->second << std::endl;
39+
}
40+
return ss.str();
41+
}
42+
43+
} // namespace runtime
44+
} // namespace core
45+
} // namespace trtorch

core/runtime/TRTEngine.cpp

Lines changed: 46 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#include <algorithm>
22

3+
#include <cuda_runtime.h>
34
#include "NvInfer.h"
45
#include "torch/csrc/jit/frontend/function_schema_parser.h"
56

@@ -10,30 +11,55 @@ namespace trtorch {
1011
namespace core {
1112
namespace runtime {
1213

14+
typedef enum { ABI_TARGET_IDX = 0, DEVICE_IDX, ENGINE_IDX } SerializedInfoIndex;
15+
1316
std::string slugify(std::string s) {
1417
std::replace(s.begin(), s.end(), '.', '_');
1518
return s;
1619
}
1720

18-
TRTEngine::TRTEngine(std::string serialized_engine)
21+
TRTEngine::TRTEngine(std::string serialized_engine, CudaDevice cuda_device)
1922
: logger(
2023
std::string("[] - "),
2124
util::logging::get_logger().get_reportable_severity(),
2225
util::logging::get_logger().get_is_colored_output_on()) {
2326
std::string _name = "deserialized_trt";
24-
new (this) TRTEngine(_name, serialized_engine);
27+
new (this) TRTEngine(_name, serialized_engine, cuda_device);
2528
}
2629

27-
TRTEngine::TRTEngine(std::string mod_name, std::string serialized_engine)
30+
TRTEngine::TRTEngine(std::vector<std::string> serialized_info)
31+
: logger(
32+
std::string("[] = "),
33+
util::logging::get_logger().get_reportable_severity(),
34+
util::logging::get_logger().get_is_colored_output_on()) {
35+
TRTORCH_CHECK(
36+
serialized_info.size() == ENGINE_IDX + 1, "Program to be deserialized targets an incompatible TRTorch ABI");
37+
TRTORCH_CHECK(
38+
serialized_info[ABI_TARGET_IDX] == ABI_VERSION,
39+
"Program to be deserialized targets a different TRTorch ABI Version ("
40+
<< serialized_info[ABI_TARGET_IDX] << ") than the TRTorch Runtime ABI (" << ABI_VERSION << ")");
41+
std::string _name = "deserialized_trt";
42+
std::string engine_info = serialized_info[ENGINE_IDX];
43+
44+
CudaDevice cuda_device = deserialize_device(serialized_info[DEVICE_IDX]);
45+
new (this) TRTEngine(_name, engine_info, cuda_device);
46+
}
47+
48+
TRTEngine::TRTEngine(std::string mod_name, std::string serialized_engine, CudaDevice cuda_device)
2849
: logger(
2950
std::string("[") + mod_name + std::string("_engine] - "),
3051
util::logging::get_logger().get_reportable_severity(),
3152
util::logging::get_logger().get_is_colored_output_on()) {
53+
device_info = cuda_device;
54+
set_cuda_device(device_info);
55+
3256
rt = nvinfer1::createInferRuntime(logger);
3357

3458
name = slugify(mod_name) + "_engine";
3559

3660
cuda_engine = rt->deserializeCudaEngine(serialized_engine.c_str(), serialized_engine.size());
61+
TRTORCH_CHECK((cuda_engine != nullptr), "Unable to deserialize the TensorRT engine");
62+
3763
// Easy way to get a unique name for each engine, maybe there is a more
3864
// descriptive way (using something associated with the graph maybe)
3965
id = reinterpret_cast<EngineID>(cuda_engine);
@@ -63,6 +89,7 @@ TRTEngine& TRTEngine::operator=(const TRTEngine& other) {
6389
id = other.id;
6490
rt = other.rt;
6591
cuda_engine = other.cuda_engine;
92+
device_info = other.device_info;
6693
exec_ctx = other.exec_ctx;
6794
num_io = other.num_io;
6895
return (*this);
@@ -85,18 +112,28 @@ TRTEngine::~TRTEngine() {
85112
namespace {
86113
static auto TRTORCH_UNUSED TRTEngineTSRegistrtion =
87114
torch::class_<TRTEngine>("tensorrt", "Engine")
88-
.def(torch::init<std::string>())
115+
.def(torch::init<std::vector<std::string>>())
89116
// TODO: .def("__call__", &TRTEngine::Run)
90117
// TODO: .def("run", &TRTEngine::Run)
91118
.def_pickle(
92-
[](const c10::intrusive_ptr<TRTEngine>& self) -> std::string {
93-
auto serialized_engine = self->cuda_engine->serialize();
94-
return std::string((const char*)serialized_engine->data(), serialized_engine->size());
119+
[](const c10::intrusive_ptr<TRTEngine>& self) -> std::vector<std::string> {
120+
// Serialize TensorRT engine
121+
auto serialized_trt_engine = self->cuda_engine->serialize();
122+
123+
// Adding device info related meta data to the serialized file
124+
auto trt_engine = std::string((const char*)serialized_trt_engine->data(), serialized_trt_engine->size());
125+
126+
std::vector<std::string> serialize_info;
127+
serialize_info.push_back(ABI_VERSION);
128+
serialize_info.push_back(serialize_device(self->device_info));
129+
serialize_info.push_back(trt_engine);
130+
return serialize_info;
95131
},
96-
[](std::string seralized_engine) -> c10::intrusive_ptr<TRTEngine> {
97-
return c10::make_intrusive<TRTEngine>(std::move(seralized_engine));
132+
[](std::vector<std::string> seralized_info) -> c10::intrusive_ptr<TRTEngine> {
133+
return c10::make_intrusive<TRTEngine>(std::move(seralized_info));
98134
});
99135
} // namespace
136+
100137
} // namespace runtime
101138
} // namespace core
102139
} // namespace trtorch

0 commit comments

Comments
 (0)