Skip to content

Commit 009c32a

Browse files
author
Anurag Dixit
committed
(//core): Rebase with master branch
Signed-off-by: Anurag Dixit <[email protected]>
1 parent 6c6fcbc commit 009c32a

File tree

11 files changed

+496
-72
lines changed

11 files changed

+496
-72
lines changed

core/compiler.cpp

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#include <memory>
33
#include <sstream>
44
#include <vector>
5+
#include <cuda_runtime.h>
56

67
#include <cuda_runtime.h>
78
#include "NvInfer.h"
@@ -47,7 +48,14 @@ void AddEngineToGraph(
4748
torch::jit::script::Module mod,
4849
std::shared_ptr<torch::jit::Graph>& g,
4950
std::string& serialized_engine) {
50-
auto engine_ptr = c10::make_intrusive<runtime::TRTEngine>(mod._ivalue()->name(), serialized_engine);
51+
52+
// Read current CUDA device properties
53+
runtime::get_cuda_device(device);
54+
55+
// Serialize current device information
56+
auto device_info = runtime::serialize_device(device);
57+
58+
auto engine_ptr = c10::make_intrusive<execution::TRTEngine>(mod._ivalue()->name(), serialized_engine, device_info);
5159
// Get required metadata about the engine out
5260
auto num_io = engine_ptr->num_io;
5361
auto name = engine_ptr->name;
@@ -177,5 +185,9 @@ void set_device(const int gpu_id) {
177185
TRTORCH_ASSERT(cudaSetDevice(gpu_id) == cudaSuccess, "Unable to set CUDA device: " << gpu_id);
178186
}
179187

188+
void set_device(const int gpu_id) {
189+
TRTORCH_CHECK((cudaSetDevice(gpu_id) != cudaSuccess), "Unable to set CUDA device: " << gpu_id);
190+
}
191+
180192
} // namespace core
181193
} // namespace trtorch

core/compiler.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,5 +21,7 @@ torch::jit::script::Module CompileGraph(const torch::jit::script::Module& module
2121

2222
void set_device(const int gpu_id);
2323

24+
void set_device(const int gpu_id);
25+
2426
} // namespace core
2527
} // namespace trtorch

core/execution/TRTEngine.cpp

