Skip to content

feat: cherry-pick of torch.compile dynamic shapes #2750

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Apr 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .github/workflows/build-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ jobs:
os: linux
test-infra-repository: pytorch/test-infra
test-infra-ref: main
channel: test
with-rocm: false
with-cpu: false

Expand Down Expand Up @@ -208,6 +209,7 @@ jobs:
${CONDA_RUN} python -m pip install --pre pytest-xdist timm transformers parameterized expecttest==0.1.6 --use-deprecated=legacy-resolver
${CONDA_RUN} python -m pytest -n 10 --junitxml=${RUNNER_TEST_RESULTS_DIR}/torch_compile_be_test_results.xml backend/
${CONDA_RUN} python -m pytest -n 4 --junitxml=${RUNNER_TEST_RESULTS_DIR}/torch_comple_be_e2e_test_results.xml --ir torch_compile models/test_models.py
${CONDA_RUN} python -m pytest --junitxml=${RUNNER_TEST_RESULTS_DIR}/torch_compile_dyn_models_export.xml --ir torch_compile models/test_dyn_models.py
popd

tests-py-dynamo-core:
Expand Down
12 changes: 5 additions & 7 deletions py/torch_tensorrt/dynamo/_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,14 +273,12 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool:
return False
return True

# Check if the module has metadata (shape, dtype). If not, run symbolic shape propagation.
# Check if the module has metadata (shape, dtype).
if not contains_metadata(gm):
from torch._inductor.compile_fx import fake_tensor_prop

torch_inputs = get_torch_inputs(sample_inputs, settings.device)
with torch.no_grad():
# This fails if the module has data-dependent shape operators.
fake_tensor_prop(gm, torch_inputs)
# TODO: For future, explore when nodes don't have metadata and if fake_tensor_prop can resolve this.
logger.warning(
"Some nodes do not have metadata (shape and dtype information). This could lead to problems sometimes if the graph has PyTorch and TensorRT segments."
)

# Partition module into components that can be TRT-accelerated
fast_partitioner_failed = False
Expand Down
20 changes: 14 additions & 6 deletions py/torch_tensorrt/dynamo/backend/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from torch_tensorrt.dynamo.lowering import (
apply_lowering_passes,
get_decompositions,
remove_sym_nodes,
repair_input_aliasing,
)
from torch_tensorrt.dynamo.utils import (
Expand All @@ -27,7 +28,7 @@
@td.register_backend(name="tensorrt") # type: ignore[misc]
@td.register_backend(name="torch_tensorrt") # type: ignore[misc]
def torch_tensorrt_backend(
gm: torch.fx.GraphModule, sample_inputs: Sequence[torch.Tensor], **kwargs: Any
gm: torch.fx.GraphModule, sample_inputs: Sequence[Any], **kwargs: Any
) -> torch.nn.Module:
# Set log level at the top of compilation (torch_tensorrt.dynamo)
if (
Expand All @@ -44,15 +45,15 @@ def torch_tensorrt_backend(

@td.register_backend(name="aot_torch_tensorrt_aten") # type: ignore[misc]
def aot_torch_tensorrt_aten_backend(
gm: torch.fx.GraphModule, sample_inputs: Sequence[torch.Tensor], **kwargs: Any
gm: torch.fx.GraphModule, sample_inputs: Sequence[Any], **kwargs: Any
) -> torch.nn.Module:
settings = parse_dynamo_kwargs(kwargs)
return _pretraced_backend(gm, sample_inputs, settings)


def _pretraced_backend(
gm: torch.fx.GraphModule,
sample_inputs: Sequence[torch.Tensor],
sample_inputs: Sequence[Any],
settings: CompilationSettings = CompilationSettings(),
) -> torch.fx.GraphModule | Callable[..., Any]:
"""Helper function to manage translation of traced FX module to TRT engines
Expand All @@ -74,10 +75,17 @@ def _pretraced_backend(
fake_mode, "allow_non_fake_inputs", True
), fake_mode:
repair_input_aliasing(gm)

# Remove sym_int placeholders and inputs
remove_sym_nodes(gm)
torch_inputs = [
input for input in sample_inputs if isinstance(input, torch.Tensor)
]

# Invoke AOTAutograd to translate operators to aten
gm = aot_export_joint_simple(
gm,
sample_inputs,
torch_inputs,
trace_joint=False,
decompositions=get_decompositions(
settings.enable_experimental_decompositions
Expand All @@ -86,10 +94,10 @@ def _pretraced_backend(

logger.debug("Post-AOT Autograd graph:\n" + str(gm.graph))

gm = apply_lowering_passes(gm, sample_inputs)
gm = apply_lowering_passes(gm, torch_inputs)

torchtrt_inputs = prepare_inputs(
sample_inputs, disable_memory_format_check=True
torch_inputs, disable_memory_format_check=True
)
trt_compiled = compile_module(
gm,
Expand Down
1 change: 1 addition & 0 deletions py/torch_tensorrt/dynamo/conversion/impl/shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ def get_shape_with_dynamic_shape(
scale_res = scale_layer.get_output(0)

length = input_shape.shape[0]

zero_layer = ctx.net.add_constant(
input_shape.shape, np.zeros((length), dtype=np.int32)
)
Expand Down
2 changes: 1 addition & 1 deletion py/torch_tensorrt/dynamo/lowering/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,6 @@
torch_enabled_decompositions,
)
from ._decompositions import get_decompositions # noqa: F401
from ._fusers import * # noqa: F401
from ._remove_sym_nodes import remove_sym_nodes
from ._repair_input_aliasing import repair_input_aliasing
from .passes import apply_lowering_passes
82 changes: 0 additions & 82 deletions py/torch_tensorrt/dynamo/lowering/_fusers.py

This file was deleted.

30 changes: 30 additions & 0 deletions py/torch_tensorrt/dynamo/lowering/_remove_sym_nodes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import logging

import torch

logger = logging.getLogger(__name__)


def remove_sym_nodes(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
"""Remove sym_int placeholders which get inserted due to torch.compile's
dynamic=True behavior
"""
# Extract SymInt placeholder Tensors
placeholders = [
node
for node in gm.graph.nodes
if (
node.op == "placeholder"
and isinstance(node.type, type)
and issubclass(node.type, torch.SymInt)
)
]

for node in placeholders:
gm.graph.erase_node(node)

gm.graph.lint()
gm.recompile()
logger.debug(f"Removed SymInt placeholders:\n{gm.graph}")

return gm
Loading