Skip to content

Commit 65ea0b1

Browse files
committed
feat: Runtime output buffer optimization
1 parent 8e2c82d commit 65ea0b1

File tree

6 files changed

+108
-54
lines changed

6 files changed

+108
-54
lines changed

core/runtime/TRTEngine.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -296,6 +296,10 @@ int64_t TRTEngine::get_automatic_device_memory_budget() {
296296
return cuda_engine->getWeightStreamingAutomaticBudget();
297297
}
298298

299+
void TRTEngine::set_pre_allocated_outputs(bool enable) {
300+
use_pre_allocated_outputs = enable;
301+
}
302+
299303
std::string TRTEngine::to_str() const {
300304
// clang-format off
301305
std::stringstream ss;

core/runtime/TRTEngine.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ struct TRTEngine : torch::CustomClassHolder {
7575
bool set_device_memory_budget(int64_t budget);
7676
int64_t get_streamable_device_memory_budget();
7777
int64_t get_automatic_device_memory_budget();
78+
void set_pre_allocated_outputs(bool enable);
7879
friend std::ostream& operator<<(std::ostream& os, const TRTEngine& engine);
7980
static const char BINDING_DELIM = '%';
8081

@@ -85,6 +86,9 @@ struct TRTEngine : torch::CustomClassHolder {
8586
std::vector<at::Tensor> input_buffers = {};
8687
std::vector<at::Tensor> output_buffers = {};
8788
std::string shape_key;
89+
bool cudagraphs_enabled = false;
90+
bool use_pre_allocated_outputs = true;
91+
std::vector<at::Tensor> pre_allocated_outputs;
8892

8993
// TODO: Implement a call method
9094
// c10::List<at::Tensor> Run(c10::List<at::Tensor> inputs);

core/runtime/execute_engine.cpp

Lines changed: 45 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#include "torch/csrc/jit/runtime/custom_operator.h"
66
#include "torch/torch.h"
77

8+
#include <ATen/record_function.h>
89
#include "core/runtime/TRTEngineProfiler.h"
910
#include "core/runtime/runtime.h"
1011
#include "core/util/prelude.h"
@@ -60,9 +61,8 @@ RTDevice select_rt_device(const RTDevice& engine_device, const RTDevice& curr_de
6061
return new_target_device_opt.value();
6162
}
6263

63-
bool _cudagraphs_validate_shapes(std::vector<at::Tensor> inputs, c10::intrusive_ptr<TRTEngine> compiled_engine) {
64-
// Validate whether the current input shapes to the engine
65-
// invalidate the existing cudagraphs object
64+
bool _validate_shapes(std::vector<at::Tensor> inputs, c10::intrusive_ptr<TRTEngine> compiled_engine) {
65+
// Validate whether the current input shapes to the engine has changed
6666

6767
// Populate the shape key for the inputs
6868
// x: (3, 4), y: (4, 5) --> Key: (3,4)(4,5)
@@ -83,15 +83,32 @@ bool _cudagraphs_validate_shapes(std::vector<at::Tensor> inputs, c10::intrusive_
8383

8484
auto new_shape_key = new_shape_key_ss.str();
8585

86-
// Compare the shape key to the original key and invalidate shapes if they do not match
86+
// Compare the shape key to the original key
8787
if (new_shape_key != compiled_engine->shape_key) {
88-
LOG_DEBUG("Resetting Cudagraph on New Shape Key " << new_shape_key);
88+
LOG_DEBUG("Input shape changed " << compiled_engine->shape_key << " -> " << new_shape_key);
8989
compiled_engine->shape_key = new_shape_key;
90-
compiled_engine->cudagraph.reset();
91-
return false;
90+
return true;
9291
}
9392

94-
return true;
93+
return false;
94+
}
95+
96+
std::vector<at::Tensor> create_output_tensors(c10::intrusive_ptr<TRTEngine> compiled_engine) {
97+
std::vector<at::Tensor> outputs(compiled_engine->num_io.second);
98+
for (auto output_indices : compiled_engine->out_binding_map) {
99+
// out_binding_map stores TRT_IDX: PYT_IDX
100+
auto pyt_idx = output_indices.second;
101+
102+
std::string name = compiled_engine->out_binding_names[pyt_idx];
103+
auto out_shape = compiled_engine->exec_ctx->getTensorShape(name.c_str());
104+
LOG_DEBUG("Output Name: " << name << " Shape: " << out_shape);
105+
106+
auto dims = core::util::toVec(out_shape);
107+
auto type = util::TRTDataTypeToScalarType(compiled_engine->exec_ctx->getEngine().getTensorDataType(name.c_str()));
108+
outputs[pyt_idx] = std::move(at::empty(dims, {at::kCUDA}).to(type).contiguous());
109+
}
110+
111+
return outputs;
95112
}
96113

97114
std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intrusive_ptr<TRTEngine> compiled_engine) {
@@ -114,10 +131,15 @@ std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intr
114131
compiled_engine->cudagraph.enable_debug_mode();
115132
}
116133

134+
bool shape_changed = _validate_shapes(inputs, compiled_engine);
135+
117136
// Whether cudagraphs needs to record the graph on this pass
118-
bool need_cudagraphs_record = (CUDAGRAPHS_MODE && (!_cudagraphs_validate_shapes(inputs, compiled_engine)));
137+
// Cudagraphs record is required if cudagraphs_enabled is switched to True regardless of shape change
138+
bool need_cudagraphs_record =
139+
(((!compiled_engine->cudagraphs_enabled) && CUDAGRAPHS_MODE) || (CUDAGRAPHS_MODE && shape_changed));
140+
compiled_engine->cudagraphs_enabled = CUDAGRAPHS_MODE;
119141

120-
if (!CUDAGRAPHS_MODE) {
142+
if (!CUDAGRAPHS_MODE || shape_changed) {
121143
compiled_engine->cudagraph.reset();
122144
}
123145

@@ -178,6 +200,7 @@ std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intr
178200

179201
{ // Input Setup
180202
std::unique_ptr<torch::autograd::profiler::RecordProfile> input_profiler_guard;
203+
RECORD_FUNCTION("process input", std::vector<c10::IValue>());
181204
if (compiled_engine->profile_execution) {
182205
input_profiler_guard =
183206
std::make_unique<torch::autograd::profiler::RecordProfile>(compiled_engine->input_profile_path);
@@ -259,23 +282,20 @@ std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intr
259282

260283
{ // Output Setup
261284
std::unique_ptr<torch::autograd::profiler::RecordProfile> output_profiler_guard;
285+
RECORD_FUNCTION("process output", std::vector<c10::IValue>());
262286
if (compiled_engine->profile_execution) {
263287
output_profiler_guard =
264288
std::make_unique<torch::autograd::profiler::RecordProfile>(compiled_engine->output_profile_path);
265289
}
290+
if ((false == compiled_engine->use_pre_allocated_outputs) || shape_changed) {
291+
outputs = create_output_tensors(compiled_engine);
292+
} else {
293+
outputs = compiled_engine->pre_allocated_outputs;
294+
}
266295

267296
for (auto output_indices : compiled_engine->out_binding_map) {
268-
// out_binding_map stores TRT_IDX: PYT_IDX
269297
auto pyt_idx = output_indices.second;
270-
271298
std::string name = compiled_engine->out_binding_names[pyt_idx];
272-
auto out_shape = compiled_engine->exec_ctx->getTensorShape(name.c_str());
273-
LOG_DEBUG("Output Name: " << name << " Shape: " << out_shape);
274-
275-
auto dims = core::util::toVec(out_shape);
276-
auto type = util::TRTDataTypeToScalarType(compiled_engine->exec_ctx->getEngine().getTensorDataType(name.c_str()));
277-
outputs[pyt_idx] = std::move(at::empty(dims, {at::kCUDA}).to(type).contiguous());
278-
279299
if (need_cudagraphs_record) {
280300
// If we are recording the cuda graph then we need to update the persistent output buffer
281301
compiled_engine->output_buffers[pyt_idx] = std::move(outputs[pyt_idx].clone());
@@ -311,6 +331,7 @@ std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intr
311331
std::unique_lock<std::mutex> lock(compiled_engine->mu);
312332

313333
{ // Engine Execution (execute on engine stream)
334+
RECORD_FUNCTION("Trt runtime", std::vector<c10::IValue>());
314335
c10::cuda::CUDAStreamGuard stream_guard(compiled_engine->engine_stream);
315336

316337
std::unique_ptr<torch::autograd::profiler::RecordProfile> enqueue_profiler_guard;
@@ -345,6 +366,11 @@ std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intr
345366
}
346367
} // End engine exeuction (resets to caller stream)
347368

369+
// Create output buffer for next execution of graph or trt context.
370+
if (compiled_engine->use_pre_allocated_outputs) {
371+
compiled_engine->pre_allocated_outputs = create_output_tensors(compiled_engine);
372+
}
373+
348374
// Block caller stream until engine execution is complete
349375
at::cuda::CUDAEvent trt_exec_complete;
350376
trt_exec_complete.record(compiled_engine->engine_stream);

core/runtime/register_jit_hooks.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ static auto TORCHTRT_UNUSED TRTEngineTSRegistrtion =
8686
.def("dump_engine_layer_info_to_file", &TRTEngine::dump_engine_layer_info_to_file)
8787
.def("dump_engine_layer_info", &TRTEngine::dump_engine_layer_info)
8888
.def("get_engine_layer_info", &TRTEngine::get_engine_layer_info)
89+
.def("set_pre_allocated_outputs", &TRTEngine::set_pre_allocated_outputs)
8990
.def_property(
9091
"device_memory_budget",
9192
&TRTEngine::get_device_memory_budget,

py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py

Lines changed: 53 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,9 @@ def __init__(
107107
self.engine = None
108108
self.weight_name_map = weight_name_map
109109
self.target_platform = Platform.current_platform()
110+
self.cudagraphs_enabled = False
111+
self.pre_allocated_outputs: List[torch.Tensor] = []
112+
self.use_pre_allocated_outputs = False
110113

111114
if self.serialized_engine is not None and not self.settings.lazy_engine_init:
112115
self.setup_engine()
@@ -171,7 +174,7 @@ def setup_engine(self) -> None:
171174
self.engine.get_tensor_shape(input_name) for input_name in self.input_names
172175
]
173176
self.output_dtypes = [
174-
dtype._from(self.engine.get_tensor_dtype(output_name))
177+
dtype._from(self.engine.get_tensor_dtype(output_name)).to(torch.dtype)
175178
for output_name in self.output_names
176179
]
177180
self.output_shapes = [
@@ -232,6 +235,19 @@ def __del__(self) -> None:
232235
if self.cudagraph:
233236
self.cudagraph.reset()
234237

238+
def create_output_tensors(self) -> List[torch.Tensor]:
239+
# create output tensors
240+
outputs: List[torch.Tensor] = []
241+
242+
for o, _ in enumerate(self.output_names):
243+
output = torch.empty(
244+
size=self.output_shapes[o],
245+
dtype=self.output_dtypes[o],
246+
device=torch.cuda.current_device(),
247+
)
248+
outputs.append(output)
249+
return outputs
250+
235251
def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, ...]:
236252
# Ensure inputs are available in all scopes and cast symbolic integers to Tensors
237253
contiguous_inputs: List[torch.Tensor] = [
@@ -247,19 +263,25 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
247263
self._check_initialized()
248264

249265
cudagraphs_enabled = torch_tensorrt.runtime.get_cudagraphs_mode()
250-
need_cudagraphs_record = (
251-
cudagraphs_enabled and not self.cudagraphs_validate_shapes(inputs)
252-
)
266+
shape_changed = self.validate_input_shapes(inputs)
267+
# Cudagraphs record is required if cudagraphs_enabled is toggled to True regardless of shape change
268+
if not self.cudagraphs_enabled and cudagraphs_enabled:
269+
need_cudagraphs_record = True
270+
else:
271+
need_cudagraphs_record = cudagraphs_enabled and shape_changed
272+
self.cudagraphs_enabled = cudagraphs_enabled
253273

254274
if need_cudagraphs_record:
275+
if self.cudagraph:
276+
self.cudagraph.reset()
255277
self._input_buffers = [None] * len(self.input_names)
256278
self._output_buffers = [None] * len(self.output_names)
257279

258280
if not cudagraphs_enabled and self.cudagraph:
259281
self.cudagraph.reset()
260282
self.cudagraph = None
261283

262-
# If in safe mode, check at each iteration for for whether a switch is required
284+
# If in safe mode, check at each iteration for whether a switch is required
263285
if (
264286
torch_tensorrt.runtime._multi_device_safe_mode._PY_RT_MULTI_DEVICE_SAFE_MODE
265287
):
@@ -350,14 +372,14 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
350372
self.context.set_tensor_address(
351373
input_name, contiguous_inputs[i].data_ptr()
352374
)
353-
354-
# Check if input shapes can be inferred.
355-
uninferred_input_names = self.context.infer_shapes()
356-
if uninferred_input_names:
357-
logger.warning(
358-
f"The shapes of the inputs: {uninferred_input_names} cannot be inferred and could lead to undefined behavior. \
359-
This could happen if the input tensor addresses/shapes haven't been configured correctly"
360-
)
375+
if shape_changed:
376+
# Check if input shapes can be inferred.
377+
uninferred_input_names = self.context.infer_shapes()
378+
if uninferred_input_names:
379+
logger.warning(
380+
f"The shapes of the inputs: {uninferred_input_names} cannot be inferred and could lead to undefined behavior. \
381+
This could happen if the input tensor addresses/shapes haven't been configured correctly"
382+
)
361383

362384
with (
363385
torch.autograd.profiler.record_function(
@@ -366,24 +388,20 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
366388
if self.profiling_enabled
367389
else nullcontext()
368390
):
369-
# create output tensors
370-
outputs: List[torch.Tensor] = []
371-
372-
for o, output_name in enumerate(self.output_names):
373-
shape = tuple(self.context.get_tensor_shape(output_name))
374-
375-
if DYNAMIC_DIM in shape:
391+
if not self.use_pre_allocated_outputs or shape_changed:
392+
self.output_shapes = [
393+
tuple(self.context.get_tensor_shape(output_name))
394+
for output_name in self.output_names
395+
]
396+
if DYNAMIC_DIM in self.output_shapes:
376397
raise ValueError(
377398
"Encountered dynamic output shapes during runtime. This could mean the network has data-dependent output shapes which is not currently supported."
378399
)
400+
outputs = self.create_output_tensors()
401+
else:
402+
outputs = self.pre_allocated_outputs
379403

380-
output = torch.empty(
381-
size=shape,
382-
dtype=self.output_dtypes[o].to(torch.dtype),
383-
device=torch.cuda.current_device(),
384-
)
385-
386-
outputs.append(output)
404+
for o, output_name in enumerate(self.output_names):
387405

388406
if need_cudagraphs_record:
389407
self._output_buffers[o] = outputs[o].clone()
@@ -444,6 +462,9 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
444462

445463
self._caller_stream.wait_stream(self._engine_stream)
446464

465+
if self.use_pre_allocated_outputs:
466+
self.pre_allocated_outputs = self.create_output_tensors()
467+
447468
if cudagraphs_enabled:
448469
for idx, o in enumerate(outputs):
449470
o.copy_(self._output_buffers[idx])
@@ -485,10 +506,9 @@ def get_layer_info(self) -> str:
485506
)
486507
return engine_json
487508

488-
def cudagraphs_validate_shapes(self, inputs: Sequence[torch.Tensor]) -> bool:
509+
def validate_input_shapes(self, inputs: Sequence[torch.Tensor]) -> bool:
489510
"""
490-
Validates the input shapes of the forward function
491-
versus the version currently active for the
511+
Validates the input shapes of the forward function has changed
492512
"""
493513
# Representation of input shapes to a given model
494514
# Shapes are concatenated as so:
@@ -498,10 +518,8 @@ def cudagraphs_validate_shapes(self, inputs: Sequence[torch.Tensor]) -> bool:
498518
# If the new shape key differs from the existing one,
499519
# invalidate the old shape key and remove the CUDAGraph
500520
if new_shape_key != self.shape_key:
501-
logger.debug(f"Resetting Cudagraph on new shape key {new_shape_key}")
521+
logger.debug(f"Input shape changed {self.shape_key} -> {new_shape_key}")
502522
self.shape_key = new_shape_key
503-
if self.cudagraph:
504-
self.cudagraph.reset()
505-
return False
523+
return True
506524

507-
return True
525+
return False

py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,7 @@ def setup_engine(self) -> None:
203203
if self.engine is not None:
204204
return
205205
self.engine = torch.classes.tensorrt.Engine(self._pack_engine_info())
206+
self.engine.set_pre_allocated_outputs(True)
206207

207208
def encode_metadata(self, metadata: Any) -> str:
208209
metadata = copy.deepcopy(metadata)

0 commit comments

Comments
 (0)