Skip to content

Commit c1645ba

Browse files
author
Anurag Dixit
committed
(//core): Rebase with master branch
Signed-off-by: Anurag Dixit <[email protected]>
1 parent ae2281e commit c1645ba

File tree

11 files changed

+180
-20
lines changed

11 files changed

+180
-20
lines changed

core/compiler.cpp

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#include <memory>
33
#include <sstream>
44
#include <vector>
5+
#include <cuda_runtime.h>
56

67
#include "NvInfer.h"
78

@@ -42,7 +43,15 @@ c10::FunctionSchema GenerateGraphSchema(torch::jit::script::Module mod, std::str
4243

4344

4445
void AddEngineToGraph(torch::jit::script::Module mod, std::shared_ptr<torch::jit::Graph>& g, std::string& serialized_engine) {
45-
auto engine_ptr = c10::make_intrusive<execution::TRTEngine>(mod._ivalue()->name(), serialized_engine);
46+
execution::CudaDevice device;
47+
48+
// Read current CUDA device properties
49+
execution::get_cuda_device(device);
50+
51+
// Serialize current device information
52+
auto device_info = execution::serialize_device(device);
53+
54+
auto engine_ptr = c10::make_intrusive<execution::TRTEngine>(mod._ivalue()->name(), serialized_engine, device_info);
4655
// Get required metadata about the engine out
4756
auto num_io = engine_ptr->num_io;
4857
auto name = engine_ptr->name;
@@ -164,6 +173,10 @@ torch::jit::script::Module CompileGraph(const torch::jit::script::Module& mod,
164173
return new_mod;
165174
}
166175

176+
void set_device(const int gpu_id) {
177+
TRTORCH_CHECK((cudaSetDevice(gpu_id) != cudaSuccess), "Unable to set CUDA device: " << gpu_id);
178+
}
179+
167180
} // namespace core
168181
} // namespace trtorch
169182

core/compiler.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,5 +20,7 @@ std::string ConvertGraphToTRTEngine(const torch::jit::script::Module& mod,
2020

2121
torch::jit::script::Module CompileGraph(const torch::jit::script::Module& module, ExtraInfo cfg);
2222

23+
void set_device(const int gpu_id);
24+
2325
} // namespace core
2426
} // namespace trtorch

core/execution/TRTEngine.cpp

Lines changed: 119 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
#include <algorithm>
2+
#include <cuda_runtime.h>
23

