Skip to content

Commit 23131c3

Browse files
committed
chore: Runtime api for pre-allocated outputs
1 parent f480353 commit 23131c3

File tree

8 files changed

+259
-21
lines changed

8 files changed

+259
-21
lines changed

core/runtime/TRTEngine.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,9 @@ TRTEngine::TRTEngine(
9999
exec_ctx = make_trt(cuda_engine->createExecutionContext());
100100
TORCHTRT_CHECK((exec_ctx.get() != nullptr), "Unable to create TensorRT execution context");
101101

102+
runtime_states.prev_cudagraphs_enabled = CUDAGRAPHS_MODE;
103+
runtime_states.prev_pre_allocated_outputs_enabled = false;
104+
102105
if (_in_binding_names.size() == 0 && _out_binding_names.size() == 0) {
103106
uint64_t inputs = 0;
104107
uint64_t outputs = 0;

core/runtime/TRTEngine.h

Lines changed: 33 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,37 @@ using FlattenedState = std::tuple<
3030
std::tuple<std::string, std::string>, // serialized metadata
3131
std::tuple<std::string, std::string>>; // Platform
3232

33+
struct RuntimeStates {
34+
bool need_cudagraphs_record;
35+
bool can_use_pre_allocated_outputs;
36+
};
37+
38+
struct TorchTRTRuntimeStates {
39+
// Previous runtime states
40+
bool prev_cudagraphs_enabled, prev_pre_allocated_outputs_enabled;
41+
42+
// Evaluates whether certain conditions are met to enable CUDA Graph recording or to reuse pre-allocated outputs
43+
// based on the current and previous states, as well as input shape has changed
44+
RuntimeStates validate_states(bool cudagraphs_enabled, bool pre_allocated_outputs_enabled, bool shape_changed) {
45+
bool need_cudagraphs_record = false;
46+
bool can_use_pre_allocated_outputs = false;
47+
48+
// Cudagraphs record is required if cudagraphs_enabled is switched to True regardless of shape change
49+
if (cudagraphs_enabled && (!prev_cudagraphs_enabled || shape_changed)) {
50+
need_cudagraphs_record = true;
51+
}
52+
// Pre-allocated output can be used when previous and current state are true without shape change
53+
if (prev_pre_allocated_outputs_enabled && pre_allocated_outputs_enabled && !shape_changed) {
54+
can_use_pre_allocated_outputs = true;
55+
}
56+
prev_cudagraphs_enabled = cudagraphs_enabled;
57+
prev_pre_allocated_outputs_enabled = pre_allocated_outputs_enabled;
58+
59+
RuntimeStates values = {need_cudagraphs_record, can_use_pre_allocated_outputs};
60+
return values;
61+
}
62+
};
63+
3364
struct TRTEngine : torch::CustomClassHolder {
3465
// Each engine needs it's own runtime object
3566
std::shared_ptr<nvinfer1::IRuntime> rt;
@@ -89,6 +120,7 @@ struct TRTEngine : torch::CustomClassHolder {
89120
int64_t get_automatic_device_memory_budget();
90121
std::vector<at::Tensor> infer_outputs(std::vector<std::vector<int64_t>> input_shapes);
91122
void set_pre_allocated_outputs(bool enable);
123+
TorchTRTRuntimeStates runtime_states;
92124
friend std::ostream& operator<<(std::ostream& os, const TRTEngine& engine);
93125
static const char BINDING_DELIM = '%';
94126

@@ -103,8 +135,7 @@ struct TRTEngine : torch::CustomClassHolder {
103135
std::vector<at::Tensor> input_buffers = {};
104136
std::vector<at::Tensor> output_buffers = {};
105137
std::string shape_key = "None";
106-
bool prev_cudagraphs_enabled = false;
107-
bool use_pre_allocated_outputs = true;
138+
bool use_pre_allocated_outputs = false;
108139
std::vector<at::Tensor> pre_allocated_outputs;
109140

110141
// TODO: Implement a call method

core/runtime/execute_engine.cpp

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -203,10 +203,9 @@ std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intr
203203
bool shape_changed = _validate_shapes(inputs, compiled_engine);
204204

205205
// Whether cudagraphs needs to record the graph on this pass
206-
// Cudagraphs record is required if cudagraphs_enabled is switched to True regardless of shape change
207-
bool need_cudagraphs_record =
208-
(((!compiled_engine->prev_cudagraphs_enabled) && CUDAGRAPHS_MODE) || (CUDAGRAPHS_MODE && shape_changed));
209-
compiled_engine->prev_cudagraphs_enabled = CUDAGRAPHS_MODE;
206+
RuntimeStates states = compiled_engine->runtime_states.validate_states(
207+
CUDAGRAPHS_MODE, compiled_engine->use_pre_allocated_outputs, shape_changed);
208+
bool need_cudagraphs_record = states.need_cudagraphs_record;
210209

211210
if (!CUDAGRAPHS_MODE || shape_changed) {
212211
compiled_engine->cudagraph.reset();
@@ -289,10 +288,10 @@ std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intr
289288
output_profiler_guard =
290289
std::make_unique<torch::autograd::profiler::RecordProfile>(compiled_engine->output_profile_path);
291290
}
292-
if (!compiled_engine->use_pre_allocated_outputs || shape_changed) {
293-
outputs = create_output_tensors(compiled_engine);
294-
} else {
291+
if (states.can_use_pre_allocated_outputs) {
295292
outputs = compiled_engine->pre_allocated_outputs;
293+
} else {
294+
outputs = create_output_tensors(compiled_engine);
296295
}
297296

298297
for (auto output_indices : compiled_engine->out_binding_map) {

py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py

Lines changed: 45 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,40 @@
2323
logger = logging.getLogger(__name__)
2424

2525

26+
class TorchTRTRuntimeStates:
27+
def __init__(self, cudagraphs_enabled: bool, pre_allocated_outputs_enabled: bool):
28+
self.prev_cudagraphs_enabled = cudagraphs_enabled
29+
self.prev_pre_allocated_outputs_enabled = pre_allocated_outputs_enabled
30+
31+
def validate_states(
32+
self,
33+
cudagraphs_enabled: bool,
34+
pre_allocated_outputs_enabled: bool,
35+
shape_changed: bool,
36+
) -> Tuple[bool, bool]:
37+
# Evaluates whether certain conditions are met to enable CUDA Graph recording or to reuse pre-allocated outputs
38+
# based on the current and previous states, as well as input shape has changed
39+
need_cudagraphs_record = False
40+
can_use_pre_allocated_outputs = False
41+
42+
# Cudagraphs record is required if cudagraphs_enabled is switched to True regardless of shape change
43+
if cudagraphs_enabled and (not self.prev_cudagraphs_enabled or shape_changed):
44+
need_cudagraphs_record = True
45+
46+
# Pre-allocated output can be used when previous and current state are true without shape change
47+
if (
48+
self.prev_pre_allocated_outputs_enabled
49+
and pre_allocated_outputs_enabled
50+
and (not shape_changed)
51+
):
52+
can_use_pre_allocated_outputs = True
53+
54+
self.prev_cudagraphs_enabled = cudagraphs_enabled
55+
self.prev_pre_allocated_outputs_enabled = pre_allocated_outputs_enabled
56+
57+
return need_cudagraphs_record, can_use_pre_allocated_outputs
58+
59+
2660
class PythonTorchTensorRTModule(Module): # type: ignore[misc]
2761
"""PythonTorchTensorRTModule is a PyTorch module which encompasses an arbitrary TensorRT Engine.
2862
@@ -107,7 +141,9 @@ def __init__(
107141
self.engine = None
108142
self.weight_name_map = weight_name_map
109143
self.target_platform = Platform.current_platform()
110-
self.prev_cudagraphs_enabled = False
144+
self.runtime_states = TorchTRTRuntimeStates(
145+
torch_tensorrt.runtime.get_cudagraphs_mode(), False
146+
)
111147
self.pre_allocated_outputs: List[torch.Tensor] = []
112148
self.use_pre_allocated_outputs = False
113149

@@ -318,13 +354,11 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
318354

319355
cudagraphs_enabled = torch_tensorrt.runtime.get_cudagraphs_mode()
320356
shape_changed = self.validate_input_shapes(inputs)
321-
# Cudagraphs record is required if cudagraphs_enabled is toggled to True regardless of shape change
322-
if not self.prev_cudagraphs_enabled and cudagraphs_enabled:
323-
need_cudagraphs_record = True
324-
else:
325-
need_cudagraphs_record = cudagraphs_enabled and shape_changed
326-
327-
self.prev_cudagraphs_enabled = cudagraphs_enabled
357+
need_cudagraphs_record, can_use_pre_allocated_outputs = (
358+
self.runtime_states.validate_states(
359+
cudagraphs_enabled, self.use_pre_allocated_outputs, shape_changed
360+
)
361+
)
328362

329363
if need_cudagraphs_record:
330364
if self.cudagraph:
@@ -399,7 +433,9 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
399433
if self.profiling_enabled
400434
else nullcontext()
401435
):
402-
if not self.use_pre_allocated_outputs or shape_changed:
436+
if can_use_pre_allocated_outputs:
437+
outputs = self.pre_allocated_outputs
438+
else:
403439
self.output_shapes = [
404440
tuple(self.context.get_tensor_shape(output_name))
405441
for output_name in self.output_names
@@ -409,8 +445,6 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
409445
"Encountered dynamic output shapes during runtime. This could mean the network has data-dependent output shapes which is not currently supported."
410446
)
411447
outputs = self.create_output_tensors()
412-
else:
413-
outputs = self.pre_allocated_outputs
414448

415449
for o, output_name in enumerate(self.output_names):
416450

py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,6 @@ def setup_engine(self) -> None:
207207
if self.engine is not None:
208208
return
209209
self.engine = torch.classes.tensorrt.Engine(self._pack_engine_info())
210-
self.set_pre_allocated_outputs(False)
211210

212211
def encode_metadata(self, metadata: Any) -> str:
213212
metadata = copy.deepcopy(metadata)

py/torch_tensorrt/runtime/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,4 +8,5 @@
88
set_cudagraphs_mode,
99
)
1010
from torch_tensorrt.runtime._multi_device_safe_mode import set_multi_device_safe_mode
11+
from torch_tensorrt.runtime._pre_allocated_outputs import enable_pre_allocated_outputs
1112
from torch_tensorrt.runtime._weight_streaming import weight_streaming
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
import logging
2+
from typing import Any
3+
4+
import torch
5+
from torch_tensorrt.dynamo.runtime import PythonTorchTensorRTModule, TorchTensorRTModule
6+
7+
logger = logging.getLogger(__name__)
8+
9+
10+
class _PreAllocatedOutputContextManager(object):
11+
"""
12+
Helper class used to enable pre-allocated output feature in runtime module
13+
"""
14+
15+
def __init__(self, module: torch.fx.GraphModule) -> None:
16+
rt_mods = []
17+
for name, rt_mod in module.named_children():
18+
if "_run_on_acc" in name and isinstance(
19+
rt_mod, (PythonTorchTensorRTModule, TorchTensorRTModule)
20+
):
21+
rt_mods.append(rt_mod)
22+
self.rt_mods = rt_mods
23+
24+
def set_pre_allocated_output(self, enable: bool) -> None:
25+
for mod in self.rt_mods:
26+
mod.set_pre_allocated_outputs(enable)
27+
28+
def __enter__(self) -> "_PreAllocatedOutputContextManager":
29+
# Enable pre-allocated output
30+
self.set_pre_allocated_output(True)
31+
return self
32+
33+
def __exit__(self, *args: Any) -> None:
34+
# Disable pre-allocated output
35+
self.set_pre_allocated_output(False)
36+
37+
38+
def enable_pre_allocated_outputs(
39+
module: torch.fx.GraphModule,
40+
) -> _PreAllocatedOutputContextManager:
41+
return _PreAllocatedOutputContextManager(module)
Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
import torch
2+
import torch_tensorrt as torchtrt
3+
from parameterized import parameterized
4+
from torch.testing._internal.common_utils import TestCase, run_tests
5+
6+
INPUT_SIZE = (3, 16, 16)
7+
TRIALS = 5
8+
9+
10+
class TestPreAllocatedOutputs(TestCase):
11+
@parameterized.expand(
12+
[
13+
("python_runtime", True),
14+
("cpp_runtime", False),
15+
]
16+
)
17+
def test_pre_allocated_outputs_default(self, _, use_python_runtime):
18+
class SampleModel(torch.nn.Module):
19+
def forward(self, x):
20+
return torch.softmax((x + 2) * 7, dim=0)
21+
22+
model = SampleModel().eval().cuda()
23+
inputs = [torch.randn(*INPUT_SIZE).cuda() for _ in range(TRIALS)]
24+
fx_graph = torch.fx.symbolic_trace(model)
25+
26+
# Validate that the results between Torch and Torch-TRT are similar
27+
optimized_model = torchtrt.compile(
28+
fx_graph,
29+
"torch_compile",
30+
inputs[0],
31+
min_block_size=1,
32+
pass_through_build_failures=True,
33+
use_python_runtime=use_python_runtime,
34+
)
35+
36+
ref_out_list = []
37+
trt_out_list = []
38+
with torchtrt.runtime.enable_pre_allocated_outputs(optimized_model):
39+
for i in inputs:
40+
ref_out_list.append(fx_graph(i).detach().cpu())
41+
trt_out_list.append(optimized_model(i).detach().cpu())
42+
43+
for torch_model_results, optimized_model_results in zip(
44+
ref_out_list, trt_out_list
45+
):
46+
torch.testing.assert_close(
47+
torch_model_results,
48+
optimized_model_results,
49+
rtol=5e-03,
50+
atol=5e-03,
51+
equal_nan=True,
52+
check_dtype=True,
53+
)
54+
55+
torch._dynamo.reset()
56+
57+
@parameterized.expand(
58+
[
59+
("python_runtime", True),
60+
("cpp_runtime", False),
61+
]
62+
)
63+
def test_pre_allocated_outputs_dynamic(self, _, use_python_runtime):
64+
class SampleModel(torch.nn.Module):
65+
def forward(self, x):
66+
return torch.relu((x + 2) * 0.5)
67+
68+
inputs = torchtrt.Input(
69+
min_shape=(1, 3, 128, 224),
70+
opt_shape=(8, 3, 192, 224),
71+
max_shape=(16, 3, 224, 224),
72+
dtype=torch.float,
73+
name="x",
74+
)
75+
fx_graph = torch.fx.symbolic_trace(SampleModel())
76+
77+
optimized_model = torchtrt.compile(
78+
fx_graph,
79+
"dynamo",
80+
inputs,
81+
min_block_size=1,
82+
pass_through_build_failures=True,
83+
torch_executed_ops={"torch.ops.aten.mul.Tensor"},
84+
use_python_runtime=use_python_runtime,
85+
)
86+
87+
input_list = []
88+
ref_out_list = []
89+
trt_out_list = []
90+
# Alternating cuda_graphs enable and input shapes at every five iterations.
91+
for i in [1, 3, 8, 11, 16]:
92+
for j in [128, 128, 222, 222, 224]:
93+
input_list.append(torch.randn((i, 3, j, 224)).cuda())
94+
95+
pre_allocated_output_ctx = torchtrt.runtime.enable_pre_allocated_outputs(
96+
optimized_model
97+
)
98+
pre_allocated_output = False
99+
for enable_cuda_graphs in [False, True]:
100+
for i in range(len(input_list)):
101+
# Toggles cuda graph at all index in TRIALS
102+
if i % TRIALS == i // TRIALS:
103+
cuda_graphs = enable_cuda_graphs
104+
else:
105+
cuda_graphs = not enable_cuda_graphs
106+
if i % 3 == 0:
107+
pre_allocated_output = not pre_allocated_output
108+
109+
torchtrt.runtime.set_cudagraphs_mode(cuda_graphs)
110+
pre_allocated_output_ctx.set_pre_allocated_output(pre_allocated_output)
111+
112+
ref_out_list.append(fx_graph(input_list[i]))
113+
trt_out_list.append(optimized_model(input_list[i]))
114+
115+
for torch_model_results, optimized_model_results in zip(
116+
ref_out_list, trt_out_list
117+
):
118+
torch.testing.assert_close(
119+
torch_model_results,
120+
optimized_model_results,
121+
rtol=5e-03,
122+
atol=5e-03,
123+
equal_nan=True,
124+
check_dtype=True,
125+
)
126+
torch._dynamo.reset()
127+
128+
129+
if __name__ == "__main__":
130+
run_tests()

0 commit comments

Comments
 (0)