Skip to content

Commit 4aaf9f9

Browse files
committed
feat: Add hardware compatibility option in Dynamo
- Add support for hardware compatibility for Ampere and later architectures - Add necessary functions to support the modification throughout the stack, including C++ and Python components - Update ABI version to address new metadata format for TRT Engines - Update engine serialization schema accordingly - Add test cases to validate feature
1 parent 78756c4 commit 4aaf9f9

File tree

14 files changed

+151
-23
lines changed

14 files changed

+151
-23
lines changed

core/runtime/TRTEngine.cpp

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -32,24 +32,35 @@ TRTEngine::TRTEngine(
3232
const std::string& serialized_engine,
3333
const RTDevice& cuda_device,
3434
const std::vector<std::string>& _in_binding_names,
35-
const std::vector<std::string>& _out_binding_names)
36-
: TRTEngine("deserialized_trt", serialized_engine, cuda_device, _in_binding_names, _out_binding_names) {}
35+
const std::vector<std::string>& _out_binding_names,
36+
bool hardware_compatible)
37+
: TRTEngine(
38+
"deserialized_trt",
39+
serialized_engine,
40+
cuda_device,
41+
_in_binding_names,
42+
_out_binding_names,
43+
hardware_compatible) {}
3744

3845
TRTEngine::TRTEngine(std::vector<std::string> serialized_info)
3946
: TRTEngine(
4047
serialized_info[NAME_IDX],
4148
serialized_info[ENGINE_IDX],
4249
RTDevice(serialized_info[DEVICE_IDX]),
4350
split(serialized_info[INPUT_BINDING_NAMES_IDX], BINDING_DELIM),
44-
split(serialized_info[OUTPUT_BINDING_NAMES_IDX], BINDING_DELIM)) {}
51+
split(serialized_info[OUTPUT_BINDING_NAMES_IDX], BINDING_DELIM),
52+
static_cast<bool>(std::stoi(serialized_info[HW_COMPATIBLE_IDX]))) {}
4553

4654
TRTEngine::TRTEngine(
4755
const std::string& mod_name,
4856
const std::string& serialized_engine,
4957
const RTDevice& cuda_device,
5058
const std::vector<std::string>& _in_binding_names,
51-
const std::vector<std::string>& _out_binding_names) {
52-
auto most_compatible_device = get_most_compatible_device(cuda_device);
59+
const std::vector<std::string>& _out_binding_names,
60+
bool hardware_compatible) {
61+
this->hardware_compatible = hardware_compatible;
62+
63+
auto most_compatible_device = get_most_compatible_device(cuda_device, RTDevice(), hardware_compatible);
5364
TORCHTRT_CHECK(most_compatible_device, "No compatible device was found for instantiating TensorRT engine");
5465
device_info = most_compatible_device.value();
5566
multi_gpu_device_check();
@@ -232,6 +243,7 @@ std::string TRTEngine::to_str() const {
232243
}
233244
ss << " }" << std::endl;
234245
ss << " Device: " << device_info << std::endl;
246+
ss << " Hardware Compatibility: " << (hardware_compatible ? "Enabled" : "Disabled") << std::endl;
235247
// clang-format on
236248
return ss.str();
237249
}

core/runtime/TRTEngine.h

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,19 +34,23 @@ struct TRTEngine : torch::CustomClassHolder {
3434
std::vector<std::string> in_binding_names = {}; // ITO: PYT IDX
3535
std::vector<std::string> out_binding_names = {}; // ITO: PYT IDX
3636

37+
bool hardware_compatible = false; // Whether the engine was compiled in hardware compatible mode
38+
3739
~TRTEngine();
3840
TRTEngine(
3941
const std::string& serialized_engine,
4042
const RTDevice& cuda_device,
4143
const std::vector<std::string>& in_binding_names,
42-
const std::vector<std::string>& out_binding_names);
44+
const std::vector<std::string>& out_binding_names,
45+
bool hardware_compatible = false);
4346
TRTEngine(std::vector<std::string> serialized_info);
4447
TRTEngine(
4548
const std::string& mod_name,
4649
const std::string& serialized_engine,
4750
const RTDevice& cuda_device,
4851
const std::vector<std::string>& in_binding_names,
49-
const std::vector<std::string>& out_binding_names);
52+
const std::vector<std::string>& out_binding_names,
53+
bool hardware_compatible = false);
5054
TRTEngine& operator=(const TRTEngine& other);
5155
std::string to_str() const;
5256
static void verify_serialization_fmt(const std::vector<std::string>& serialized_info);

