Skip to content

Commit 9b78a0e

Browse files
committed
fix: Add support for Dynamic Shapes
1 parent 998e560 commit 9b78a0e

File tree

2 files changed

+23
-4
lines changed

2 files changed

+23
-4
lines changed

py/torch_tensorrt/dynamo/_DryRunTracker.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -229,6 +229,21 @@ def input_formatter_helper(shapes: Any, dtypes: Any) -> str:
229229
if isinstance(shapes, tuple) and all(isinstance(elt, int) for elt in shapes):
230230
return f"Tensor: {shapes}@{str(dtypes)[6:]}, "
231231

232+
# Base case - dynamic shape, single dtype
233+
elif (
234+
isinstance(shapes, dict)
235+
and len(shapes) == 3
236+
and all(
237+
(
238+
isinstance(shape, tuple)
239+
and all(isinstance(elt, int) for elt in shape)
240+
and k in ("min_shape", "opt_shape", "max_shape")
241+
)
242+
for k, shape in shapes.items()
243+
)
244+
):
245+
return f"Tensor: {shapes}@{str(dtypes)[6:]}, "
246+
232247
# Shapes is a sequence
233248
elif isinstance(shapes, (list, tuple)):
234249
formatted_str = "List[" if isinstance(shapes, list) else "Tuple("

py/torch_tensorrt/dynamo/_compiler.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -244,7 +244,7 @@ def compile_module(
244244
dryrun_tracker.total_ops_in_graph = total_ops
245245
dryrun_tracker.supported_ops_in_graph = num_supported_ops
246246
dryrun_tracker.graph_input_shapes = parse_complex_tensor_structs(
247-
sample_inputs, "shape", tuple
247+
sample_inputs, "shape", lambda x: dict(x) if isinstance(x, dict) else tuple(x)
248248
)
249249
dryrun_tracker.graph_input_dtypes = parse_complex_tensor_structs(
250250
sample_inputs, "torch_dtype"
@@ -356,7 +356,9 @@ def compile_module(
356356
)
357357

358358
subgraph_data.subgraph_input_shapes = parse_complex_tensor_structs(
359-
submodule_inputs, "shape", tuple
359+
submodule_inputs,
360+
"shape",
361+
lambda x: dict(x) if isinstance(x, dict) else tuple(x),
360362
)
361363
subgraph_data.subgraph_input_dtypes = parse_complex_tensor_structs(
362364
submodule_inputs, "torch_dtype"
@@ -367,7 +369,9 @@ def compile_module(
367369
)
368370

369371
subgraph_data.subgraph_output_shapes = parse_complex_tensor_structs(
370-
submodule_outputs, "shape", tuple
372+
submodule_outputs,
373+
"shape",
374+
lambda x: dict(x) if isinstance(x, dict) else tuple(x),
371375
)
372376
subgraph_data.subgraph_output_dtypes = parse_complex_tensor_structs(
373377
submodule_outputs, "dtype"
@@ -395,7 +399,7 @@ def compile_module(
395399
sample_outputs = [sample_outputs]
396400

397401
dryrun_tracker.graph_output_shapes = parse_complex_tensor_structs(
398-
sample_outputs, "shape", tuple
402+
sample_outputs, "shape", lambda x: dict(x) if isinstance(x, dict) else tuple(x)
399403
)
400404
dryrun_tracker.graph_output_dtypes = parse_complex_tensor_structs(
401405
sample_outputs, "dtype"

0 commit comments

Comments
 (0)