Skip to content

Commit d7768ee

Browse files
authored
feat: Add dynamic shapes support for torch.compile workflow (#2627)
Signed-off-by: Dheeraj Peri <[email protected]>
1 parent c5d6e16 commit d7768ee

File tree

8 files changed

+157
-133
lines changed

8 files changed

+157
-133
lines changed

.github/workflows/build-test.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ jobs:
2121
os: linux
2222
test-infra-repository: pytorch/test-infra
2323
test-infra-ref: main
24+
channel: test
2425
with-rocm: false
2526
with-cpu: false
2627

@@ -197,6 +198,7 @@ jobs:
197198
${CONDA_RUN} python -m pip install --pre pytest-xdist timm transformers parameterized expecttest==0.1.6 --use-deprecated=legacy-resolver
198199
${CONDA_RUN} python -m pytest -n 10 --junitxml=${RUNNER_TEST_RESULTS_DIR}/torch_compile_be_test_results.xml backend/
199200
${CONDA_RUN} python -m pytest -n 4 --junitxml=${RUNNER_TEST_RESULTS_DIR}/torch_comple_be_e2e_test_results.xml --ir torch_compile models/test_models.py
201+
${CONDA_RUN} python -m pytest --junitxml=${RUNNER_TEST_RESULTS_DIR}/torch_compile_dyn_models_export.xml --ir torch_compile models/test_dyn_models.py
200202
popd
201203
202204
tests-py-dynamo-core:

py/torch_tensorrt/dynamo/_compiler.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -317,14 +317,12 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool:
317317
return False
318318
return True
319319

320-
# Check if the module has metadata (shape, dtype). If not, run symbolic shape propagation.
320+
# Check if the module has metadata (shape, dtype).
321321
if not contains_metadata(gm):
322-
from torch._inductor.compile_fx import fake_tensor_prop
323-
324-
torch_inputs = get_torch_inputs(sample_inputs, settings.device)
325-
with torch.no_grad():
326-
# This fails if the module has data-dependent shape operators.
327-
fake_tensor_prop(gm, torch_inputs)
322+
# TODO: For future, explore when nodes don't have metadata and if fake_tensor_prop can resolve this.
323+
logger.warning(
324+
"Some nodes do not have metadata (shape and dtype information). This could lead to problems sometimes if the graph has PyTorch and TensorRT segments."
325+
)
328326

329327
# Partition module into components that can be TRT-accelerated
330328
fast_partitioner_failed = False

py/torch_tensorrt/dynamo/backend/backends.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from torch_tensorrt.dynamo.lowering import (
1414
apply_lowering_passes,
1515
get_decompositions,
16+
remove_sym_nodes,
1617
repair_input_aliasing,
1718
)
1819
from torch_tensorrt.dynamo.utils import (
@@ -27,7 +28,7 @@
2728
@td.register_backend(name="tensorrt") # type: ignore[misc]
2829
@td.register_backend(name="torch_tensorrt") # type: ignore[misc]
2930
def torch_tensorrt_backend(
30-
gm: torch.fx.GraphModule, sample_inputs: Sequence[torch.Tensor], **kwargs: Any
31+
gm: torch.fx.GraphModule, sample_inputs: Sequence[Any], **kwargs: Any
3132
) -> torch.nn.Module:
3233
# Set log level at the top of compilation (torch_tensorrt.dynamo)
3334
if (
@@ -44,15 +45,15 @@ def torch_tensorrt_backend(
4445

4546
@td.register_backend(name="aot_torch_tensorrt_aten") # type: ignore[misc]
4647
def aot_torch_tensorrt_aten_backend(
47-
gm: torch.fx.GraphModule, sample_inputs: Sequence[torch.Tensor], **kwargs: Any
48+
gm: torch.fx.GraphModule, sample_inputs: Sequence[Any], **kwargs: Any
4849
) -> torch.nn.Module:
4950
settings = parse_dynamo_kwargs(kwargs)
5051
return _pretraced_backend(gm, sample_inputs, settings)
5152

5253

5354
def _pretraced_backend(
5455
gm: torch.fx.GraphModule,
55-
sample_inputs: Sequence[torch.Tensor],
56+
sample_inputs: Sequence[Any],
5657
settings: CompilationSettings = CompilationSettings(),
5758
) -> torch.fx.GraphModule | Callable[..., Any]:
5859
"""Helper function to manage translation of traced FX module to TRT engines
@@ -74,10 +75,17 @@ def _pretraced_backend(
7475
fake_mode, "allow_non_fake_inputs", True
7576
), fake_mode:
7677
repair_input_aliasing(gm)
78+
79+
# Remove sym_int placeholders and inputs
80+
remove_sym_nodes(gm)
81+
torch_inputs = [
82+
input for input in sample_inputs if isinstance(input, torch.Tensor)
83+
]
84+
7785
# Invoke AOTAutograd to translate operators to aten
7886
gm = aot_export_joint_simple(
7987
gm,
80-
sample_inputs,
88+
torch_inputs,
8189
trace_joint=False,
8290
decompositions=get_decompositions(
8391
settings.enable_experimental_decompositions
@@ -86,10 +94,10 @@ def _pretraced_backend(
8694

8795
logger.debug("Post-AOT Autograd graph:\n" + str(gm.graph))
8896

89-
gm = apply_lowering_passes(gm, sample_inputs)
97+
gm = apply_lowering_passes(gm, torch_inputs)
9098

9199
torchtrt_inputs = prepare_inputs(
92-
sample_inputs, disable_memory_format_check=True
100+
torch_inputs, disable_memory_format_check=True
93101
)
94102
trt_compiled = compile_module(
95103
gm,

py/torch_tensorrt/dynamo/conversion/impl/shape.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
from torch_tensorrt.dynamo.conversion.converter_utils import (
1212
get_positive_dim,
1313
get_trt_tensor,
14-
to_numpy,
1514
)
1615
from torch_tensorrt.dynamo.conversion.impl.elementwise.base import (
1716
convert_binary_elementwise,
@@ -87,8 +86,9 @@ def get_shape_with_dynamic_shape(
8786
scale_res = scale_layer.get_output(0)
8887

8988
length = input_shape.shape[0]
89+
9090
zero_layer = ctx.net.add_constant(
91-
input_shape.shape, to_numpy(torch.zeros((length), dtype=torch.int32))
91+
input_shape.shape, np.zeros((length), dtype=np.int32)
9292
)
9393
set_layer_name(zero_layer, target, f"{name}_zeros")
9494

py/torch_tensorrt/dynamo/lowering/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,6 @@
33
torch_enabled_decompositions,
44
)
55
from ._decompositions import get_decompositions # noqa: F401
6-
from ._fusers import * # noqa: F401
6+
from ._remove_sym_nodes import remove_sym_nodes
77
from ._repair_input_aliasing import repair_input_aliasing
88
from .passes import apply_lowering_passes

py/torch_tensorrt/dynamo/lowering/_fusers.py

Lines changed: 0 additions & 82 deletions
This file was deleted.
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
import logging
2+
3+
import torch
4+
5+
logger = logging.getLogger(__name__)
6+
7+
8+
def remove_sym_nodes(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
9+
"""Remove sym_int placeholders which get inserted due to torch.compile's
10+
dynamic=True behavior
11+
"""
12+
# Extract SymInt placeholder Tensors
13+
placeholders = [
14+
node
15+
for node in gm.graph.nodes
16+
if (
17+
node.op == "placeholder"
18+
and isinstance(node.type, type)
19+
and issubclass(node.type, torch.SymInt)
20+
)
21+
]
22+
23+
for node in placeholders:
24+
gm.graph.erase_node(node)
25+
26+
gm.graph.lint()
27+
gm.recompile()
28+
logger.debug(f"Removed SymInt placeholders:\n{gm.graph}")
29+
30+
return gm

0 commit comments

Comments
 (0)