Skip to content

Commit bf4dd23

Browse files
author
Anurag Dixit
committed
feat(//core)!: Added support for Device meta data serialization and deserialization implicitly
1 parent cbe1866 commit bf4dd23

File tree

14 files changed

+433
-28
lines changed

14 files changed

+433
-28
lines changed

core/compiler.cpp

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,11 @@ void AddEngineToGraph(
3131
torch::jit::script::Module mod,
3232
std::shared_ptr<torch::jit::Graph>& g,
3333
const std::string& serialized_engine,
34+
runtime::CudaDevice& device_info,
3435
std::string engine_id = "",
3536
bool fallback = false) {
36-
auto engine_ptr = c10::make_intrusive<runtime::TRTEngine>(mod._ivalue()->name() + engine_id, serialized_engine);
37+
auto engine_ptr =
38+
c10::make_intrusive<runtime::TRTEngine>(mod._ivalue()->name() + engine_id, serialized_engine, device_info);
3739
// Get required metadata about the engine out
3840
auto num_io = engine_ptr->num_io;
3941
auto name = engine_ptr->name;
@@ -220,7 +222,9 @@ torch::jit::script::Module CompileGraphWithFallback(const torch::jit::script::Mo
220222
convert_cfg.input_ranges = input_ranges;
221223
auto engine = conversion::ConvertBlockToEngine(seg_block.block(), convert_cfg, named_params);
222224
auto temp_g = std::make_shared<torch::jit::Graph>();
223-
AddEngineToGraph(new_mod, temp_g, engine, trt_engine_id.str(), true);
225+
auto device_spec = convert_cfg.engine_settings.device;
226+
auto cuda_device = runtime::get_device_info(device_spec.gpu_id, device_spec.device_type);
227+
AddEngineToGraph(new_mod, temp_g, engine, cuda_device, trt_engine_id.str(), true);
224228

225229
seg_block.update_graph(temp_g);
226230
AddSegmentedBlockToGraph(new_g, seg_block, old_to_new_g);
@@ -260,7 +264,9 @@ torch::jit::script::Module CompileGraph(const torch::jit::script::Module& mod, C
260264
if (method.name().rfind("_", 0)) {
261265
auto engine = ConvertGraphToTRTEngine(mod, method.name(), cfg);
262266
auto new_g = std::make_shared<torch::jit::Graph>();
263-
AddEngineToGraph(new_mod, new_g, engine);
267+
auto device_spec = cfg.convert_info.engine_settings.device;
268+
auto cuda_device = runtime::get_device_info(device_spec.gpu_id, device_spec.device_type);
269+
AddEngineToGraph(new_mod, new_g, engine, cuda_device);
264270
auto new_method = new_mod._ivalue()->compilation_unit()->create_function(method.name(), new_g);
265271
auto schema = util::GenerateGraphSchema(new_method->name(), new_g);
266272
new_mod.type()->addMethod(new_method);
@@ -271,12 +277,14 @@ torch::jit::script::Module CompileGraph(const torch::jit::script::Module& mod, C
271277
return new_mod;
272278
}
273279

274-
torch::jit::script::Module EmbedEngineInNewModule(const std::string& engine) {
280+
torch::jit::script::Module EmbedEngineInNewModule(const std::string& engine, CompileSpec cfg) {
275281
std::ostringstream engine_id;
276282
engine_id << reinterpret_cast<const int*>(&engine);
277283
torch::jit::script::Module new_mod("tensorrt_engine_mod_" + engine_id.str());
278284
auto new_g = std::make_shared<torch::jit::Graph>();
279-
AddEngineToGraph(new_mod, new_g, engine);
285+
auto device_spec = cfg.convert_info.engine_settings.device;
286+
auto cuda_device = runtime::get_device_info(device_spec.gpu_id, device_spec.device_type);
287+
AddEngineToGraph(new_mod, new_g, engine, cuda_device);
280288
auto new_method = new_mod._ivalue()->compilation_unit()->create_function("forward", new_g);
281289
auto schema = util::GenerateGraphSchema(new_method->name(), new_g);
282290
new_mod.type()->addMethod(new_method);

core/compiler.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#include "core/conversion/conversion.h"
66
#include "core/ir/ir.h"
77
#include "core/partitioning/partitioning.h"
8+
#include "core/runtime/runtime.h"
89
#include "torch/csrc/jit/api/module.h"
910

1011
namespace trtorch {
@@ -22,7 +23,7 @@ std::string ConvertGraphToTRTEngine(const torch::jit::script::Module& mod, std::
2223

2324
torch::jit::script::Module CompileGraph(const torch::jit::script::Module& module, CompileSpec cfg);
2425

25-
torch::jit::script::Module EmbedEngineInNewModule(const std::string& engine);
26+
torch::jit::script::Module EmbedEngineInNewModule(const std::string& engine, CompileSpec cfg);
2627

2728
void set_device(const int gpu_id);
2829

core/runtime/TRTEngine.cpp

Lines changed: 152 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#include <algorithm>
22

3+
#include <cuda_runtime.h>
34
#include "NvInfer.h"
45
#include "torch/csrc/jit/frontend/function_schema_parser.h"
56

@@ -15,20 +16,35 @@ std::string slugify(std::string s) {
1516
return s;
1617
}
1718

18-
TRTEngine::TRTEngine(std::string serialized_engine)
19+
TRTEngine::TRTEngine(std::string serialized_engine, CudaDevice cuda_device)
1920
: logger(
2021
std::string("[] - "),
2122
util::logging::get_logger().get_reportable_severity(),
2223
util::logging::get_logger().get_is_colored_output_on()) {
2324
std::string _name = "deserialized_trt";
24-
new (this) TRTEngine(_name, serialized_engine);
25+
new (this) TRTEngine(_name, serialized_engine, cuda_device);
2526
}
2627

27-
TRTEngine::TRTEngine(std::string mod_name, std::string serialized_engine)
28+
TRTEngine::TRTEngine(std::vector<std::string> serialized_info)
29+
: logger(
30+
std::string("[] = "),
31+
util::logging::get_logger().get_reportable_severity(),
32+
util::logging::get_logger().get_is_colored_output_on()) {
33+
std::string _name = "deserialized_trt";
34+
std::string engine_info = serialized_info[EngineIdx];
35+
36+
CudaDevice cuda_device = deserialize_device(serialized_info[DeviceIdx]);
37+
new (this) TRTEngine(_name, engine_info, cuda_device);
38+
}
39+
40+
TRTEngine::TRTEngine(std::string mod_name, std::string serialized_engine, CudaDevice cuda_device)
2841
: logger(
2942
std::string("[") + mod_name + std::string("_engine] - "),
3043
util::logging::get_logger().get_reportable_severity(),
3144
util::logging::get_logger().get_is_colored_output_on()) {
45+
device_info = cuda_device;
46+
set_cuda_device(device_info);
47+
3248
rt = nvinfer1::createInferRuntime(logger);
3349

3450
name = slugify(mod_name) + "_engine";
@@ -63,6 +79,7 @@ TRTEngine& TRTEngine::operator=(const TRTEngine& other) {
6379
id = other.id;
6480
rt = other.rt;
6581
cuda_engine = other.cuda_engine;
82+
device_info = other.device_info;
6683
exec_ctx = other.exec_ctx;
6784
num_io = other.num_io;
6885
return (*this);
@@ -85,18 +102,144 @@ TRTEngine::~TRTEngine() {
85102
namespace {
86103
static auto TRTORCH_UNUSED TRTEngineTSRegistrtion =
87104
torch::class_<TRTEngine>("tensorrt", "Engine")
88-
.def(torch::init<std::string>())
105+
.def(torch::init<std::vector<std::string>>())
89106
// TODO: .def("__call__", &TRTEngine::Run)
90107
// TODO: .def("run", &TRTEngine::Run)
91108
.def_pickle(
92-
[](const c10::intrusive_ptr<TRTEngine>& self) -> std::string {
93-
auto serialized_engine = self->cuda_engine->serialize();
94-
return std::string((const char*)serialized_engine->data(), serialized_engine->size());
109+
[](const c10::intrusive_ptr<TRTEngine>& self) -> std::vector<std::string> {
110+
// Serialize TensorRT engine
111+
auto serialized_trt_engine = self->cuda_engine->serialize();
112+
113+
// Adding device info related meta data to the serialized file
114+
auto trt_engine = std::string((const char*)serialized_trt_engine->data(), serialized_trt_engine->size());
115+
116+
std::vector<std::string> serialize_info;
117+
serialize_info.push_back(serialize_device(self->device_info));
118+
serialize_info.push_back(trt_engine);
119+
return serialize_info;
95120
},
96-
[](std::string seralized_engine) -> c10::intrusive_ptr<TRTEngine> {
97-
return c10::make_intrusive<TRTEngine>(std::move(seralized_engine));
121+
[](std::vector<std::string> seralized_info) -> c10::intrusive_ptr<TRTEngine> {
122+
return c10::make_intrusive<TRTEngine>(std::move(seralized_info));
98123
});
99124
} // namespace
125+
void set_cuda_device(CudaDevice& cuda_device) {
126+
TRTORCH_CHECK((cudaSetDevice(cuda_device.id) == cudaSuccess), "Unable to set device: " << cuda_device.id);
127+
}
128+
129+
void get_cuda_device(CudaDevice& cuda_device) {
130+
int device = 0;
131+
TRTORCH_CHECK(
132+
(cudaGetDevice(reinterpret_cast<int*>(&device)) == cudaSuccess),
133+
"Unable to get current device: " << cuda_device.id);
134+
cuda_device.id = static_cast<int64_t>(device);
135+
cudaDeviceProp device_prop;
136+
TRTORCH_CHECK(
137+
(cudaGetDeviceProperties(&device_prop, cuda_device.id) == cudaSuccess),
138+
"Unable to get CUDA properties from device:" << cuda_device.id);
139+
cuda_device.set_major(device_prop.major);
140+
cuda_device.set_minor(device_prop.minor);
141+
std::string device_name(device_prop.name);
142+
cuda_device.set_device_name(device_name);
143+
}
144+
145+
std::string serialize_device(CudaDevice& cuda_device) {
146+
void* buffer = new char[sizeof(cuda_device)];
147+
void* ref_buf = buffer;
148+
149+
int64_t temp = cuda_device.get_id();
150+
memcpy(buffer, reinterpret_cast<int64_t*>(&temp), sizeof(int64_t));
151+
buffer = static_cast<char*>(buffer) + sizeof(int64_t);
152+
153+
temp = cuda_device.get_major();
154+
memcpy(buffer, reinterpret_cast<int64_t*>(&temp), sizeof(int64_t));
155+
buffer = static_cast<char*>(buffer) + sizeof(int64_t);
156+
157+
temp = cuda_device.get_minor();
158+
memcpy(buffer, reinterpret_cast<int64_t*>(&temp), sizeof(int64_t));
159+
buffer = static_cast<char*>(buffer) + sizeof(int64_t);
160+
161+
auto device_type = cuda_device.get_device_type();
162+
memcpy(buffer, reinterpret_cast<char*>(&device_type), sizeof(nvinfer1::DeviceType));
163+
buffer = static_cast<char*>(buffer) + sizeof(nvinfer1::DeviceType);
164+
165+
size_t device_name_len = cuda_device.get_device_name_len();
166+
memcpy(buffer, reinterpret_cast<char*>(&device_name_len), sizeof(size_t));
167+
buffer = static_cast<char*>(buffer) + sizeof(size_t);
168+
169+
auto device_name = cuda_device.get_device_name();
170+
memcpy(buffer, reinterpret_cast<char*>(&device_name), device_name.size());
171+
buffer = static_cast<char*>(buffer) + device_name.size();
172+
173+
return std::string((const char*)ref_buf, sizeof(int64_t) * 3 + sizeof(nvinfer1::DeviceType) + device_name.size());
174+
}
175+
176+
CudaDevice deserialize_device(std::string device_info) {
177+
CudaDevice ret;
178+
char* buffer = new char[device_info.size() + 1];
179+
std::copy(device_info.begin(), device_info.end(), buffer);
180+
int64_t temp = 0;
181+
182+
memcpy(&temp, reinterpret_cast<char*>(buffer), sizeof(int64_t));
183+
buffer += sizeof(int64_t);
184+
ret.set_id(temp);
185+
186+
memcpy(&temp, reinterpret_cast<char*>(buffer), sizeof(int64_t));
187+
buffer += sizeof(int64_t);
188+
ret.set_major(temp);
189+
190+
memcpy(&temp, reinterpret_cast<char*>(buffer), sizeof(int64_t));
191+
buffer += sizeof(int64_t);
192+
ret.set_minor(temp);
193+
194+
nvinfer1::DeviceType device_type;
195+
memcpy(&device_type, reinterpret_cast<char*>(buffer), sizeof(nvinfer1::DeviceType));
196+
buffer += sizeof(nvinfer1::DeviceType);
197+
198+
size_t size;
199+
memcpy(&size, reinterpret_cast<size_t*>(&buffer), sizeof(size_t));
200+
buffer += sizeof(size_t);
201+
202+
ret.set_device_name_len(size);
203+
204+
std::string device_name;
205+
memcpy(&device_name, reinterpret_cast<char*>(buffer), size * sizeof(char));
206+
buffer += size * sizeof(char);
207+
208+
ret.set_device_name(device_name);
209+
210+
return ret;
211+
}
212+
213+
CudaDevice get_device_info(int64_t gpu_id, nvinfer1::DeviceType device_type) {
214+
CudaDevice cuda_device;
215+
cudaDeviceProp device_prop;
216+
217+
// Device ID
218+
cuda_device.set_id(gpu_id);
219+
220+
// Get Device Properties
221+
cudaGetDeviceProperties(&device_prop, gpu_id);
222+
223+
// Compute capability major version
224+
cuda_device.set_major(device_prop.major);
225+
226+
// Compute capability minor version
227+
cuda_device.set_minor(device_prop.minor);
228+
229+
std::string device_name(device_prop.name);
230+
231+
// Set Device name
232+
cuda_device.set_device_name(device_name);
233+
234+
// Set Device name len for serialization/deserialization
235+
cuda_device.set_device_name_len(device_name.size());
236+
237+
// Set Device Type
238+
cuda_device.set_device_type(device_type);
239+
240+
return cuda_device;
241+
}
242+
100243
} // namespace runtime
101244
} // namespace core
102245
} // namespace trtorch

core/runtime/register_trt_op.cpp

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,87 @@ namespace trtorch {
1010
namespace core {
1111
namespace runtime {
1212

13+
// Checks if the context switch requred for device ID
14+
bool is_switch_required(const CudaDevice& curr_device, const CudaDevice& conf_device) {
15+
// If SM capability is not the same as configured then switch
16+
if ((curr_device.major != conf_device.major) || (curr_device.minor != conf_device.minor)) {
17+
LOG_WARNING("Configured SM capability does not match with current device ID. Switching context");
18+
return true;
19+
}
20+
21+
// GPU case
22+
if (conf_device.device_type == nvinfer1::DeviceType::kGPU) {
23+
if (curr_device.device_name != conf_device.device_name) {
24+
LOG_WARNING("TRTEngine compiled for " << conf_device.device_name << " but current CUDA device is " << curr_device.device_name << ". Switching the device context");
25+
return true;
26+
}
27+
}
28+
29+
if (curr_device.id != conf_device.id) {
30+
LOG_WARNING("Configured Device ID: " << conf_device.id << " is different that current device ID: " << curr_device.id << ". Switching context");
31+
return true;
32+
}
33+
34+
return false;
35+
}
36+
37+
int select_cuda_device(const CudaDevice& conf_device) {
38+
int device_id = 0;
39+
int num_devices = 0;
40+
// SM Compute capability <major,minor> pair
41+
std::unordered_map<std::string, std::string> dla_supported_SM;
42+
43+
// Xavier SM Compute Capability
44+
dla_supported_SM.insert(std::make_pair("7.2", "Xavier"));
45+
auto status = cudaGetDeviceCount(&num_devices);
46+
TRTORCH_CHECK((status == cudaSuccess), "Unable to read CUDA capable devices. Return status: " << status);
47+
48+
cudaDeviceProp device_prop;
49+
50+
for (int i=0; i < num_devices; i++) {
51+
TRTORCH_CHECK((cudaGetDeviceProperties(&device_prop, i) == cudaSuccess), "Unable to read CUDA Device Properies for device id: " << i);
52+
auto compute_cap = std::to_string(device_prop.major) + "." + std::to_string(device_prop.minor);
53+
std::string device_name{device_prop.name};
54+
// In case of DLA select the DLA supported device ID
55+
if (conf_device.device_type == nvinfer1::DeviceType::kDLA) {
56+
if (dla_supported_SM.find(compute_cap) != dla_supported_SM.end() && dla_supported_SM[compute_cap] == device_name) {
57+
device_id = i;
58+
break;
59+
}
60+
}
61+
else if (conf_device.device_type == nvinfer1::DeviceType::kGPU) {
62+
auto conf_sm = std::to_string(conf_device.major) + "." + std::to_string(conf_device.minor);
63+
if (compute_cap == conf_sm && device_name == conf_device.device_name) {
64+
device_id = i;
65+
break;
66+
}
67+
}
68+
else {
69+
LOG_ERROR("Unkown device type detected from the compiled engine");
70+
break;
71+
}
72+
}
73+
return device_id;
74+
}
75+
1376
std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intrusive_ptr<TRTEngine> compiled_engine) {
1477
LOG_DEBUG("Attempting to run engine (ID: " << compiled_engine->name << ")");
78+
79+
CudaDevice curr_device;
80+
get_cuda_device(curr_device);
81+
82+
if (is_switch_required(curr_device, compiled_engine->device_info)) {
83+
// Scan through available CUDA devices and set the CUDA device context correctly
84+
CudaDevice device{.id = select_cuda_device(compiled_engine->device_info)};
85+
set_cuda_device(device);
86+
87+
std::string target_device = "cuda:" + std::to_string(device.id);
88+
89+
for(auto& in : inputs) {
90+
in = in.to(at::kCUDA);
91+
}
92+
}
93+
1594
std::vector<void*> gpu_handles;
1695

1796
std::vector<at::Tensor> contig_inputs{};

0 commit comments

Comments
 (0)