Lines changed: 212 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,212 @@
1+
#include <algorithm>
2+
#include <cuda_runtime.h>
3+
4+
#include "NvInfer.h"
5+
#include "torch/csrc/jit/frontend/function_schema_parser.h"
6+
7+
#include "core/util/prelude.h"
8+
#include "core/execution/execution.h"
9+
10+
namespace trtorch {
11+
namespace core {
12+
namespace execution {
13+
14+
const std::string empty_string = std::string();
15+
16+
std::string slugify(std::string s) {
17+
std::replace(s.begin(), s.end(), '.', '_');
18+
return s;
19+
}
20+
21+
TRTEngine::TRTEngine(std::string serialized_engine)
22+
: logger(std::string("[] - "),
23+
util::logging::get_logger().get_reportable_severity(),
24+
util::logging::get_logger().get_is_colored_output_on()) {
25+
std::string _name = "deserialized_trt";
26+
new (this) TRTEngine(_name, serialized_engine, empty_string);
27+
}
28+
29+
TRTEngine::TRTEngine(std::vector<std::string> serialized_info)
30+
: logger(std::string("[] - "),
31+
util::logging::get_logger().get_reportable_severity(),
32+
util::logging::get_logger().get_is_colored_output_on()) {
33+
std::string _name = "deserialized_trt";
34+
std::string device_info = serialized_info[0];
35+
std::string engine_info = serialized_info[1];
36+
37+
new (this) TRTEngine(_name, engine_info, device_info);
38+
}
39+
40+
TRTEngine::TRTEngine(std::string mod_name, std::string serialized_engine, std::string serialized_device_info = empty_string)
41+
: logger(std::string("[") + mod_name + std::string("_engine] - "),
42+
util::logging::get_logger().get_reportable_severity(),
43+
util::logging::get_logger().get_is_colored_output_on()) {
44+
45+
CudaDevice cuda_device;
46+
// Deserialize device meta data if device_info is non-empty
47+
if (!serialized_device_info.empty())
48+
{
49+
cuda_device = deserialize_device(serialized_device_info);
50+
// Set CUDA device as configured in serialized meta data
51+
set_cuda_device(cuda_device);
52+
}
53+
54+
rt = nvinfer1::createInferRuntime(logger);
55+
56+
name = slugify(mod_name) + "_engine";
57+
58+
cuda_engine = rt->deserializeCudaEngine(serialized_engine.c_str(), serialized_engine.size());
59+
// Easy way to get a unique name for each engine, maybe there is a more descriptive way (using something associated with the graph maybe)
60+
id = reinterpret_cast<EngineID>(cuda_engine);
61+
62+
exec_ctx = cuda_engine->createExecutionContext();
63+
64+
uint64_t inputs = 0;
65+
uint64_t outputs = 0;
66+
67+
for (int64_t x = 0; x < cuda_engine->getNbBindings(); x++) {
68+
std::string name = cuda_engine->getBindingName(x);
69+
std::string idx_s = name.substr(name.find("_") + 1);
70+
uint64_t idx = static_cast<uint64_t>(std::stoi(idx_s));
71+
72+
if(cuda_engine->bindingIsInput(x)) {
73+
inputs++;
74+
in_binding_map[x] = idx;
75+
} else {
76+
outputs++;
77+
out_binding_map[x] = idx;
78+
}
79+
}
80+
num_io = std::make_pair(inputs, outputs);
81+
82+
}
83+
84+
TRTEngine& TRTEngine::operator=(const TRTEngine& other) {
85+
id = other.id;
86+
rt = other.rt;
87+
cuda_engine = other.cuda_engine;
88+
device_info = other.device_info;
89+
exec_ctx = other.exec_ctx;
90+
num_io = other.num_io;
91+
return (*this);
92+
}
93+
94+
TRTEngine::~TRTEngine() {
95+
exec_ctx->destroy();
96+
cuda_engine->destroy();
97+
rt->destroy();
98+
}
99+
100+
101+
// TODO: Implement a call method
102+
// c10::List<at::Tensor> TRTEngine::Run(c10::List<at::Tensor> inputs) {
103+
// auto input_vec = inputs.vec();
104+
// auto output_vec = RunCudaEngine(exec_ctx, num_io, input_vec);
105+
//
106+
// return c10::List<at::Tensor>(output_vec);
107+
// }
108+
109+
static auto TRTORCH_UNUSED TRTEngineTSRegistrtion = torch::class_<TRTEngine>("tensorrt", "Engine")
110+
.def(torch::init<std::string>())
111+
// TODO: .def("__call__", &TRTEngine::Run)
112+
// TODO: .def("run", &TRTEngine::Run)
113+
.def_pickle(
114+
[](const c10::intrusive_ptr<TRTEngine>& self) -> std::vector<std::string> {
115+
// Serialize TensorRT engine
116+
auto serialized_trt_engine = self->cuda_engine->serialize();
117+
118+
// Adding device info related meta data to the serialized file
119+
auto trt_engine = std::string((const char*)serialized_trt_engine->data(), serialized_trt_engine->size());
120+
121+
CudaDevice cuda_device;
122+
std::vector<std::string> serialize_info;
123+
serialize_info.push_back(serialize_device(cuda_device));
124+
serialize_info.push_back(trt_engine);
125+
return serialize_info;
126+
},
127+
[](std::vector<std::string> seralized_info) -> c10::intrusive_ptr<TRTEngine> {
128+
return c10::make_intrusive<TRTEngine>(std::move(seralized_info));
129+
}
130+
);
131+
132+
133+
int CudaDevice::get_id(void) {
134+
return this->id;
135+
}
136+
137+
void CudaDevice::set_id(int id) {
138+
this->id = id;
139+
}
140+
141+
int CudaDevice::get_major(void) {
142+
return this->major;
143+
}
144+
145+
void CudaDevice::set_major(int major) {
146+
this->major = major;
147+
}
148+
149+
int CudaDevice::get_minor(void) {
150+
return this->minor;
151+
}
152+
153+
void CudaDevice::set_minor(int minor) {
154+
this->minor = minor;
155+
}
156+
157+
void set_cuda_device(CudaDevice& cuda_device) {
158+
TRTORCH_CHECK((cudaSetDevice(cuda_device.id) != cudaSuccess), "Unable to set device: " << cuda_device.id);
159+
}
160+
161+
void get_cuda_device(CudaDevice& cuda_device) {
162+
TRTORCH_CHECK((cudaGetDevice(&cuda_device.id) != cudaSuccess), "Unable to get current device: " << cuda_device.id);
163+
cudaDeviceProp device_prop;
164+
TRTORCH_CHECK((cudaGetDeviceProperties(&device_prop, cuda_device.id) != cudaSuccess), "Unable to get CUDA properties from device:" << cuda_device.id);
165+
cuda_device.set_major(device_prop.major);
166+
cuda_device.set_minor(device_prop.minor);
167+
}
168+
169+
std::string serialize_device(CudaDevice& cuda_device) {
170+
void *buffer = new char[sizeof(cuda_device)];
171+
void *ref_buf = buffer;
172+
173+
int temp = cuda_device.get_id();
174+
memcpy(buffer, reinterpret_cast<int*>(&temp), sizeof(int));
175+
buffer = static_cast<char*>(buffer) + sizeof(int);
176+
177+
temp = cuda_device.get_major();
178+
memcpy(buffer, reinterpret_cast<int*>(&temp), sizeof(int));
179+
buffer = static_cast<char*>(buffer) + sizeof(int);
180+
181+
temp = cuda_device.get_minor();
182+
memcpy(buffer, reinterpret_cast<int*>(&temp), sizeof(int));
183+
buffer = static_cast<char*>(buffer) + sizeof(int);
184+
185+
return std::string((const char*)ref_buf, sizeof(int)*3);
186+
}
187+
188+
CudaDevice deserialize_device(std::string device_info) {
189+
CudaDevice ret;
190+
char *buffer = new char[device_info.size() + 1];
191+
std::copy(device_info.begin(), device_info.end(), buffer);
192+
int temp = 0;
193+
194+
memcpy(&temp, reinterpret_cast<char*>(buffer), sizeof(int));
195+
buffer += sizeof(int);
196+
ret.set_id(temp);
197+
198+
memcpy(&temp, reinterpret_cast<char*>(buffer), sizeof(int));
199+
buffer += sizeof(int);
200+
ret.set_major(temp);
201+
202+
memcpy(&temp, reinterpret_cast<char*>(buffer), sizeof(int));
203+
buffer += sizeof(int);
204+
ret.set_minor(temp);
205+
206+
return ret;
207+
}
208+
209+
210+
} // namespace execution
211+
} // namespace core
212+
} // namespace trtorch

