Skip to content

Commit d022f4a

Browse files
committed
fix: Refactor tensor freezing in Dynamo
1 parent e4df382 commit d022f4a

File tree

3 files changed

+22
-231
lines changed

3 files changed

+22
-231
lines changed

py/torch_tensorrt/dynamo/backend/aot_module.py

Lines changed: 0 additions & 128 deletions
This file was deleted.

py/torch_tensorrt/dynamo/backend/backends.py

Lines changed: 22 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,19 @@
11
from __future__ import annotations
22

33
import logging
4-
from functools import partial
4+
import unittest
55
from typing import Any, Callable, Sequence
66

77
import torch
88
import torch._dynamo as td
9-
from torch._functorch.aot_autograd import make_boxed_compiler
10-
from torch._guards import TracingContext
9+
from torch._dynamo.utils import detect_fake_mode
10+
from torch._functorch.aot_autograd import aot_export_joint_simple
1111
from torch_tensorrt.dynamo import CompilationSettings
1212
from torch_tensorrt.dynamo.compile import compile_module
1313
from torch_tensorrt.dynamo.lowering._decompositions import get_decompositions
14-
from torch_tensorrt.dynamo.lowering._freeze_aot_graph import freeze_autograd_gm
1514
from torch_tensorrt.dynamo.lowering._pre_aot_lowering import pre_aot_substitutions
1615
from torch_tensorrt.dynamo.utils import parse_dynamo_kwargs
1716

18-
from .aot_module import aot_module
19-
2017
logger = logging.getLogger(__name__)
2118

2219

@@ -37,8 +34,6 @@ def torch_tensorrt_backend(
3734

3835
DEFAULT_BACKEND = aot_torch_tensorrt_aten_backend
3936

40-
TracingContext.get().fake_mode.allow_non_fake_inputs = True
41-
4237
return DEFAULT_BACKEND(gm, sample_inputs, **kwargs)
4338

4439

@@ -48,21 +43,26 @@ def aot_torch_tensorrt_aten_backend(
4843
) -> torch.nn.Module:
4944
settings = parse_dynamo_kwargs(kwargs)
5045

51-
custom_backend = partial(
52-
_pretraced_backend,
53-
settings=settings,
54-
)
55-
5646
# Perform Pre-AOT Lowering for Module-Level Replacement
5747
gm = pre_aot_substitutions(gm)
5848

59-
# Invoke AOTAutograd to translate operators to aten
60-
return aot_module(
61-
gm,
62-
sample_inputs,
63-
fw_compiler=make_boxed_compiler(custom_backend),
64-
decompositions=get_decompositions(settings.enable_experimental_decompositions),
65-
)
49+
fake_mode = detect_fake_mode(sample_inputs)
50+
51+
# Place backend tracing within FakeTensor context allowing nonfake Tensors
52+
with unittest.mock.patch.object(
53+
fake_mode, "allow_non_fake_inputs", True
54+
), fake_mode:
55+
# Invoke AOTAutograd to translate operators to aten
56+
graph_module = aot_export_joint_simple(
57+
gm,
58+
sample_inputs,
59+
trace_joint=False,
60+
decompositions=get_decompositions(
61+
settings.enable_experimental_decompositions
62+
),
63+
)
64+
65+
return _pretraced_backend(graph_module, sample_inputs, settings)
6666

6767

6868
def _pretraced_backend(
@@ -82,16 +82,9 @@ def _pretraced_backend(
8282
try:
8383
logger.debug("Post-AOT Autograd graph:\n" + str(gm.graph))
8484

85-
frozen_gm, unfrozen_indices = freeze_autograd_gm(gm, sample_inputs)
86-
nonfrozen_inputs = [sample_inputs[idx] for idx in unfrozen_indices]
87-
88-
frozen_gm.graph.eliminate_dead_code()
89-
frozen_gm.graph.lint()
90-
frozen_gm.recompile()
91-
9285
trt_compiled = compile_module(
93-
frozen_gm,
94-
nonfrozen_inputs,
86+
gm,
87+
sample_inputs,
9588
settings=settings,
9689
)
9790
return trt_compiled

py/torch_tensorrt/dynamo/lowering/_freeze_aot_graph.py

Lines changed: 0 additions & 74 deletions
This file was deleted.

0 commit comments

Comments
 (0)