Skip to content

Commit 1ea0b45

Browse files
author
Anurag Dixit
committed
Rebased with DLA device support
Signed-off-by: Anurag Dixit <[email protected]>
1 parent d24574a commit 1ea0b45

File tree

4 files changed

+16
-23
lines changed

4 files changed

+16
-23
lines changed

core/compiler.cpp

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -49,13 +49,15 @@ void AddEngineToGraph(
4949
std::shared_ptr<torch::jit::Graph>& g,
5050
std::string& serialized_engine) {
5151

52+
runtime::CudaDevice device;
53+
5254
// Read current CUDA device properties
5355
runtime::get_cuda_device(device);
5456

5557
// Serialize current device information
5658
auto device_info = runtime::serialize_device(device);
5759

58-
auto engine_ptr = c10::make_intrusive<execution::TRTEngine>(mod._ivalue()->name(), serialized_engine, device_info);
60+
auto engine_ptr = c10::make_intrusive<runtime::TRTEngine>(mod._ivalue()->name(), serialized_engine, device_info);
5961
// Get required metadata about the engine out
6062
auto num_io = engine_ptr->num_io;
6163
auto name = engine_ptr->name;
@@ -181,10 +183,6 @@ torch::jit::script::Module CompileGraph(const torch::jit::script::Module& mod, C
181183
return new_mod;
182184
}
183185

184-
void set_device(const int gpu_id) {
185-
TRTORCH_ASSERT(cudaSetDevice(gpu_id) == cudaSuccess, "Unable to set CUDA device: " << gpu_id);
186-
}
187-
188186
void set_device(const int gpu_id) {
189187
TRTORCH_CHECK((cudaSetDevice(gpu_id) == cudaSuccess), "Unable to set CUDA device: " << gpu_id);
190188
}

core/runtime/TRTEngine.cpp

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#include <algorithm>
22

3+
#include <cuda_runtime.h>
34
#include "NvInfer.h"
45
#include "torch/csrc/jit/frontend/function_schema_parser.h"
56

@@ -23,7 +24,7 @@ TRTEngine::TRTEngine(std::string serialized_engine)
2324
util::logging::get_logger().get_reportable_severity(),
2425
util::logging::get_logger().get_is_colored_output_on()) {
2526
std::string _name = "deserialized_trt";
26-
new (this) TRTEngine(_name, serialized_engine);
27+
new (this) TRTEngine(_name, serialized_engine, empty_string);
2728
}
2829

2930
TRTEngine::TRTEngine(std::vector<std::string> serialized_info)
@@ -37,7 +38,8 @@ TRTEngine::TRTEngine(std::vector<std::string> serialized_info)
3738
new (this) TRTEngine(_name, engine_info, device_info);
3839
}
3940

40-
TRTEngine::TRTEngine(std::string mod_name, std::string serialized_engine)
41+
TRTEngine::TRTEngine(std::string mod_name, std::string serialized_engine,
42+
std::string serialized_device_info = empty_string)
4143
: logger(
4244
std::string("[") + mod_name + std::string("_engine] - "),
4345
util::logging::get_logger().get_reportable_severity(),
@@ -105,7 +107,6 @@ TRTEngine::~TRTEngine() {
105107
// return c10::List<at::Tensor>(output_vec);
106108
// }
107109

108-
namespace {
109110
static auto TRTORCH_UNUSED TRTEngineTSRegistrtion =
110111
torch::class_<TRTEngine>("tensorrt", "Engine")
111112
.def(torch::init<std::string>())
@@ -120,13 +121,14 @@ static auto TRTORCH_UNUSED TRTEngineTSRegistrtion =
120121
auto trt_engine = std::string((const char*)serialized_trt_engine->data(), serialized_trt_engine->size());
121122

122123
CudaDevice cuda_device;
124+
get_cuda_device(cuda_device);
123125
std::vector<std::string> serialize_info;
124126
serialize_info.push_back(serialize_device(cuda_device));
125127
serialize_info.push_back(trt_engine);
126128
return serialize_info;
127129
},
128-
[](std::string seralized_engine) -> c10::intrusive_ptr<TRTEngine> {
129-
return c10::make_intrusive<TRTEngine>(std::move(seralized_engine));
130+
[](std::vector<std::string> seralized_info) -> c10::intrusive_ptr<TRTEngine> {
131+
return c10::make_intrusive<TRTEngine>(std::move(seralized_info));
130132
});
131133

132134
int CudaDevice::get_id(void) {
@@ -154,13 +156,13 @@ void CudaDevice::set_minor(int minor) {
154156
}
155157

156158
void set_cuda_device(CudaDevice& cuda_device) {
157-
TRTORCH_CHECK((cudaSetDevice(cuda_device.id) != cudaSuccess), "Unable to set device: " << cuda_device.id);
159+
TRTORCH_CHECK((cudaSetDevice(cuda_device.id) == cudaSuccess), "Unable to set device: " << cuda_device.id);
158160
}
159161

160162
void get_cuda_device(CudaDevice& cuda_device) {
161-
TRTORCH_CHECK((cudaGetDevice(&cuda_device.id) != cudaSuccess), "Unable to get current device: " << cuda_device.id);
163+
TRTORCH_CHECK((cudaGetDevice(&cuda_device.id) == cudaSuccess), "Unable to get current device: " << cuda_device.id);
162164
cudaDeviceProp device_prop;
163-
TRTORCH_CHECK((cudaGetDeviceProperties(&device_prop, cuda_device.id) != cudaSuccess), "Unable to get CUDA properties from device:" << cuda_device.id);
165+
TRTORCH_CHECK((cudaGetDeviceProperties(&device_prop, cuda_device.id) == cudaSuccess), "Unable to get CUDA properties from device:" << cuda_device.id);
164166
cuda_device.set_major(device_prop.major);
165167
cuda_device.set_minor(device_prop.minor);
166168
}
@@ -205,8 +207,6 @@ CudaDevice deserialize_device(std::string device_info) {
205207
return ret;
206208
}
207209

208-
209-
} // namespace
210210
} // namespace runtime
211211
} // namespace core
212212
} // namespace trtorch

cpp/api/src/trtorch.cpp

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -45,9 +45,4 @@ void set_device(const int gpu_id) {
4545
core::set_device(gpu_id);
4646
}
4747

48-
void set_device(const int gpu_id) {
49-
// Want to export a much simpler (non CUDA header dependent) API
50-
core::set_device(gpu_id);
51-
}
52-
5348
} // namespace trtorch

third_party/cuda/BUILD

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ cc_library(
1919
name = "cudart",
2020
srcs = select({
2121
":aarch64_linux": [
22-
"targets/aarch64-linux-gnu/lib/libcudart.so",
22+
"targets/aarch64-linux/lib/libcudart.so",
2323
],
2424
":windows": [
2525
"lib/x64/cudart.lib",
@@ -40,7 +40,7 @@ cc_library(
4040
name = "nvToolsExt",
4141
srcs = select({
4242
":aarch64_linux": [
43-
"targets/aarch64-linux-gnu/lib/libnvToolsExt.so.1",
43+
"targets/aarch64-linux/lib/libnvToolsExt.so.1",
4444
],
4545
":windows": [
4646
"bin/nvToolsExt64_1.dll",
@@ -55,7 +55,7 @@ cc_library(
5555
name = "cuda",
5656
srcs = select({
5757
":aarch64_linux": glob([
58-
"targets/aarch64-linux-gnu/lib/**/lib*.so",
58+
"targets/aarch64-linux/lib/**/lib*.so",
5959
]),
6060
":windows": [
6161
"bin/*.dll",

0 commit comments

Comments
 (0)