core/execution/execution.h

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
#pragma once
2+
#include <utility>
3+
#include "NvInfer.h"
4+
#include "ATen/core/function_schema.h"
5+
#include "torch/custom_class.h"
6+
#include "core/util/prelude.h"
7+
8+
9+
namespace trtorch {
10+
namespace core {
11+
namespace execution {
12+
13+
using EngineID = int64_t;
14+
15+
struct CudaDevice {
16+
int id; // CUDA device id
17+
int major; // CUDA compute major version
18+
int minor; // CUDA compute minor version
19+
20+
int get_id(void);
21+
void set_id(int id);
22+
23+
int get_major(void);
24+
void set_major(int major);
25+
26+
int get_minor(void);
27+
void set_minor(int minor);
28+
};
29+
30+
void set_cuda_device(CudaDevice& cuda_device);
31+
void get_cuda_device(CudaDevice& cuda_device);
32+
33+
std::string serialize_device(CudaDevice& cuda_device);
34+
CudaDevice deserialize_device(std::string device_info);
35+
36+
struct TRTEngine : torch::CustomClassHolder {
37+
// Each engine needs it's own runtime object
38+
nvinfer1::IRuntime* rt;
39+
nvinfer1::ICudaEngine* cuda_engine;
40+
nvinfer1::IExecutionContext* exec_ctx;
41+
std::pair<uint64_t, uint64_t> num_io;
42+
EngineID id;
43+
std::string name;
44+
CudaDevice device_info;
45+
util::logging::TRTorchLogger logger;
46+
47+
std::unordered_map<uint64_t, uint64_t> in_binding_map;
48+
std::unordered_map<uint64_t, uint64_t> out_binding_map;
49+
50+
~TRTEngine();
51+
TRTEngine(std::string serialized_engine);
52+
TRTEngine(std::vector<std::string> serialized_info);
53+
TRTEngine(std::string mod_name, std::string serialized_engine, std::string device_info);
54+
TRTEngine& operator=(const TRTEngine& other);
55+
// TODO: Implement a call method
56+
//c10::List<at::Tensor> Run(c10::List<at::Tensor> inputs);
57+
};
58+
59+
std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intrusive_ptr<TRTEngine> compiled_engine);
60+
61+
} // namespace execution
62+
} // namespace core
63+
} // namespace trtorch

0 commit comments

Comments
 (0)