34
#include "NvInfer.h"
45
#include "torch/csrc/jit/frontend/function_schema_parser.h"
@@ -10,6 +11,8 @@ namespace trtorch {
1011
namespace core {
1112
namespace execution {
1213

14+
const std::string empty_string = std::string();
15+
1316
std::string slugify(std::string s) {
1417
std::replace(s.begin(), s.end(), '.', '_');
1518
return s;
@@ -20,14 +23,34 @@ TRTEngine::TRTEngine(std::string serialized_engine)
2023
util::logging::get_logger().get_reportable_severity(),
2124
util::logging::get_logger().get_is_colored_output_on()) {
2225
std::string _name = "deserialized_trt";
23-
new (this) TRTEngine(_name, serialized_engine);
26+
new (this) TRTEngine(_name, serialized_engine, empty_string);
2427
}
2528

26-
TRTEngine::TRTEngine(std::string mod_name, std::string serialized_engine)
29+
TRTEngine::TRTEngine(std::vector<std::string> serialized_info)
30+
: logger(std::string("[] - "),
31+
util::logging::get_logger().get_reportable_severity(),
32+
util::logging::get_logger().get_is_colored_output_on()) {
33+
std::string _name = "deserialized_trt";
34+
std::string device_info = serialized_info[0];
35+
std::string engine_info = serialized_info[1];
36+
37+
new (this) TRTEngine(_name, engine_info, device_info);
38+
}
39+
40+
TRTEngine::TRTEngine(std::string mod_name, std::string serialized_engine, std::string serialized_device_info = empty_string)
2741
: logger(std::string("[") + mod_name + std::string("_engine] - "),
2842
util::logging::get_logger().get_reportable_severity(),
2943
util::logging::get_logger().get_is_colored_output_on()) {
3044

45+
CudaDevice cuda_device;
46+
// Deserialize device meta data if device_info is non-empty
47+
if (!serialized_device_info.empty())
48+
{
49+
cuda_device = deserialize_device(serialized_device_info);
50+
// Set CUDA device as configured in serialized meta data
51+
set_cuda_device(cuda_device);
52+
}
53+
3154
rt = nvinfer1::createInferRuntime(logger);
3255

3356
name = slugify(mod_name) + "_engine";
@@ -62,6 +85,7 @@ TRTEngine& TRTEngine::operator=(const TRTEngine& other) {
6285
id = other.id;
6386
rt = other.rt;
6487
cuda_engine = other.cuda_engine;
88+
device_info = other.device_info;
6589
exec_ctx = other.exec_ctx;
6690
num_io = other.num_io;
6791
return (*this);
@@ -73,6 +97,7 @@ TRTEngine::~TRTEngine() {
7397
rt->destroy();
7498
}
7599

100+
76101
// TODO: Implement a call method
77102
// c10::List<at::Tensor> TRTEngine::Run(c10::List<at::Tensor> inputs) {
78103
// auto input_vec = inputs.vec();
@@ -86,15 +111,102 @@ static auto TRTORCH_UNUSED TRTEngineTSRegistrtion = torch::class_<TRTEngine>("te
86111
// TODO: .def("__call__", &TRTEngine::Run)
87112
// TODO: .def("run", &TRTEngine::Run)
88113
.def_pickle(
89-
[](const c10::intrusive_ptr<TRTEngine>& self) -> std::string {
90-
auto serialized_engine = self->cuda_engine->serialize();
91-
return std::string((const char*)serialized_engine->data(), serialized_engine->size());
114+
[](const c10::intrusive_ptr<TRTEngine>& self) -> std::vector<std::string> {
115+
// Serialize TensorRT engine
116+
auto serialized_trt_engine = self->cuda_engine->serialize();
117+
118+
// Adding device info related meta data to the serialized file
119+
auto trt_engine = std::string((const char*)serialized_trt_engine->data(), serialized_trt_engine->size());
120+
121+
CudaDevice cuda_device;
122+
std::vector<std::string> serialize_info;
123+
serialize_info.push_back(serialize_device(cuda_device));
124+
serialize_info.push_back(trt_engine);
125+
return serialize_info;
92126
},
93-
[](std::string seralized_engine) -> c10::intrusive_ptr<TRTEngine> {
94-
return c10::make_intrusive<TRTEngine>(std::move(seralized_engine));
127+
[](std::vector<std::string> seralized_info) -> c10::intrusive_ptr<TRTEngine> {
128+
return c10::make_intrusive<TRTEngine>(std::move(seralized_info));
95129
}
96130
);
97131

132+
133+
int CudaDevice::get_id(void) {
134+
return this->id;
135+
}
136+
137+
void CudaDevice::set_id(int id) {
138+
this->id = id;
139+
}
140+
141+
int CudaDevice::get_major(void) {
142+
return this->major;
143+
}
144+
145+
void CudaDevice::set_major(int major) {
146+
this->major = major;
147+
}
148+
149+
int CudaDevice::get_minor(void) {
150+
return this->minor;
151+
}
152+
153+
void CudaDevice::set_minor(int minor) {
154+
this->minor = minor;
155+
}
156+
157+
void set_cuda_device(CudaDevice& cuda_device) {
158+
TRTORCH_CHECK((cudaSetDevice(cuda_device.id) != cudaSuccess), "Unable to set device: " << cuda_device.id);
159+
}
160+
161+
void get_cuda_device(CudaDevice& cuda_device) {
162+
TRTORCH_CHECK((cudaGetDevice(&cuda_device.id) != cudaSuccess), "Unable to get current device: " << cuda_device.id);
163+
cudaDeviceProp device_prop;
164+
TRTORCH_CHECK((cudaGetDeviceProperties(&device_prop, cuda_device.id) != cudaSuccess), "Unable to get CUDA properties from device:" << cuda_device.id);
165+
cuda_device.set_major(device_prop.major);
166+
cuda_device.set_minor(device_prop.minor);
167+
}
168+
169+
std::string serialize_device(CudaDevice& cuda_device) {
170+
void *buffer = new char[sizeof(cuda_device)];
171+
void *ref_buf = buffer;
172+
173+
int temp = cuda_device.get_id();
174+
memcpy(buffer, reinterpret_cast<int*>(&temp), sizeof(int));
175+
buffer = static_cast<char*>(buffer) + sizeof(int);
176+
177+
temp = cuda_device.get_major();
178+
memcpy(buffer, reinterpret_cast<int*>(&temp), sizeof(int));
179+
buffer = static_cast<char*>(buffer) + sizeof(int);
180+
181+
temp = cuda_device.get_minor();
182+
memcpy(buffer, reinterpret_cast<int*>(&temp), sizeof(int));
183+
buffer = static_cast<char*>(buffer) + sizeof(int);
184+
185+
return std::string((const char*)ref_buf, sizeof(int)*3);
186+
}
187+
188+
CudaDevice deserialize_device(std::string device_info) {
189+
CudaDevice ret;
190+
char *buffer = new char[device_info.size() + 1];
191+
std::copy(device_info.begin(), device_info.end(), buffer);
192+
int temp = 0;
193+
194+
memcpy(&temp, reinterpret_cast<char*>(buffer), sizeof(int));
195+
buffer += sizeof(int);
196+
ret.set_id(temp);
197+
198+
memcpy(&temp, reinterpret_cast<char*>(buffer), sizeof(int));
199+
buffer += sizeof(int);
200+
ret.set_major(temp);
201+
202+
memcpy(&temp, reinterpret_cast<char*>(buffer), sizeof(int));
203+
buffer += sizeof(int);
204+
ret.set_minor(temp);
205+
206+
return ret;
207+
}
208+
209+
98210
} // namespace execution
99211
} // namespace core
100212
} // namespace trtorch

core/execution/execution.h

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,27 @@ namespace execution {
1212

1313
using EngineID = int64_t;
1414

15+
struct CudaDevice {
16+
int id; // CUDA device id
17+
int major; // CUDA compute major version
18+
int minor; // CUDA compute minor version
19+
20+
int get_id(void);
21+
void set_id(int id);
22+
23+
int get_major(void);
24+
void set_major(int major);
25+
26+
int get_minor(void);
27+
void set_minor(int minor);
28+
};
29+
30+
void set_cuda_device(CudaDevice& cuda_device);
31+
void get_cuda_device(CudaDevice& cuda_device);
32+
33+
std::string serialize_device(CudaDevice& cuda_device);
34+
CudaDevice deserialize_device(std::string device_info);
35+
1536
struct TRTEngine : torch::CustomClassHolder {
1637
// Each engine needs it's own runtime object
1738
nvinfer1::IRuntime* rt;
@@ -20,14 +41,16 @@ struct TRTEngine : torch::CustomClassHolder {
2041
std::pair<uint64_t, uint64_t> num_io;
2142
EngineID id;
2243
std::string name;
44+
CudaDevice device_info;
2345
util::logging::TRTorchLogger logger;
2446

2547
std::unordered_map<uint64_t, uint64_t> in_binding_map;
2648
std::unordered_map<uint64_t, uint64_t> out_binding_map;
2749

2850
~TRTEngine();
2951
TRTEngine(std::string serialized_engine);
30-
TRTEngine(std::string mod_name, std::string serialized_engine);
52+
TRTEngine(std::vector<std::string> serialized_info);
53+
TRTEngine(std::string mod_name, std::string serialized_engine, std::string device_info);
3154
TRTEngine& operator=(const TRTEngine& other);
3255
// TODO: Implement a call method
3356
//c10::List<at::Tensor> Run(c10::List<at::Tensor> inputs);

core/util/BUILD

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,6 @@ cc_library(
8484
})
8585
)
8686

87-
8887
load("@rules_pkg//:pkg.bzl", "pkg_tar")
8988

9089
pkg_tar(

cpp/api/include/trtorch/trtorch.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -404,4 +404,13 @@ TRTORCH_API torch::jit::Module CompileGraph(const torch::jit::Module& module, Ex
404404
* @return: std::string: Serialized TensorRT engine equivilant to the method graph
405405
*/
406406
TRTORCH_API std::string ConvertGraphToTRTEngine(const torch::jit::Module& module, std::string method_name, ExtraInfo info);
407+
408+
/**
409+
* @brief Set gpu device id
410+
*
411+
* @param gpu_id
412+
*
413+
* Sets gpu id using cudaSetDevice
414+
*/
415+
TRTORCH_API void set_device(const int gpu_id);
407416
} // namespace trtorch

cpp/api/src/trtorch.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,4 +39,9 @@ void dump_build_info() {
3939
std::cout << get_build_info() << std::endl;
4040
}
4141

42+
void set_device(const int gpu_id) {
43+
// Want to export a much simpler (non CUDA header dependent) API
44+
core::set_device(gpu_id);
45+
}
46+
4247
} // namespace trtorch

tests/modules/test_serialization.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ std::vector<trtorch::ExtraInfo::InputRange> toInputRangesDynamic(std::vector<std
1919
}
2020

2121
TEST_P(ModuleTests, SerializedModuleIsStillCorrect) {
22+
trtorch::set_device(0);
2223
std::vector<torch::jit::IValue> post_serialized_inputs_ivalues;
2324
std::vector<torch::jit::IValue> pre_serialized_inputs_ivalues;
2425
for (auto in_shape : input_shapes) {
@@ -45,6 +46,7 @@ TEST_P(ModuleTests, SerializedModuleIsStillCorrect) {
4546
}
4647

4748
TEST_P(ModuleTests, SerializedDynamicModuleIsStillCorrect) {
49+
trtorch::set_device(0);
4850
std::vector<torch::jit::IValue> post_serialized_inputs_ivalues;
4951
std::vector<torch::jit::IValue> pre_serialized_inputs_ivalues;
5052
for (auto in_shape : input_shapes) {
@@ -77,4 +79,4 @@ INSTANTIATE_TEST_SUITE_P(CompiledModuleForwardIsCloseSuite,
7779
PathAndInSize({"tests/modules/resnet18_traced.jit.pt",
7880
{{1,3,224,224}}}),
7981
PathAndInSize({"tests/modules/pooling_traced.jit.pt",
80-
{{1,3,10,10}}})));
82+
{{1,3,10,10}}})));

tests/util/run_graph_engine.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ std::vector<core::conversion::InputRange> toInputRangesDynamic(std::vector<at::T
4343

4444
std::vector<at::Tensor> RunEngine(std::string& eng, std::vector<at::Tensor> inputs) {
4545
LOG_DEBUG("Running TRT version");
46-
auto engine_ptr = c10::make_intrusive<trtorch::core::execution::TRTEngine>("test_engine", eng);
46+
auto engine_ptr = c10::make_intrusive<trtorch::core::execution::TRTEngine>("test_engine", eng, "");
4747
auto outputs = trtorch::core::execution::execute_engine(inputs, engine_ptr);
4848
return outputs;
4949
}

third_party/cuda/BUILD

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ cc_library(
1919
name = "cudart",
2020
srcs = select({
2121
":aarch64_linux": [
22-
"targets/aarch64-linux/lib/libcudart.so",
22+
"targets/aarch64-linux-gnu/lib/libcudart.so",
2323
],
2424
":windows": [
2525
"lib/x64/cudart.lib",
@@ -40,7 +40,7 @@ cc_library(
4040
name = "nvToolsExt",
4141
srcs = select({
4242
":aarch64_linux": [
43-
"targets/aarch64-linux/lib/libnvToolsExt.so.1",
43+
"targets/aarch64-linux-gnu/lib/libnvToolsExt.so.1",
4444
],
4545
":windows": [
4646
"bin/nvToolsExt64_1.dll",
@@ -55,7 +55,7 @@ cc_library(
5555
name = "cuda",
5656
srcs = select({
5757
":aarch64_linux": glob([
58-
"targets/aarch64-linux/lib/**/lib*.so",
58+
"targets/aarch64-linux-gnu/lib/**/lib*.so",
5959
]),
6060
":windows": [
6161
"bin/*.dll",

third_party/tensorrt/local/BUILD

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -67,11 +67,6 @@ cc_import(
6767
":windows": "lib/nvinfer.dll",
6868
"//conditions:default": "lib/x86_64-linux-gnu/libnvinfer.so",
6969
}),
70-
static_library = select({
71-
":aarch64_linux": "lib/aarch64-linux-gnu/libnvinfer_static.a",
72-
":windows": "lib/nvinfer.lib",
73-
"//conditions:default": "lib/x86_64-linux-gnu/libnvinfer_static.a"
74-
}),
7570
visibility = ["//visibility:private"],
7671
)
7772

0 commit comments

Comments
 (0)