Skip to content

Commit 611f6a1

Browse files
author
Anurag Dixit
committed
refactor: Review comments incorporated
Signed-off-by: Anurag Dixit <[email protected]>
1 parent 0bd8d28 commit 611f6a1

File tree

5 files changed

+48
-45
lines changed

5 files changed

+48
-45
lines changed

core/compiler.cpp

Lines changed: 0 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -27,33 +27,13 @@
2727
namespace trtorch {
2828
namespace core {
2929

30-
static std::unordered_map<int, runtime::CudaDevice> cuda_device_list;
31-
32-
void update_cuda_device_list(void) {
33-
int num_devices = 0;
34-
auto status = cudaGetDeviceCount(&num_devices);
35-
TRTORCH_ASSERT((status == cudaSuccess), "Unable to read CUDA capable devices. Return status: " << status);
36-
cudaDeviceProp device_prop;
37-
for (int i = 0; i < num_devices; i++) {
38-
TRTORCH_CHECK(
39-
(cudaGetDeviceProperties(&device_prop, i) == cudaSuccess),
40-
"Unable to read CUDA Device Properies for device id: " << i);
41-
std::string device_name(device_prop.name);
42-
runtime::CudaDevice device = {
43-
i, device_prop.major, device_prop.minor, nvinfer1::DeviceType::kGPU, device_name.size(), device_name};
44-
cuda_device_list[i] = device;
45-
}
46-
}
47-
4830
void AddEngineToGraph(
4931
torch::jit::script::Module mod,
5032
std::shared_ptr<torch::jit::Graph>& g,
5133
const std::string& serialized_engine,
5234
runtime::CudaDevice& device_info,
5335
std::string engine_id = "",
5436
bool fallback = false) {
55-
// Scan and Update the list of available cuda devices
56-
update_cuda_device_list();
5737
auto engine_ptr =
5838
c10::make_intrusive<runtime::TRTEngine>(mod._ivalue()->name() + engine_id, serialized_engine, device_info);
5939
// Get required metadata about the engine out

core/runtime/TRTEngine.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,8 @@ TRTEngine::TRTEngine(std::string mod_name, std::string serialized_engine, CudaDe
5050
name = slugify(mod_name) + "_engine";
5151

5252
cuda_engine = rt->deserializeCudaEngine(serialized_engine.c_str(), serialized_engine.size());
53+
TRTORCH_CHECK((cuda_engine != nullptr), "Unable to deserialize the TensorRT engine");
54+
5355
// Easy way to get a unique name for each engine, maybe there is a more
5456
// descriptive way (using something associated with the graph maybe)
5557
id = reinterpret_cast<EngineID>(cuda_engine);

core/runtime/register_trt_op.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,9 +49,9 @@ int select_cuda_device(const CudaDevice& conf_device) {
4949
int device_id = 0;
5050
auto dla_supported = get_dla_supported_SM();
5151

52-
auto cuda_device_list = DeviceList::instance().get_devices();
52+
auto device_list = cuda_device_list.instance().get_devices();
5353

54-
for (auto device : cuda_device_list) {
54+
for (auto device : device_list) {
5555
auto compute_cap = std::to_string(device.second.major) + "." + std::to_string(device.second.minor);
5656
// In case of DLA select the DLA supported device ID
5757
if (conf_device.device_type == nvinfer1::DeviceType::kDLA) {

core/runtime/runtime.h

Lines changed: 43 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -78,28 +78,6 @@ CudaDevice deserialize_device(std::string device_info);
7878

7979
CudaDevice get_device_info(int64_t gpu_id, nvinfer1::DeviceType device_type);
8080

81-
class DeviceList {
82-
using DeviceMap = std::unordered_map<int, CudaDevice>;
83-
DeviceMap device_list;
84-
DeviceList() {}
85-
86-
public:
87-
static DeviceList& instance() {
88-
static DeviceList obj;
89-
return obj;
90-
}
91-
92-
void insert(int device_id, CudaDevice cuda_device) {
93-
device_list[device_id] = cuda_device;
94-
}
95-
CudaDevice find(int device_id) {
96-
return device_list[device_id];
97-
}
98-
DeviceMap get_devices() {
99-
return device_list;
100-
}
101-
};
102-
10381
struct TRTEngine : torch::CustomClassHolder {
10482
// Each engine needs it's own runtime object
10583
nvinfer1::IRuntime* rt;
@@ -125,6 +103,49 @@ struct TRTEngine : torch::CustomClassHolder {
125103

126104
std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intrusive_ptr<TRTEngine> compiled_engine);
127105

106+
class DeviceList {
107+
using DeviceMap = std::unordered_map<int, CudaDevice>;
108+
DeviceMap device_list;
109+
110+
public:
111+
// Scans and updates the list of available CUDA devices
112+
DeviceList(void) {
113+
int num_devices = 0;
114+
auto status = cudaGetDeviceCount(&num_devices);
115+
TRTORCH_ASSERT((status == cudaSuccess), "Unable to read CUDA capable devices. Return status: " << status);
116+
cudaDeviceProp device_prop;
117+
for (int i = 0; i < num_devices; i++) {
118+
TRTORCH_CHECK(
119+
(cudaGetDeviceProperties(&device_prop, i) == cudaSuccess),
120+
"Unable to read CUDA Device Properies for device id: " << i);
121+
std::string device_name(device_prop.name);
122+
CudaDevice device = {
123+
i, device_prop.major, device_prop.minor, nvinfer1::DeviceType::kGPU, device_name.size(), device_name};
124+
device_list[i] = device;
125+
}
126+
}
127+
128+
public:
129+
static DeviceList& instance() {
130+
static DeviceList obj;
131+
return obj;
132+
}
133+
134+
void insert(int device_id, CudaDevice cuda_device) {
135+
device_list[device_id] = cuda_device;
136+
}
137+
CudaDevice find(int device_id) {
138+
return device_list[device_id];
139+
}
140+
DeviceMap get_devices() {
141+
return device_list;
142+
}
143+
};
144+
145+
namespace {
146+
static DeviceList cuda_device_list;
147+
}
148+
128149
} // namespace runtime
129150
} // namespace core
130151
} // namespace trtorch

cpp/api/include/trtorch/trtorch.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -517,7 +517,7 @@ TRTORCH_API std::string ConvertGraphToTRTEngine(
517517
* in a TorchScript module
518518
*
519519
* @param engine: std::string - Pre-built serialized TensorRT engine
520-
* @param info: CompileSepc::Device - Device information
520+
* @param device: CompileSepc::Device - Device information
521521
*
522522
* Takes a pre-built serialized TensorRT engine and embeds it in a TorchScript
523523
* module. Registers execution of the engine as the forward method of the module

0 commit comments

Comments
 (0)