core/runtime/execute_engine.cpp

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,8 @@ bool is_switch_required(const RTDevice& curr_device, const RTDevice& engine_devi
4343
return false;
4444
}
4545

46-
RTDevice select_rt_device(const RTDevice& engine_device, const RTDevice& curr_device) {
47-
auto new_target_device_opt = get_most_compatible_device(engine_device, curr_device);
46+
RTDevice select_rt_device(const RTDevice& engine_device, const RTDevice& curr_device, bool hardware_compatible) {
47+
auto new_target_device_opt = get_most_compatible_device(engine_device, curr_device, hardware_compatible);
4848

4949
// REVIEW: THIS DOES NOT LIST DLA PROBABLY, WHICH WE SHOULD
5050
// TODO: I think this logic could be way simpler at execution time since if the tensors arent on the right
@@ -59,7 +59,9 @@ RTDevice select_rt_device(const RTDevice& engine_device, const RTDevice& curr_de
5959
}
6060

6161
std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intrusive_ptr<TRTEngine> compiled_engine) {
62-
LOG_DEBUG("Attempting to run engine (ID: " << compiled_engine->name << ")");
62+
LOG_DEBUG(
63+
"Attempting to run engine (ID: " << compiled_engine->name
64+
<< "); Hardware Compatible: " << compiled_engine->hardware_compatible);
6365

6466
if (compiled_engine->profile_execution) {
6567
std::stringstream ss;
@@ -89,7 +91,8 @@ std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intr
8991

9092
if (is_switch_required(curr_device, compiled_engine->device_info)) {
9193
// Scan through available CUDA devices and set the CUDA device context correctly
92-
RTDevice device = select_rt_device(compiled_engine->device_info, curr_device);
94+
RTDevice device =
95+
select_rt_device(compiled_engine->device_info, curr_device, compiled_engine->hardware_compatible);
9396
set_rt_device(device);
9497

9598
// Target device is new device

core/runtime/register_jit_hooks.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,9 @@ static auto TORCHTRT_UNUSED TRTEngineTSRegistrtion =
101101
serialize_info[ENGINE_IDX] = base64_encode(trt_engine);
102102
serialize_info[INPUT_BINDING_NAMES_IDX] = serialize_bindings(self->in_binding_names);
103103
serialize_info[OUTPUT_BINDING_NAMES_IDX] = serialize_bindings(self->out_binding_names);
104+
serialize_info[HW_COMPATIBLE_IDX] = self->hardware_compatible ? "1" : "0";
105+
106+
LOG_DEBUG("Serialized Hardware Compatibility: " << (self->hardware_compatible ? "Enabled" : "Disabled"));
104107

105108
return serialize_info;
106109
},

core/runtime/runtime.cpp

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,12 @@ namespace runtime {
99

1010
bool MULTI_DEVICE_SAFE_MODE = false;
1111

12-
c10::optional<RTDevice> get_most_compatible_device(const RTDevice& target_device, const RTDevice& curr_device) {
12+
c10::optional<RTDevice> get_most_compatible_device(
13+
const RTDevice& target_device,
14+
const RTDevice& curr_device,
15+
bool hardware_compatible) {
1316
LOG_DEBUG("Target Device: " << target_device);
14-
auto device_options = find_compatible_devices(target_device);
17+
auto device_options = find_compatible_devices(target_device, hardware_compatible);
1518
RTDevice current_device;
1619
if (current_device.id == -1) {
1720
current_device = get_current_device();
@@ -30,7 +33,8 @@ c10::optional<RTDevice> get_most_compatible_device(const RTDevice& target_device
3033
dev_list << "[" << std::endl;
3134
for (auto device : device_options) {
3235
dev_list << " " << device << ',' << std::endl;
33-
if (device.device_name == target_device.device_name) {
36+
// If the model is hardware compatible, any compatible device should be valid
37+
if ((device.device_name == target_device.device_name) || hardware_compatible) {
3438
// First priority is selecting a candidate which agrees with the current device ID
3539
// If such a device is found, we can select it and break out of the loop
3640
if (device.id == current_device.id) {
@@ -60,7 +64,7 @@ c10::optional<RTDevice> get_most_compatible_device(const RTDevice& target_device
6064
}
6165
}
6266

63-
std::vector<RTDevice> find_compatible_devices(const RTDevice& target_device) {
67+
std::vector<RTDevice> find_compatible_devices(const RTDevice& target_device, bool hardware_compatible) {
6468
auto dla_supported = get_dla_supported_SMs();
6569
auto device_list = get_available_device_list().get_devices();
6670

@@ -76,7 +80,8 @@ std::vector<RTDevice> find_compatible_devices(const RTDevice& target_device) {
7680
} else if (target_device.device_type == nvinfer1::DeviceType::kGPU) {
7781
auto target_dev_cc = target_device.getSMCapability();
7882
// If the SM Capabilities match, should be good enough to run
79-
if (poss_dev_cc == target_dev_cc) {
83+
// If hardware compatibility mode is enabled and the SM is at least 80, device is valid
84+
if ((poss_dev_cc == target_dev_cc) || (hardware_compatible && std::stoi(poss_dev_cc) >= 8)) {
8085
compatible_devices.push_back(device.second);
8186
}
8287
} else {

core/runtime/runtime.h

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ namespace core {
1515
namespace runtime {
1616

1717
using EngineID = int64_t;
18-
const std::string ABI_VERSION = "4";
18+
const std::string ABI_VERSION = "5";
1919
extern bool MULTI_DEVICE_SAFE_MODE;
2020
typedef enum {
2121
ABI_TARGET_IDX = 0,
@@ -24,13 +24,15 @@ typedef enum {
2424
ENGINE_IDX,
2525
INPUT_BINDING_NAMES_IDX,
2626
OUTPUT_BINDING_NAMES_IDX,
27+
HW_COMPATIBLE_IDX,
2728
SERIALIZATION_LEN, // NEVER USED FOR DATA, USED TO DETERMINE LENGTH OF SERIALIZED INFO
2829
} SerializedInfoIndex;
2930

3031
c10::optional<RTDevice> get_most_compatible_device(
3132
const RTDevice& target_device,
32-
const RTDevice& curr_device = RTDevice());
33-
std::vector<RTDevice> find_compatible_devices(const RTDevice& target_device);
33+
const RTDevice& curr_device = RTDevice(),
34+
bool hardware_compatible = false);
35+
std::vector<RTDevice> find_compatible_devices(const RTDevice& target_device, bool hardware_compatible);
3436

3537
std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intrusive_ptr<TRTEngine> compiled_engine);
3638

py/torch_tensorrt/dynamo/_compiler.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
from typing import Any, List, Optional, Sequence, Set, Tuple, Union
66

77
import torch
8-
import torch_tensorrt
98
from torch.export import ExportedProgram
109
from torch_tensorrt._Device import Device
1110
from torch_tensorrt._enums import ( # TODO: Should probabably be the TRT EngineCapability Enum
@@ -22,6 +21,7 @@
2221
DLA_SRAM_SIZE,
2322
ENABLE_EXPERIMENTAL_DECOMPOSITIONS,
2423
ENGINE_CAPABILITY,
24+
HARDWARE_COMPATIBLE,
2525
MAX_AUX_STREAMS,
2626
MIN_BLOCK_SIZE,
2727
NUM_AVG_TIMING_ITERS,
@@ -51,6 +51,8 @@
5151
to_torch_tensorrt_device,
5252
)
5353

54+
import torch_tensorrt
55+
5456
logger = logging.getLogger(__name__)
5557

5658

@@ -84,6 +86,7 @@ def compile(
8486
use_python_runtime: bool = USE_PYTHON_RUNTIME,
8587
use_fast_partitioner: bool = USE_FAST_PARTITIONER,
8688
enable_experimental_decompositions: bool = ENABLE_EXPERIMENTAL_DECOMPOSITIONS,
89+
hardware_compatible: bool = HARDWARE_COMPATIBLE,
8790
**kwargs: Any,
8891
) -> torch.fx.GraphModule:
8992
"""Compile a TorchScript module for NVIDIA GPUs using TensorRT
@@ -140,6 +143,7 @@ def compile(
140143
use_python_runtime: (bool): Return a graph using a pure Python runtime, reduces options for serialization
141144
use_fast_partitioner: (bool): Use the adjacency based partitioning scheme instead of the global partitioner. Adjacency partitioning is faster but may not be optiminal. Use the global paritioner (``False``) if looking for best performance
142145
enable_experimental_decompositions (bool): Use the full set of operator decompositions. These decompositions may not be tested but serve to make the grap easier to covert to TensorRT, potentially increasing the amount of graphs run in TensorRT.
146+
hardware_compatible (bool): Build the TensorRT engines compatible with GPU architectures other than that of the GPU on which the engine was built (currently works for NVIDIA Ampere and newer)
143147
**kwargs: Any,
144148
Returns:
145149
torch.fx.GraphModule: Compiled FX Module, when run it will execute via TensorRT
@@ -215,6 +219,7 @@ def compile(
215219
"dla_sram_size": dla_sram_size,
216220
"dla_local_dram_size": dla_local_dram_size,
217221
"dla_global_dram_size": dla_global_dram_size,
222+
"hardware_compatible": hardware_compatible,
218223
}
219224

220225
settings = CompilationSettings(**compilation_options)

py/torch_tensorrt/dynamo/_defaults.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
ENABLE_EXPERIMENTAL_DECOMPOSITIONS = False
2525
REFIT = False
2626
REQUIRE_FULL_COMPILATION = False
27+
HARDWARE_COMPATIBLE = False
2728

2829

2930
def default_device() -> Device:

py/torch_tensorrt/dynamo/_settings.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
DLA_SRAM_SIZE,
1313
ENABLE_EXPERIMENTAL_DECOMPOSITIONS,
1414
ENGINE_CAPABILITY,
15+
HARDWARE_COMPATIBLE,
1516
MAX_AUX_STREAMS,
1617
MIN_BLOCK_SIZE,
1718
NUM_AVG_TIMING_ITERS,
@@ -63,6 +64,7 @@ class CompilationSettings:
6364
dla_sram_size (int): Fast software managed RAM used by DLA to communicate within a layer.
6465
dla_local_dram_size (int): Host RAM used by DLA to share intermediate tensor data across operations
6566
dla_global_dram_size (int): Host RAM used by DLA to store weights and metadata for execution
67+
hardware_compatible (bool): Build the TensorRT engines compatible with GPU architectures other than that of the GPU on which the engine was built (currently works for NVIDIA Ampere and newer)
6668
"""
6769

6870
precision: torch.dtype = PRECISION
@@ -88,3 +90,4 @@ class CompilationSettings:
8890
dla_sram_size: int = DLA_SRAM_SIZE
8991
dla_local_dram_size: int = DLA_LOCAL_DRAM_SIZE
9092
dla_global_dram_size: int = DLA_GLOBAL_DRAM_SIZE
93+
hardware_compatible: bool = HARDWARE_COMPATIBLE

py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,11 @@ def run(
188188
if self.compilation_settings.version_compatible:
189189
_LOGGER.info("Using version compatible")
190190
builder_config.set_flag(trt.BuilderFlag.VERSION_COMPATIBLE)
191+
if self.compilation_settings.hardware_compatible:
192+
_LOGGER.info("Using hardware compatible")
193+
builder_config.hardware_compatibility_level = (
194+
trt.HardwareCompatibilityLevel.AMPERE_PLUS
195+
)
191196
if self.compilation_settings.optimization_level is not None:
192197
_LOGGER.info(
193198
f"Using optimization level {self.compilation_settings.optimization_level}"

py/torch_tensorrt/dynamo/conversion/_conversion.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,4 +76,5 @@ def convert_module(
7676
input_binding_names=list(interpreter_result.input_names),
7777
output_binding_names=list(interpreter_result.output_names),
7878
target_device=settings.device,
79+
hardware_compatible=settings.hardware_compatible,
7980
)

py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,10 @@
99
logger = logging.getLogger(__name__)
1010

1111
SerializedTensorRTEngineFmt = Tuple[
12-
str, str, bytes, str, str
12+
str, str, str, bytes, str, str, str
1313
] # Defined in //core/runtime/register_jit_hooks.cpp
1414
SerializedTorchTensorRTModuleFmt = Tuple[
15-
str, SerializedTensorRTEngineFmt, List[str], List[str]
15+
str, Optional[SerializedTensorRTEngineFmt], List[str], List[str]
1616
]
1717

1818

@@ -43,6 +43,7 @@ def __init__(
4343
input_binding_names: Optional[List[str]] = None,
4444
output_binding_names: Optional[List[str]] = None,
4545
target_device: Device = Device._current_device(),
46+
hardware_compatible: bool = False,
4647
):
4748
"""__init__ method for torch_tensorrt.dynamo.runtime._TorchTensorRTModule.TorchTensorRTModule
4849
@@ -89,6 +90,7 @@ def __init__(
8990
output_binding_names if output_binding_names is not None else []
9091
)
9192
self.name = name
93+
self.hardware_compatible = hardware_compatible
9294

9395
if serialized_engine is not None:
9496
self.engine = torch.classes.tensorrt.Engine(
@@ -99,6 +101,7 @@ def __init__(
99101
serialized_engine,
100102
TorchTensorRTModule._pack_binding_names(self.input_binding_names),
101103
TorchTensorRTModule._pack_binding_names(self.output_binding_names),
104+
str(int(hardware_compatible)),
102105
]
103106
)
104107
else:
@@ -115,7 +118,7 @@ def get_extra_state(self) -> SerializedTorchTensorRTModuleFmt:
115118
def set_extra_state(self, state: SerializedTorchTensorRTModuleFmt) -> None:
116119
self.name = state[0]
117120
if state[1] is not None:
118-
serialized_engine_info = state[1][0]
121+
serialized_engine_info: SerializedTensorRTEngineFmt = state[1]
119122
import base64
120123

121124
serialized_engine = base64.b64decode(serialized_engine_info[3])
@@ -127,13 +130,17 @@ def set_extra_state(self, state: SerializedTorchTensorRTModuleFmt) -> None:
127130
serialized_engine,
128131
serialized_engine_info[4],
129132
serialized_engine_info[5],
133+
serialized_engine_info[6],
130134
]
131135
)
132136
else:
133137
self.engine = None
134138

135139
self.input_binding_names = state[2]
136140
self.output_binding_names = state[3]
141+
self.hardware_compatible = (
142+
bool(int(state[1][6])) if state[1] is not None else False
143+
)
137144

138145
def forward(self, *inputs: Any) -> torch.Tensor | Tuple[torch.Tensor, ...]:
139146
"""Implementation of the forward pass for a TensorRT engine

0 commit comments

Comments
 (0)