Skip to content

Commit 41aef0f

Browse files
committed
feat: Runtime output buffer optimization
1 parent 283a983 commit 41aef0f

File tree

6 files changed

+108
-54
lines changed

6 files changed

+108
-54
lines changed

core/runtime/TRTEngine.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -319,6 +319,10 @@ int64_t TRTEngine::get_automatic_device_memory_budget() {
319319
return cuda_engine->getWeightStreamingAutomaticBudget();
320320
}
321321

322+
void TRTEngine::set_pre_allocated_outputs(bool enable) {
323+
use_pre_allocated_outputs = enable;
324+
}
325+
322326
std::string TRTEngine::to_str() const {
323327
// clang-format off
324328
std::stringstream ss;

core/runtime/TRTEngine.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@ struct TRTEngine : torch::CustomClassHolder {
8888
int64_t get_streamable_device_memory_budget();
8989
int64_t get_automatic_device_memory_budget();
9090
std::vector<at::Tensor> infer_outputs(std::vector<std::vector<int64_t>> input_shapes);
91+
void set_pre_allocated_outputs(bool enable);
9192
friend std::ostream& operator<<(std::ostream& os, const TRTEngine& engine);
9293
static const char BINDING_DELIM = '%';
9394

@@ -102,6 +103,9 @@ struct TRTEngine : torch::CustomClassHolder {
102103
std::vector<at::Tensor> input_buffers = {};
103104
std::vector<at::Tensor> output_buffers = {};
104105
std::string shape_key;
106+
bool cudagraphs_enabled = false;
107+
bool use_pre_allocated_outputs = true;
108+
std::vector<at::Tensor> pre_allocated_outputs;
105109

106110
// TODO: Implement a call method
107111
// c10::List<at::Tensor> Run(c10::List<at::Tensor> inputs);

core/runtime/execute_engine.cpp

Lines changed: 45 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#include "torch/csrc/jit/runtime/custom_operator.h"
66
#include "torch/torch.h"
77

8+
#include <ATen/record_function.h>
89
#include "core/runtime/TRTEngineProfiler.h"
910
#include "core/runtime/runtime.h"
1011
#include "core/util/prelude.h"
@@ -60,9 +61,8 @@ RTDevice select_rt_device(const RTDevice& engine_device, const RTDevice& curr_de
6061
return new_target_device_opt.value();
6162
}
6263

63-
bool _cudagraphs_validate_shapes(std::vector<at::Tensor> inputs, c10::intrusive_ptr<TRTEngine> compiled_engine) {
64-
// Validate whether the current input shapes to the engine
65-
// invalidate the existing cudagraphs object
64+
bool _validate_shapes(std::vector<at::Tensor> inputs, c10::intrusive_ptr<TRTEngine> compiled_engine) {
65+
// Validate whether the current input shapes to the engine has changed
6666

6767
// Populate the shape key for the inputs
6868
// x: (3, 4), y: (4, 5) --> Key: (3,4)(4,5)
@@ -83,15 +83,32 @@ bool _cudagraphs_validate_shapes(std::vector<at::Tensor> inputs, c10::intrusive_
8383

8484
auto new_shape_key = new_shape_key_ss.str();
8585

86-
// Compare the shape key to the original key and invalidate shapes if they do not match
86+
// Compare the shape key to the original key
8787
if (new_shape_key != compiled_engine->shape_key) {
88-
LOG_DEBUG("Resetting Cudagraph on New Shape Key " << new_shape_key);
88+
LOG_DEBUG("Input shape changed " << compiled_engine->shape_key << " -> " << new_shape_key);
8989
compiled_engine->shape_key = new_shape_key;
90-
compiled_engine->cudagraph.reset();
91-
return false;
90+
return true;
9291
}
9392

94-
return true;
93+
return false;
94+
}
95+
96+
std::vector<at::Tensor> create_output_tensors(c10::intrusive_ptr<TRTEngine> compiled_engine) {
97+
std::vector<at::Tensor> outputs(compiled_engine->num_io.second);
98+
for (auto output_indices : compiled_engine->out_binding_map) {
99+
// out_binding_map stores TRT_IDX: PYT_IDX
100+
auto pyt_idx = output_indices.second;
101+
102+
std::string name = compiled_engine->out_binding_names[pyt_idx];
103+
auto out_shape = compiled_engine->exec_ctx->getTensorShape(name.c_str());
104+
LOG_DEBUG("Output Name: " << name << " Shape: " << out_shape);
105+
106+
auto dims = core::util::toVec(out_shape);
107+
auto type = util::TRTDataTypeToScalarType(compiled_engine->exec_ctx->getEngine().getTensorDataType(name.c_str()));
108+
outputs[pyt_idx] = std::move(at::empty(dims, {at::kCUDA}).to(type).contiguous());
109+
}
110+
111+
return outputs;
95112
}
96113

97114
std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intrusive_ptr<TRTEngine> compiled_engine) {
@@ -116,10 +133,15 @@ std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intr
116133
compiled_engine->cudagraph.enable_debug_mode();
117134
}
118135

136+
bool shape_changed = _validate_shapes(inputs, compiled_engine);
137+
119138
// Whether cudagraphs needs to record the graph on this pass
120-
bool need_cudagraphs_record = (CUDAGRAPHS_MODE && (!_cudagraphs_validate_shapes(inputs, compiled_engine)));
139+
// Cudagraphs record is required if cudagraphs_enabled is switched to True regardless of shape change
140+
bool need_cudagraphs_record =
141+
(((!compiled_engine->cudagraphs_enabled) && CUDAGRAPHS_MODE) || (CUDAGRAPHS_MODE && shape_changed));
142+
compiled_engine->cudagraphs_enabled = CUDAGRAPHS_MODE;
121143

122-
if (!CUDAGRAPHS_MODE) {
144+
if (!CUDAGRAPHS_MODE || shape_changed) {
123145
compiled_engine->cudagraph.reset();
124146
}
125147

@@ -180,6 +202,7 @@ std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intr
180202

181203
{ // Input Setup
182204
std::unique_ptr<torch::autograd::profiler::RecordProfile> input_profiler_guard;
205+
RECORD_FUNCTION("process input", std::vector<c10::IValue>());
183206
if (compiled_engine->profile_execution) {
184207
input_profiler_guard =
185208
std::make_unique<torch::autograd::profiler::RecordProfile>(compiled_engine->input_profile_path);
@@ -261,23 +284,20 @@ std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intr
261284

262285
{ // Output Setup
263286
std::unique_ptr<torch::autograd::profiler::RecordProfile> output_profiler_guard;
287+
RECORD_FUNCTION("process output", std::vector<c10::IValue>());
264288
if (compiled_engine->profile_execution) {
265289
output_profiler_guard =
266290
std::make_unique<torch::autograd::profiler::RecordProfile>(compiled_engine->output_profile_path);
267291
}
292+
if ((false == compiled_engine->use_pre_allocated_outputs) || shape_changed) {
293+
outputs = create_output_tensors(compiled_engine);
294+
} else {
295+
outputs = compiled_engine->pre_allocated_outputs;
296+
}
268297

269298
for (auto output_indices : compiled_engine->out_binding_map) {
270-
// out_binding_map stores TRT_IDX: PYT_IDX
271299
auto pyt_idx = output_indices.second;
272-
273300
std::string name = compiled_engine->out_binding_names[pyt_idx];
274-
auto out_shape = compiled_engine->exec_ctx->getTensorShape(name.c_str());
275-
LOG_DEBUG("Output Name: " << name << " Shape: " << out_shape);
276-
277-
auto dims = core::util::toVec(out_shape);
278-
auto type = util::TRTDataTypeToScalarType(compiled_engine->exec_ctx->getEngine().getTensorDataType(name.c_str()));
279-
outputs[pyt_idx] = std::move(at::empty(dims, {at::kCUDA}).to(type).contiguous());
280-
281301
if (need_cudagraphs_record) {
282302
// If we are recording the cuda graph then we need to update the persistent output buffer
283303
compiled_engine->output_buffers[pyt_idx] = std::move(outputs[pyt_idx].clone());
@@ -310,6 +330,7 @@ std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intr
310330
}
311331

312332
{ // Engine Execution (execute on engine stream)
333+
RECORD_FUNCTION("Trt runtime", std::vector<c10::IValue>());
313334
c10::cuda::CUDAStreamGuard stream_guard(compiled_engine->engine_stream);
314335

315336
std::unique_ptr<torch::autograd::profiler::RecordProfile> enqueue_profiler_guard;
@@ -344,6 +365,11 @@ std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intr
344365
}
345366
} // End engine exeuction (resets to caller stream)
346367

368+
// Create output buffer for next execution of graph or trt context.
369+
if (compiled_engine->use_pre_allocated_outputs) {
370+
compiled_engine->pre_allocated_outputs = create_output_tensors(compiled_engine);
371+
}
372+
347373
// Block caller stream until engine execution is complete
348374
at::cuda::CUDAEvent trt_exec_complete;
349375
trt_exec_complete.record(compiled_engine->engine_stream);

core/runtime/register_jit_hooks.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@ static auto TORCHTRT_UNUSED TRTEngineTSRegistrtion =
8888
.def("dump_engine_layer_info", &TRTEngine::dump_engine_layer_info)
8989
.def("get_engine_layer_info", &TRTEngine::get_engine_layer_info)
9090
.def("infer_outputs", &TRTEngine::infer_outputs)
91+
.def("set_pre_allocated_outputs", &TRTEngine::set_pre_allocated_outputs)
9192
.def_property(
9293
"device_memory_budget",
9394
&TRTEngine::get_device_memory_budget,

py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py

Lines changed: 53 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,9 @@ def __init__(
108108
self.engine = None
109109
self.weight_name_map = weight_name_map
110110
self.target_platform = Platform.current_platform()
111+
self.cudagraphs_enabled = False
112+
self.pre_allocated_outputs: List[torch.Tensor] = []
113+
self.use_pre_allocated_outputs = False
111114

112115
if self.serialized_engine is not None and not self.settings.lazy_engine_init:
113116
self.setup_engine()
@@ -172,7 +175,7 @@ def setup_engine(self) -> None:
172175
self.engine.get_tensor_shape(input_name) for input_name in self.input_names
173176
]
174177
self.output_dtypes = [
175-
dtype._from(self.engine.get_tensor_dtype(output_name))
178+
dtype._from(self.engine.get_tensor_dtype(output_name)).to(torch.dtype)
176179
for output_name in self.output_names
177180
]
178181
self.output_shapes = [
@@ -233,6 +236,19 @@ def __del__(self) -> None:
233236
if self.cudagraph:
234237
self.cudagraph.reset()
235238

239+
def create_output_tensors(self) -> List[torch.Tensor]:
240+
# create output tensors
241+
outputs: List[torch.Tensor] = []
242+
243+
for o, _ in enumerate(self.output_names):
244+
output = torch.empty(
245+
size=self.output_shapes[o],
246+
dtype=self.output_dtypes[o],
247+
device=torch.cuda.current_device(),
248+
)
249+
outputs.append(output)
250+
return outputs
251+
236252
def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, ...]:
237253
# Ensure inputs are available in all scopes and cast symbolic integers to Tensors
238254
contiguous_inputs: List[torch.Tensor] = [
@@ -248,19 +264,25 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
248264
self._check_initialized()
249265

250266
cudagraphs_enabled = torch_tensorrt.runtime.get_cudagraphs_mode()
251-
need_cudagraphs_record = (
252-
cudagraphs_enabled and not self.cudagraphs_validate_shapes(inputs)
253-
)
267+
shape_changed = self.validate_input_shapes(inputs)
268+
# Cudagraphs record is required if cudagraphs_enabled is toggled to True regardless of shape change
269+
if not self.cudagraphs_enabled and cudagraphs_enabled:
270+
need_cudagraphs_record = True
271+
else:
272+
need_cudagraphs_record = cudagraphs_enabled and shape_changed
273+
self.cudagraphs_enabled = cudagraphs_enabled
254274

255275
if need_cudagraphs_record:
276+
if self.cudagraph:
277+
self.cudagraph.reset()
256278
self._input_buffers = [None] * len(self.input_names)
257279
self._output_buffers = [None] * len(self.output_names)
258280

259281
if not cudagraphs_enabled and self.cudagraph:
260282
self.cudagraph.reset()
261283
self.cudagraph = None
262284

263-
# If in safe mode, check at each iteration for for whether a switch is required
285+
# If in safe mode, check at each iteration for whether a switch is required
264286
if (
265287
torch_tensorrt.runtime._multi_device_safe_mode._PY_RT_MULTI_DEVICE_SAFE_MODE
266288
):
@@ -351,14 +373,14 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
351373
self.context.set_tensor_address(
352374
input_name, contiguous_inputs[i].data_ptr()
353375
)
354-
355-
# Check if input shapes can be inferred.
356-
uninferred_input_names = self.context.infer_shapes()
357-
if uninferred_input_names:
358-
logger.warning(
359-
f"The shapes of the inputs: {uninferred_input_names} cannot be inferred and could lead to undefined behavior. \
360-
This could happen if the input tensor addresses/shapes haven't been configured correctly"
361-
)
376+
if shape_changed:
377+
# Check if input shapes can be inferred.
378+
uninferred_input_names = self.context.infer_shapes()
379+
if uninferred_input_names:
380+
logger.warning(
381+
f"The shapes of the inputs: {uninferred_input_names} cannot be inferred and could lead to undefined behavior. \
382+
This could happen if the input tensor addresses/shapes haven't been configured correctly"
383+
)
362384

363385
with (
364386
torch.autograd.profiler.record_function(
@@ -367,24 +389,20 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
367389
if self.profiling_enabled
368390
else nullcontext()
369391
):
370-
# create output tensors
371-
outputs: List[torch.Tensor] = []
372-
373-
for o, output_name in enumerate(self.output_names):
374-
shape = tuple(self.context.get_tensor_shape(output_name))
375-
376-
if DYNAMIC_DIM in shape:
392+
if not self.use_pre_allocated_outputs or shape_changed:
393+
self.output_shapes = [
394+
tuple(self.context.get_tensor_shape(output_name))
395+
for output_name in self.output_names
396+
]
397+
if DYNAMIC_DIM in self.output_shapes:
377398
raise ValueError(
378399
"Encountered dynamic output shapes during runtime. This could mean the network has data-dependent output shapes which is not currently supported."
379400
)
401+
outputs = self.create_output_tensors()
402+
else:
403+
outputs = self.pre_allocated_outputs
380404

381-
output = torch.empty(
382-
size=shape,
383-
dtype=self.output_dtypes[o].to(torch.dtype),
384-
device=torch.cuda.current_device(),
385-
)
386-
387-
outputs.append(output)
405+
for o, output_name in enumerate(self.output_names):
388406

389407
if need_cudagraphs_record:
390408
self._output_buffers[o] = outputs[o].clone()
@@ -445,6 +463,9 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
445463

446464
self._caller_stream.wait_stream(self._engine_stream)
447465

466+
if self.use_pre_allocated_outputs:
467+
self.pre_allocated_outputs = self.create_output_tensors()
468+
448469
if cudagraphs_enabled:
449470
for idx, o in enumerate(outputs):
450471
o.copy_(self._output_buffers[idx])
@@ -486,10 +507,9 @@ def get_layer_info(self) -> str:
486507
)
487508
return engine_json
488509

489-
def cudagraphs_validate_shapes(self, inputs: Sequence[torch.Tensor]) -> bool:
510+
def validate_input_shapes(self, inputs: Sequence[torch.Tensor]) -> bool:
490511
"""
491-
Validates the input shapes of the forward function
492-
versus the version currently active for the
512+
Validates the input shapes of the forward function has changed
493513
"""
494514
# Representation of input shapes to a given model
495515
# Shapes are concatenated as so:
@@ -499,10 +519,8 @@ def cudagraphs_validate_shapes(self, inputs: Sequence[torch.Tensor]) -> bool:
499519
# If the new shape key differs from the existing one,
500520
# invalidate the old shape key and remove the CUDAGraph
501521
if new_shape_key != self.shape_key:
502-
logger.debug(f"Resetting Cudagraph on new shape key {new_shape_key}")
522+
logger.debug(f"Input shape changed {self.shape_key} -> {new_shape_key}")
503523
self.shape_key = new_shape_key
504-
if self.cudagraph:
505-
self.cudagraph.reset()
506-
return False
524+
return True
507525

508-
return True
526+
return False

py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,7 @@ def setup_engine(self) -> None:
208208
if self.engine is not None:
209209
return
210210
self.engine = torch.classes.tensorrt.Engine(self._pack_engine_info())
211+
self.engine.set_pre_allocated_outputs(True)
211212

212213
def encode_metadata(self, metadata: Any) -> str:
213214
metadata = copy.deepcopy(metadata)

0 commit comments

Comments
 (0)