Skip to content

Commit f7bef90

Browse files
author
Anurag Dixit
committed
refactor: Review comments incorporated
Signed-off-by: Anurag Dixit <[email protected]>
1 parent 5559627 commit f7bef90

File tree

10 files changed

+91
-34
lines changed

10 files changed

+91
-34
lines changed

core/compiler.cpp

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,13 +27,33 @@
2727
namespace trtorch {
2828
namespace core {
2929

30+
static std::unordered_map<int, runtime::CudaDevice> cuda_device_list;
31+
32+
void update_cuda_device_list(void) {
33+
int num_devices = 0;
34+
auto status = cudaGetDeviceCount(&num_devices);
35+
TRTORCH_ASSERT((status == cudaSuccess), "Unable to read CUDA capable devices. Return status: " << status);
36+
cudaDeviceProp device_prop;
37+
for (int i = 0; i < num_devices; i++) {
38+
TRTORCH_CHECK(
39+
(cudaGetDeviceProperties(&device_prop, i) == cudaSuccess),
40+
"Unable to read CUDA Device Properies for device id: " << i);
41+
std::string device_name(device_prop.name);
42+
runtime::CudaDevice device = {
43+
i, device_prop.major, device_prop.minor, nvinfer1::DeviceType::kGPU, device_name.size(), device_name};
44+
cuda_device_list[i] = device;
45+
}
46+
}
47+
3048
void AddEngineToGraph(
3149
torch::jit::script::Module mod,
3250
std::shared_ptr<torch::jit::Graph>& g,
3351
const std::string& serialized_engine,
3452
runtime::CudaDevice& device_info,
3553
std::string engine_id = "",
3654
bool fallback = false) {
55+
// Scan and Update the list of available cuda devices
56+
update_cuda_device_list();
3757
auto engine_ptr =
3858
c10::make_intrusive<runtime::TRTEngine>(mod._ivalue()->name() + engine_id, serialized_engine, device_info);
3959
// Get required metadata about the engine out
@@ -277,13 +297,13 @@ torch::jit::script::Module CompileGraph(const torch::jit::script::Module& mod, C
277297
return new_mod;
278298
}
279299

280-
torch::jit::script::Module EmbedEngineInNewModule(const std::string& engine, CompileSpec cfg) {
300+
torch::jit::script::Module EmbedEngineInNewModule(
301+
const std::string& engine,
302+
trtorch::core::runtime::CudaDevice cuda_device) {
281303
std::ostringstream engine_id;
282304
engine_id << reinterpret_cast<const int*>(&engine);
283305
torch::jit::script::Module new_mod("tensorrt_engine_mod_" + engine_id.str());
284306
auto new_g = std::make_shared<torch::jit::Graph>();
285-
auto device_spec = cfg.convert_info.engine_settings.device;
286-
auto cuda_device = runtime::get_device_info(device_spec.gpu_id, device_spec.device_type);
287307
AddEngineToGraph(new_mod, new_g, engine, cuda_device);
288308
auto new_method = new_mod._ivalue()->compilation_unit()->create_function("forward", new_g);
289309
auto schema = util::GenerateGraphSchema(new_method->name(), new_g);

core/compiler.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ std::string ConvertGraphToTRTEngine(const torch::jit::script::Module& mod, std::
2323

2424
torch::jit::script::Module CompileGraph(const torch::jit::script::Module& module, CompileSpec cfg);
2525

26-
torch::jit::script::Module EmbedEngineInNewModule(const std::string& engine, CompileSpec cfg);
26+
torch::jit::script::Module EmbedEngineInNewModule(const std::string& engine, runtime::CudaDevice cuda_device);
2727

2828
void set_device(const int gpu_id);
2929

core/runtime/TRTEngine.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,9 @@ TRTEngine::TRTEngine(std::vector<std::string> serialized_info)
3131
util::logging::get_logger().get_reportable_severity(),
3232
util::logging::get_logger().get_is_colored_output_on()) {
3333
std::string _name = "deserialized_trt";
34-
std::string engine_info = serialized_info[EngineIdx];
34+
std::string engine_info = serialized_info[ENGINE_IDX];
3535

36-
CudaDevice cuda_device = deserialize_device(serialized_info[DeviceIdx]);
36+
CudaDevice cuda_device = deserialize_device(serialized_info[DEVICE_IDX]);
3737
new (this) TRTEngine(_name, engine_info, cuda_device);
3838
}
3939

core/runtime/register_trt_op.cpp

Lines changed: 17 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,13 @@ namespace trtorch {
1010
namespace core {
1111
namespace runtime {
1212

13+
// SM Compute capability <Compute Capability, Device Name> map
14+
const std::unordered_map<std::string, std::string>& get_dla_supported_SM() {
15+
// Xavier SM Compute Capability
16+
static std::unordered_map<std::string, std::string> dla_supported_SM = {{"7.2", "Xavier"}};
17+
return dla_supported_SM;
18+
}
19+
1320
// Checks if the context switch requred for device ID
1421
bool is_switch_required(const CudaDevice& curr_device, const CudaDevice& conf_device) {
1522
// If SM capability is not the same as configured then switch
@@ -40,34 +47,23 @@ bool is_switch_required(const CudaDevice& curr_device, const CudaDevice& conf_de
4047

4148
int select_cuda_device(const CudaDevice& conf_device) {
4249
int device_id = 0;
43-
int num_devices = 0;
44-
// SM Compute capability <major,minor> pair
45-
std::unordered_map<std::string, std::string> dla_supported_SM;
50+
auto dla_supported = get_dla_supported_SM();
4651

47-
// Xavier SM Compute Capability
48-
dla_supported_SM.insert(std::make_pair("7.2", "Xavier"));
49-
auto status = cudaGetDeviceCount(&num_devices);
50-
TRTORCH_CHECK((status == cudaSuccess), "Unable to read CUDA capable devices. Return status: " << status);
51-
52-
cudaDeviceProp device_prop;
52+
auto cuda_device_list = DeviceList::instance().get_devices();
5353

54-
for (int i = 0; i < num_devices; i++) {
55-
TRTORCH_CHECK(
56-
(cudaGetDeviceProperties(&device_prop, i) == cudaSuccess),
57-
"Unable to read CUDA Device Properies for device id: " << i);
58-
auto compute_cap = std::to_string(device_prop.major) + "." + std::to_string(device_prop.minor);
59-
std::string device_name{device_prop.name};
54+
for (auto device : cuda_device_list) {
55+
auto compute_cap = std::to_string(device.second.major) + "." + std::to_string(device.second.minor);
6056
// In case of DLA select the DLA supported device ID
6157
if (conf_device.device_type == nvinfer1::DeviceType::kDLA) {
62-
if (dla_supported_SM.find(compute_cap) != dla_supported_SM.end() &&
63-
dla_supported_SM[compute_cap] == device_name) {
64-
device_id = i;
58+
if (dla_supported.find(compute_cap) != dla_supported.end() &&
59+
dla_supported[compute_cap] == device.second.device_name) {
60+
device_id = device.second.id;
6561
break;
6662
}
6763
} else if (conf_device.device_type == nvinfer1::DeviceType::kGPU) {
6864
auto conf_sm = std::to_string(conf_device.major) + "." + std::to_string(conf_device.minor);
69-
if (compute_cap == conf_sm && device_name == conf_device.device_name) {
70-
device_id = i;
65+
if (compute_cap == conf_sm && device.second.device_name == conf_device.device_name) {
66+
device_id = device.second.id;
7167
break;
7268
}
7369
} else {
@@ -83,6 +79,7 @@ std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intr
8379

8480
CudaDevice curr_device;
8581
get_cuda_device(curr_device);
82+
LOG_DEBUG("Current Device ID: " << curr_device.id);
8683

8784
if (is_switch_required(curr_device, compiled_engine->device_info)) {
8885
// Scan through available CUDA devices and set the CUDA device context correctly

core/runtime/runtime.h

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ namespace runtime {
1111

1212
using EngineID = int64_t;
1313

14-
typedef enum { DeviceIdx = 0, EngineIdx } SerializedInfoIndex;
14+
typedef enum { DEVICE_IDX = 0, ENGINE_IDX } SerializedInfoIndex;
1515

1616
struct CudaDevice {
1717
int64_t id; // CUDA device id
@@ -78,6 +78,28 @@ CudaDevice deserialize_device(std::string device_info);
7878

7979
CudaDevice get_device_info(int64_t gpu_id, nvinfer1::DeviceType device_type);
8080

81+
class DeviceList {
82+
using DeviceMap = std::unordered_map<int, CudaDevice>;
83+
DeviceMap device_list;
84+
DeviceList() {}
85+
86+
public:
87+
static DeviceList& instance() {
88+
static DeviceList obj;
89+
return obj;
90+
}
91+
92+
void insert(int device_id, CudaDevice cuda_device) {
93+
device_list[device_id] = cuda_device;
94+
}
95+
CudaDevice find(int device_id) {
96+
return device_list[device_id];
97+
}
98+
DeviceMap get_devices() {
99+
return device_list;
100+
}
101+
};
102+
81103
struct TRTEngine : torch::CustomClassHolder {
82104
// Each engine needs it's own runtime object
83105
nvinfer1::IRuntime* rt;

cpp/api/include/trtorch/trtorch.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -517,15 +517,15 @@ TRTORCH_API std::string ConvertGraphToTRTEngine(
517517
* in a TorchScript module
518518
*
519519
* @param engine: std::string - Pre-built serialized TensorRT engine
520-
* @param info: trtorch::CompileSpec - Compilation settings
520+
* @param info: CompileSepc::Device - Device information
521521
*
522522
* Takes a pre-built serialized TensorRT engine and embeds it in a TorchScript
523523
* module. Registers execution of the engine as the forward method of the module
524524
* Forward is defined as: forward(Tensor[]) -> Tensor[]
525525
*
526526
* @return: A new module trageting a TensorRT engine
527527
*/
528-
TRTORCH_API torch::jit::Module EmbedEngineInNewModule(const std::string& engine, CompileSpec info);
528+
TRTORCH_API torch::jit::Module EmbedEngineInNewModule(const std::string& engine, CompileSpec::Device device);
529529

530530
/**
531531
* @brief Set gpu device id

cpp/api/src/compile_spec.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,19 @@ std::vector<core::ir::InputRange> to_vec_internal_input_ranges(std::vector<Compi
7474
return internal;
7575
}
7676

77+
core::runtime::CudaDevice to_internal_cuda_device(CompileSpec::Device device) {
78+
auto device_type = nvinfer1::DeviceType::kGPU;
79+
switch (device.device_type) {
80+
case CompileSpec::Device::DeviceType::kDLA:
81+
device_type = nvinfer1::DeviceType::kDLA;
82+
break;
83+
case CompileSpec::Device::DeviceType::kGPU:
84+
default:
85+
device_type = nvinfer1::DeviceType::kGPU;
86+
}
87+
return core::runtime::get_device_info(device.gpu_id, device_type);
88+
}
89+
7790
core::CompileSpec to_internal_compile_spec(CompileSpec external) {
7891
core::CompileSpec internal(to_vec_internal_input_ranges(external.input_ranges));
7992

cpp/api/src/trtorch.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ namespace trtorch {
99

1010
// Defined in compile_spec.cpp
1111
core::CompileSpec to_internal_compile_spec(CompileSpec external);
12+
core::runtime::CudaDevice to_internal_cuda_device(CompileSpec::Device device);
1213

1314
bool CheckMethodOperatorSupport(const torch::jit::script::Module& module, std::string method_name) {
1415
return core::CheckMethodOperatorSupport(module, method_name);
@@ -31,8 +32,8 @@ torch::jit::script::Module CompileGraph(const torch::jit::script::Module& module
3132
return core::CompileGraph(module, to_internal_compile_spec(info));
3233
}
3334

34-
torch::jit::Module EmbedEngineInNewModule(const std::string& engine, CompileSpec info) {
35-
return core::EmbedEngineInNewModule(engine, to_internal_compile_spec(info));
35+
torch::jit::Module EmbedEngineInNewModule(const std::string& engine, CompileSpec::Device device) {
36+
return core::EmbedEngineInNewModule(engine, to_internal_cuda_device(device));
3637
}
3738

3839
std::string get_build_info() {

py/trtorch/csrc/trtorch_py.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -119,8 +119,8 @@ bool CheckMethodOperatorSupport(const torch::jit::Module& module, const std::str
119119
return core::CheckMethodOperatorSupport(module, method_name);
120120
}
121121

122-
torch::jit::Module EmbedEngineInNewModule(const py::bytes& engine, CompileSpec& info) {
123-
return core::EmbedEngineInNewModule(engine, info.toInternalCompileSpec());
122+
torch::jit::Module EmbedEngineInNewModule(const py::bytes& engine, core::runtime::CudaDevice& device) {
123+
return core::EmbedEngineInNewModule(engine, device);
124124
}
125125

126126
std::string get_build_info() {

tests/modules/test_modules_as_engines.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,12 @@ TEST_P(ModuleTests, ModuleToEngineToModuleIsClose) {
3636
}
3737

3838
auto compile_spec = trtorch::CompileSpec({input_ranges});
39+
int device_id = 0;
40+
cudaGetDevice(&device_id);
41+
compile_spec.device.device_type = trtorch::CompileSpec::Device::DeviceType::kGPU;
42+
compile_spec.device.gpu_id = device_id;
3943
auto engine = trtorch::ConvertGraphToTRTEngine(mod, "forward", input_ranges);
40-
auto trt_mod = trtorch::EmbedEngineInNewModule(engine, compile_spec);
44+
auto trt_mod = trtorch::EmbedEngineInNewModule(engine, compile_spec.device);
4145

4246
torch::jit::IValue trt_results_ivalues = trtorch::tests::util::RunModuleForward(mod, inputs_ivalues);
4347
std::vector<at::Tensor> trt_results;

0 commit comments

Comments
 (0)