Skip to content

Commit 420af23

Browse files
committed
check output shape to implicitly decide whether network is dds
1 parent 9e60482 commit 420af23

File tree

6 files changed

+86
-22
lines changed

6 files changed

+86
-22
lines changed

py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,9 @@
3333
from torch_tensorrt.dynamo.conversion._ConverterRegistry import (
3434
DYNAMO_CONVERTERS as CONVERTERS,
3535
)
36-
from torch_tensorrt.dynamo.conversion._ConverterRegistry import CallingConvention
36+
from torch_tensorrt.dynamo.conversion._ConverterRegistry import (
37+
CallingConvention,
38+
)
3739
from torch_tensorrt.dynamo.conversion._TRTBuilderMonitor import TRTBulderMonitor
3840
from torch_tensorrt.dynamo.conversion.converter_utils import (
3941
get_node_io,
@@ -62,6 +64,7 @@ class TRTInterpreterResult(NamedTuple):
6264
input_names: Sequence[str]
6365
output_names: Sequence[str]
6466
weight_name_map: Optional[dict[Any, Any]]
67+
engine_is_dds: bool
6568

6669

6770
class TRTInterpreter(torch.fx.Interpreter): # type: ignore[misc]
@@ -136,6 +139,9 @@ def __init__(
136139
# Engine cache for storing and reusing TRT engines
137140
self.engine_cache = engine_cache
138141

142+
# Whether the engine is data-dependent shape (dds)
143+
self.engine_is_dds: bool = False
144+
139145
def validate_conversion(self) -> Set[str]:
140146
missing_converters: Set[str] = set()
141147

@@ -575,6 +581,7 @@ def _insert_engine_to_cache(self, hash_val: str, serialized_engine: bytes) -> No
575581
self.input_specs,
576582
self.compilation_settings,
577583
self.weight_name_map,
584+
self.engine_is_dds,
578585
),
579586
)
580587

@@ -589,6 +596,7 @@ def _pull_cached_engine(self, hash_val: str) -> Optional[TRTInterpreterResult]:
589596
cached_engine_input_specs,
590597
engine_compilation_settings,
591598
self.weight_name_map,
599+
self.engine_is_dds,
592600
) = cached_data
593601

594602
setting_compatiblity, incompattible_settings = settings_are_compatible(
@@ -650,9 +658,20 @@ def _pull_cached_engine(self, hash_val: str) -> Optional[TRTInterpreterResult]:
650658
self._input_names,
651659
self._output_names,
652660
self.weight_name_map,
661+
self.engine_is_dds,
653662
)
654663
return None
655664

