Skip to content

Commit 6ef68e7

Browse files
committed
chore: rebase
1 parent 63733ee commit 6ef68e7

File tree

5 files changed

+57
-94
lines changed

5 files changed

+57
-94
lines changed

core/runtime/TRTEngine.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -130,8 +130,10 @@ struct TRTEngine : torch::CustomClassHolder {
130130
at::cuda::CUDAStream caller_stream = c10::cuda::getDefaultCUDAStream();
131131
std::vector<at::Tensor> input_buffers = {};
132132
std::vector<at::Tensor> output_buffers = {};
133-
std::string shape_key;
134-
bool prev_cudagraphs_enabled = false;
133+
std::string shape_key = "None";
134+
bool use_pre_allocated_outputs = false;
135+
std::vector<at::Tensor> pre_allocated_outputs;
136+
135137
// TODO: Implement a call method
136138
// c10::List<at::Tensor> Run(c10::List<at::Tensor> inputs);
137139

core/runtime/execute_engine.cpp

Lines changed: 10 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,7 @@ bool _validate_shapes(std::vector<at::Tensor> inputs, c10::intrusive_ptr<TRTEngi
9494
void setup_input_tensors(
9595
std::vector<at::Tensor> inputs,
9696
c10::intrusive_ptr<TRTEngine> compiled_engine,
97+
bool cudagraphs_enabled,
9798
bool need_cudagraphs_record) {
9899
// this is a buffer to store shape tensor input addresses throughout the runtime scope
99100
std::list<std::vector<int64_t>> inputShapeTensorValues;
@@ -127,7 +128,7 @@ void setup_input_tensors(
127128
compiled_engine->exec_ctx->setTensorAddress(name.c_str(), inputShapeTensorValues.back().data()),
128129
"Error while setting the tensor address for shape inputs");
129130

130-
if (CUDAGRAPHS_MODE) {
131+
if (cudagraphs_enabled) {
131132
// @peri044 I dont know if this makes sense since they are supposed to be GPU buffers
132133
compiled_engine->input_buffers[i] = input_cpu;
133134
}
@@ -147,7 +148,7 @@ void setup_input_tensors(
147148
TORCHTRT_CHECK(
148149
compiled_engine->exec_ctx->setInputShape(name.c_str(), dims), "Error while setting the input shape");
149150

150-
if (CUDAGRAPHS_MODE) {
151+
if (cudagraphs_enabled) {
151152
// If using CUDAGraphs copy formatted input to the corresponding persistent input buffer
152153
compiled_engine->input_buffers[i].copy_(formatted_inputs.back(), true);
153154
TORCHTRT_CHECK(
@@ -202,15 +203,16 @@ std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intr
202203
compiled_engine->cudagraph.enable_debug_mode();
203204
}
204205
bool cudagraphs_enabled = (CUDAGRAPHS_MODE == SUBGRAPH_CUDAGRAPHS);
206+
bool shape_changed = _validate_shapes(inputs, compiled_engine);
205207

206208
// Whether cudagraphs needs to record the graph on this pass
207-
// Cudagraphs record is required if cudagraphs_enabled is switched to True regardless of shape change
208-
bool need_cudagraphs_record = cudagraphs_enabled &&
209-
((!compiled_engine->prev_cudagraphs_enabled) || (!_cudagraphs_validate_shapes(inputs, compiled_engine)));
209+
auto result = compiled_engine->runtime_states.set_runtime_states(
210+
cudagraphs_enabled, compiled_engine->use_pre_allocated_outputs, shape_changed);
210211

211-
compiled_engine->prev_cudagraphs_enabled = cudagraphs_enabled;
212+
bool need_cudagraphs_record = std::get<0>(result);
213+
bool can_use_pre_allocated_outputs = std::get<1>(result);
212214

213-
if (!cudagraphs_enabled) {
215+
if (!cudagraphs_enabled || shape_changed) {
214216
compiled_engine->cudagraph.reset();
215217
}
216218

@@ -272,69 +274,7 @@ std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intr
272274
std::make_unique<torch::autograd::profiler::RecordProfile>(compiled_engine->input_profile_path);
273275
}
274276

275-
for (size_t i = 0; i < inputs.size(); i++) {
276-
std::string name = compiled_engine->in_binding_names[i];
277-
278-
TORCHTRT_CHECK(
279-
inputs[i].is_cuda(), "Expected input tensors to have device cuda, found device " << inputs[i].device());
280-
281-
auto expected_type =
282-
util::TRTDataTypeToScalarType(compiled_engine->exec_ctx->getEngine().getTensorDataType(name.c_str()));
283-
TORCHTRT_CHECK(
284-
inputs[i].dtype() == expected_type,
285-
"Expected input tensors to have type " << expected_type << ", found type " << inputs[i].dtype());
286-
287-
auto dims = core::util::toDims(inputs[i].sizes());
288-
auto shape = core::util::toVec(dims);
289-
LOG_DEBUG("Input Name: " << name << " Shape: " << dims);
290-
291-
if (compiled_engine->cuda_engine->isShapeInferenceIO(name.c_str())) {
292-
// Shape tensor inputs are casted to int64 explicitly.
293-
// Refer to
294-
// https://github.com/NVIDIA/TensorRT/blob/d2f4ef789a9a6ffdf37b55c3f81b486225f6b380/samples/common/sampleInference.cpp#L435
295-
auto input_cpu = inputs[i].clone().contiguous().cpu().to(torch::kInt64);
296-
std::vector<int64_t> inputs_cpu_vec(
297-
input_cpu.data_ptr<int64_t>(), input_cpu.data_ptr<int64_t>() + input_cpu.numel());
298-
inputShapeTensorValues.emplace_back(inputs_cpu_vec);
299-
TORCHTRT_CHECK(
300-
compiled_engine->exec_ctx->setTensorAddress(name.c_str(), inputShapeTensorValues.back().data()),
301-
"Error while setting the tensor address for shape inputs");
302-
303-
if (cudagraphs_enabled) {
304-
// @peri044 I dont know if this makes sense since they are supposed to be GPU buffers
305-
compiled_engine->input_buffers[i] = input_cpu;
306-
}
307-
TORCHTRT_CHECK(
308-
compiled_engine->exec_ctx->setTensorAddress(name.c_str(), inputShapeTensorValues.back().data()),
309-
"Error while setting the tensor address for shape inputs");
310-
311-
} else {
312-
at::Tensor contig_input = inputs[i].view(shape).contiguous();
313-
formatted_inputs.emplace_back(std::move(contig_input));
314-
315-
if (need_cudagraphs_record) {
316-
// Create a new persistent input buffer
317-
compiled_engine->input_buffers[i] = std::move(formatted_inputs.back().clone());
318-
}
319-
320-
TORCHTRT_CHECK(
321-
compiled_engine->exec_ctx->setInputShape(name.c_str(), dims), "Error while setting the input shape");
322-
323-
if (cudagraphs_enabled) {
324-
// If using CUDAGraphs copy formatted input to the corresponding persistent input buffer
325-
compiled_engine->input_buffers[i].copy_(formatted_inputs.back(), true);
326-
TORCHTRT_CHECK(
327-
compiled_engine->exec_ctx->setTensorAddress(name.c_str(), compiled_engine->input_buffers[i].data_ptr()),
328-
"Error while setting the input tensor address for inputs");
329-
} else {
330-
// Otherwise use the formatted buffer directly
331-
TORCHTRT_CHECK(
332-
compiled_engine->exec_ctx->setTensorAddress(name.c_str(), formatted_inputs.back().data_ptr()),
333-
"Error while setting the input tensor address for inputs");
334-
}
335-
}
336-
}
337-
277+
setup_input_tensors(inputs, compiled_engine, cudagraphs_enabled, need_cudagraphs_record);
338278
// Check if input shapes can be inferred.
339279
int32_t const io_size{compiled_engine->cuda_engine->getNbIOTensors()};
340280
std::vector<char const*> names(io_size);

docsrc/index.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ Tutorials
6767
* :ref:`custom_kernel_plugins`
6868
* :ref:`mutable_torchtrt_module_example`
6969
* :ref:`weight_streaming_example`
70+
* :ref:`pre_allocated_output_example`
7071

7172
.. toctree::
7273
:caption: Tutorials
@@ -85,6 +86,7 @@ Tutorials
8586
tutorials/_rendered_examples/dynamo/auto_generate_converters
8687
tutorials/_rendered_examples/dynamo/mutable_torchtrt_module_example
8788
tutorials/_rendered_examples/dynamo/weight_streaming_example
89+
tutorials/_rendered_examples/dynamo/pre_allocated_output_example
8890

8991
Dynamo Frontend
9092
----------------

examples/dynamo/torch_export_cudagraphs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@
4747

4848
# We can enable the cudagraphs API with a context manager
4949
with torch_tensorrt.runtime.enable_cudagraphs(opt) as cudagraphs_module:
50-
out_trt = opt(inputs)
50+
out_trt = cudagraphs_module(inputs)
5151

5252
# Alternatively, we can set the cudagraphs mode for the session
5353
torch_tensorrt.runtime.set_cudagraphs_mode(True)

py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py

Lines changed: 40 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def __init__(self, new_cudagraphs: bool, new_pre_allocated_output: bool):
3030
# Indicates whether pre-allocated output was enabled in the previous execute_engine
3131
self.old_pre_allocated_outputs = new_pre_allocated_output
3232

33-
def validate_states(
33+
def set_runtime_states(
3434
self,
3535
new_cudagraphs: bool,
3636
new_pre_allocated_output: bool,
@@ -144,8 +144,11 @@ def __init__(
144144
self.engine = None
145145
self.weight_name_map = weight_name_map
146146
self.target_platform = Platform.current_platform()
147-
# Previous cuda graphs state
148-
self.prev_cudagraphs_enabled = False
147+
self.runtime_states = TorchTRTRuntimeStates(
148+
torch_tensorrt.runtime.get_cudagraphs_mode(), False
149+
)
150+
self.pre_allocated_outputs: List[torch.Tensor] = []
151+
self.use_pre_allocated_outputs = False
149152

150153
if self.serialized_engine is not None and not self.settings.lazy_engine_init:
151154
self.setup_engine()
@@ -352,14 +355,16 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
352355
self._check_initialized()
353356

354357
cudagraphs_enabled = torch_tensorrt.runtime.get_cudagraphs_mode()
355-
shape_changed = self.cudagraphs_validate_shapes(inputs)
356-
# Cudagraphs record is required if cudagraphs_enabled is switched to True regardless of shape change
357-
need_cudagraphs_record = cudagraphs_enabled and (
358-
(not self.prev_cudagraphs_enabled) or (not shape_changed)
358+
shape_changed = self.validate_input_shapes(inputs)
359+
need_cudagraphs_record, can_use_pre_allocated_outputs = (
360+
self.runtime_states.set_runtime_states(
361+
cudagraphs_enabled, self.use_pre_allocated_outputs, shape_changed
362+
)
359363
)
360-
self.prev_cudagraphs_enabled = cudagraphs_enabled
361364

362365
if need_cudagraphs_record:
366+
if self.cudagraph:
367+
self.cudagraph.reset()
363368
self._input_buffers = [None] * len(self.input_names)
364369
self._output_buffers = [None] * len(self.output_names)
365370

@@ -423,8 +428,16 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
423428
This could happen if the input tensor addresses/shapes haven't been configured correctly"
424429
)
425430

426-
with nvtx.annotate("ProcessOutputs:1", color="red"):
427-
if not self.use_pre_allocated_outputs or shape_changed:
431+
with (
432+
torch.autograd.profiler.record_function(
433+
"PythonTorchTensorRTModule:ProcessOutputs"
434+
)
435+
if self.profiling_enabled
436+
else nullcontext()
437+
):
438+
if can_use_pre_allocated_outputs:
439+
outputs = self.pre_allocated_outputs
440+
else:
428441
self.output_shapes = [
429442
tuple(self.context.get_tensor_shape(output_name))
430443
for output_name in self.output_names
@@ -434,12 +447,12 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
434447
"Encountered dynamic output shapes during runtime. This could mean the network has data-dependent output shapes which is not currently supported."
435448
)
436449
outputs = self.create_output_tensors()
437-
else:
438-
outputs = self.pre_allocated_outputs
439450

440451
for o, output_name in enumerate(self.output_names):
452+
441453
if need_cudagraphs_record:
442454
self._output_buffers[o] = outputs[o].clone()
455+
443456
if cudagraphs_enabled:
444457
self.context.set_tensor_address(
445458
output_name, self._output_buffers[o].data_ptr()
@@ -449,7 +462,13 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
449462
output_name, outputs[o].data_ptr()
450463
)
451464

452-
with nvtx.annotate("TensorRTRuntime", color="red"):
465+
with (
466+
torch.autograd.profiler.record_function(
467+
"PythonTorchTensorRTModule:TensorRTRuntime"
468+
)
469+
if self.profiling_enabled
470+
else nullcontext()
471+
):
453472
self._caller_stream = torch.cuda.current_stream()
454473
if (
455474
self._engine_stream == torch.cuda.default_stream()
@@ -490,6 +509,9 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
490509

491510
self._caller_stream.wait_stream(self._engine_stream)
492511

512+
if self.use_pre_allocated_outputs:
513+
self.pre_allocated_outputs = self.create_output_tensors()
514+
493515
if cudagraphs_enabled:
494516
for idx, o in enumerate(outputs):
495517
o.copy_(self._output_buffers[idx])
@@ -531,10 +553,9 @@ def get_layer_info(self) -> str:
531553
)
532554
return engine_json
533555

534-
def cudagraphs_validate_shapes(self, inputs: Sequence[torch.Tensor]) -> bool:
556+
def validate_input_shapes(self, inputs: Sequence[torch.Tensor]) -> bool:
535557
"""
536-
Validates the input shapes of the forward function
537-
versus the version currently active for the
558+
Validates the input shapes of the forward function has changed
538559
"""
539560
# Representation of input shapes to a given model
540561
# Shapes are concatenated as so:
@@ -544,10 +565,8 @@ def cudagraphs_validate_shapes(self, inputs: Sequence[torch.Tensor]) -> bool:
544565
# If the new shape key differs from the existing one,
545566
# invalidate the old shape key and remove the CUDAGraph
546567
if new_shape_key != self.shape_key:
547-
logger.debug(f"Resetting Cudagraph on new shape key {new_shape_key}")
568+
logger.debug(f"Input shape changed {self.shape_key} -> {new_shape_key}")
548569
self.shape_key = new_shape_key
549-
if self.cudagraph:
550-
self.cudagraph.reset()
551-
return False
570+
return True
552571

553-
return True
572+
return False

0 commit comments

Comments
 (0)