665+
def check_dds(self, serialized_engine: bytes, output_names: List[str]) -> bool:
666+
runtime = trt.Runtime(TRT_LOGGER)
667+
engine = runtime.deserialize_cuda_engine(serialized_engine)
668+
669+
for output_name in output_names:
670+
output_shape = engine.get_tensor_shape(output_name)
671+
if -1 in output_shape:
672+
return True
673+
return False
674+
656675
def run(
657676
self,
658677
strict_type_constraints: bool = False,
@@ -709,6 +728,8 @@ def run(
709728
)
710729
assert serialized_engine
711730

731+
self.engine_is_dds = self.check_dds(serialized_engine, self._output_names)
732+
712733
_LOGGER.info(
713734
f"Build TRT engine elapsed time: {datetime.now() - build_engine_start_time}"
714735
)
@@ -735,6 +756,7 @@ def run(
735756
self._input_names,
736757
self._output_names,
737758
self.weight_name_map,
759+
self.engine_is_dds,
738760
)
739761

740762
def run_node(self, n: torch.fx.Node) -> torch.fx.Node:

py/torch_tensorrt/dynamo/conversion/_conversion.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def infer_module_output_dtypes(
3030
"""
3131
outputs = [node for node in module.graph.nodes if node.op == "output"]
3232
outputs = outputs[0].args
33-
return get_output_dtypes(outputs, truncate_double)
33+
return get_output_dtypes(outputs, truncate_double) # type: ignore[no-any-return]
3434

3535

3636
def interpret_module_to_result(
@@ -112,4 +112,5 @@ def convert_module(
112112
name=name,
113113
settings=settings,
114114
weight_name_map=interpreter_result.weight_name_map,
115+
engine_is_dds=interpreter_result.engine_is_dds,
115116
)

py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py

Lines changed: 24 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,7 @@ def __init__(
127127
name: str = "",
128128
settings: CompilationSettings = CompilationSettings(),
129129
weight_name_map: Optional[dict[Any, Any]] = None,
130+
engine_is_dds: bool = False,
130131
):
131132
"""Takes a name, target device, serialized TensorRT engine, and binding names / order and constructs
132133
a PyTorch ``torch.nn.Module`` around it. Uses TensorRT Python APIs to run the engine
@@ -140,6 +141,7 @@ def __init__(
140141
name (str): Name for module
141142
settings (torch_tensorrt.dynamo.CompilationSettings): Settings used to compile engine, assumes engine was built with default compilation settings if object not passed
142143
weight_name_map (dict): Mapping of engine weight name to state_dict weight name
144+
engine_is_dds (bool): Whether the engine is Data Dependent Shape
143145
144146
Example:
145147
@@ -200,7 +202,7 @@ def __init__(
200202
torch_tensorrt.runtime.get_cudagraphs_mode()
201203
)
202204

203-
self.contains_dds_layer = False
205+
self.engine_is_dds = engine_is_dds
204206
self.pre_allocated_outputs: List[torch.Tensor] = []
205207
self.use_pre_allocated_outputs = False
206208
self.output_allocator: Optional[DynamicOutputAllocator] = None
@@ -276,19 +278,12 @@ def setup_engine(self) -> None:
276278
for output_name in self.output_names
277279
]
278280

279-
self.contains_dds_layer = self._check_dds_layer()
280-
if self.contains_dds_layer:
281-
self.setup_output_allocator()
281+
if self.engine_is_dds:
282+
self.create_output_allocator()
282283

283284
if torch_tensorrt.runtime.get_cudagraphs_mode():
284285
self.cudagraph = torch.cuda.CUDAGraph()
285286

286-
def _check_dds_layer(self) -> bool:
287-
layer_info = self.get_layer_info()
288-
if "trainStation" in layer_info: # contains dds layer
289-
return True
290-
return False
291-
292287
def _check_initialized(self) -> None:
293288
if not self.initialized:
294289
raise RuntimeError("PythonTorchTensorRTModule is not initialized.")
@@ -406,19 +401,13 @@ def create_output_tensors(self) -> List[torch.Tensor]:
406401
def set_pre_allocated_outputs(self, enable: bool) -> None:
407402
self.use_pre_allocated_outputs = enable
408403

409-
def setup_output_allocator(self) -> None:
404+
def create_output_allocator(self) -> None:
410405
if self.output_allocator is None:
411406
output_dtypes_dict = {}
412407
for o, output_name in enumerate(self.output_names):
413408
output_dtypes_dict[output_name] = self.output_dtypes[o]
414409
self.output_allocator = DynamicOutputAllocator(output_dtypes_dict)
415410

416-
for output_name in self.output_names:
417-
if not self.context.set_output_allocator(
418-
output_name, self.output_allocator
419-
):
420-
raise RuntimeError(f"Failed to set output allocator for {output_name}")
421-
422411
def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, ...]:
423412

424413
def run_cuda_graph() -> torch.Tensor | Tuple[torch.Tensor, ...]:
@@ -569,6 +558,23 @@ def run_output_allocator() -> torch.Tensor | Tuple[torch.Tensor, ...]:
569558

570559
self.setup_input_tensors(contiguous_inputs, False, False)
571560

561+
with (
562+
torch.autograd.profiler.record_function(
563+
"PythonTorchTensorRTModule:SetupOutputAllocator"
564+
)
565+
if self.profiling_enabled
566+
else nullcontext()
567+
):
568+
self.create_output_allocator()
569+
# need to set output allocator every run
570+
for output_name in self.output_names:
571+
if not self.context.set_output_allocator(
572+
output_name, self.output_allocator
573+
):
574+
raise RuntimeError(
575+
f"Failed to set output allocator for {output_name}"
576+
)
577+
572578
with (
573579
torch.autograd.profiler.record_function(
574580
"PythonTorchTensorRTModule:TensorRTRuntime"
@@ -662,7 +668,7 @@ def run_output_allocator() -> torch.Tensor | Tuple[torch.Tensor, ...]:
662668
]
663669
logger.warning(f"Moved all input Tensors to cuda:{device_id}")
664670

665-
if self.contains_dds_layer:
671+
if self.engine_is_dds:
666672
return run_output_allocator()
667673
else:
668674
return run_cuda_graph()

py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ def __init__(
7979
name: str = "",
8080
settings: CompilationSettings = CompilationSettings(), # Assumes engine was built with default compilation settings if object not passed
8181
weight_name_map: Optional[dict[Any, Any]] = None,
82+
engine_is_dds: bool = False,
8283
):
8384
"""Takes a name, target device, serialized TensorRT engine, and binding names / order and constructs
8485
a PyTorch ``torch.nn.Module`` around it. Uses the Torch-TensorRT runtime extension to run the engines
@@ -97,6 +98,7 @@ def __init__(
9798
name (str): Name for module
9899
settings (torch_tensorrt.dynamo.CompilationSettings): Settings used to compile engine, assumes engine was built with default compilation settings if object not passed
99100
weight_name_map (dict): Mapping of engine weight name to state_dict weight name
101+
engine_is_dds (bool): Whether the engine is Data Dependent Shape
100102
101103
Example:
102104
@@ -132,6 +134,7 @@ def __init__(
132134
self.weight_name_map = weight_name_map
133135
self.serialized_engine = serialized_engine
134136
self.engine = None
137+
self.engine_is_dds = engine_is_dds
135138

136139
if (
137140
serialized_engine
@@ -146,7 +149,11 @@ def _pack_engine_info(self) -> List[str | bytes]:
146149
if self.settings.device is not None
147150
else Device._current_device()
148151
)
149-
metadata = {"settings": self.settings, "weight_name_map": self.weight_name_map}
152+
metadata = {
153+
"settings": self.settings,
154+
"weight_name_map": self.weight_name_map,
155+
"engine_is_dds": self.engine_is_dds,
156+
}
150157
target_platform = (
151158
Platform.current_platform()
152159
if not self.settings.enable_cross_compile_for_windows
@@ -263,6 +270,7 @@ def set_extra_state(self, state: SerializedTorchTensorRTModuleFmt) -> None:
263270
metadata = TorchTensorRTModule.decode_metadata(serialized_metadata)
264271
self.settings = metadata["settings"]
265272
self.weight_name_map = metadata["weight_name_map"]
273+
self.engine_is_dds = metadata["engine_is_dds"]
266274

267275
else:
268276
self.engine = None

tests/py/dynamo/conversion/harness.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,7 @@ def run_test(
207207
input_binding_names=list(interpreter_result.input_names),
208208
output_binding_names=list(interpreter_result.output_names),
209209
name="test_engine",
210+
engine_is_dds=interpreter_result.engine_is_dds,
210211
)
211212
mod = mod.cuda()
212213
if pyt_inputs is not None:
@@ -289,6 +290,7 @@ def run_test_custom_compare_results(
289290
input_binding_names=list(interpreter_result.input_names),
290291
output_binding_names=list(interpreter_result.output_names),
291292
name="test_engine",
293+
engine_is_dds=interpreter_result.engine_is_dds,
292294
)
293295
res_trt = trt_mod(*cuda_inputs).cpu()
294296
res_cpu = mod(*cuda_inputs).cpu()

tests/py/dynamo/conversion/test_nonzero_aten.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,33 @@ class TestNonZeroConverter(DispatchTestCase):
1919
)
2020
def test_non_zero(self, input_shape, dtype):
2121
class NonZero(nn.Module):
22+
# This is a DDS network
2223
def forward(self, input):
23-
return torch.ops.aten.nonzero.default(input)
24+
out = torch.ops.aten.nonzero.default(input)
25+
return out
26+
27+
inputs = [torch.randint(low=0, high=3, size=input_shape, dtype=dtype)]
28+
self.run_test(
29+
NonZero(),
30+
inputs,
31+
)
32+
33+
@parameterized.expand(
34+
[
35+
((10,), torch.int),
36+
((1, 20), torch.int32),
37+
((2, 3), torch.int64),
38+
((2, 3, 4), torch.float),
39+
((2, 3, 4, 5), torch.float),
40+
]
41+
)
42+
def test_non_zero(self, input_shape, dtype):
43+
class NonZero(nn.Module):
44+
# This is a static network
45+
def forward(self, input):
46+
out = torch.ops.aten.nonzero.default(input)
47+
out = torch.ops.aten.sum.dim_IntList(out, 0)
48+
return out
2449

2550
inputs = [torch.randint(low=0, high=3, size=input_shape, dtype=dtype)]
2651
self.run_test(

0 commit comments

Comments
 (0)