Skip to content

Commit 78f000f

Browse files
committed
fix: Refactor tensor freezing in Dynamo
1 parent beecf35 commit 78f000f

File tree

4 files changed

+21
-228
lines changed

4 files changed

+21
-228
lines changed

py/torch_tensorrt/dynamo/backend/aot_module.py

Lines changed: 0 additions & 127 deletions
This file was deleted.

py/torch_tensorrt/dynamo/backend/backends.py

Lines changed: 21 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
import logging
22
from typing import Sequence
33
import torch
4-
from functools import partial
4+
from torch._dynamo.utils import detect_fake_mode
5+
import unittest
56
import torch._dynamo as td
67
from torch._guards import TracingContext
78

@@ -16,15 +17,13 @@
1617
partition,
1718
get_submod_inputs,
1819
)
19-
from torch_tensorrt.dynamo.lowering._freeze_aot_graph import freeze_autograd_gm
2020
from torch_tensorrt.dynamo.utils import parse_dynamo_kwargs
2121
from torch_tensorrt.dynamo.conversion import (
2222
convert_module,
2323
repair_long_or_double_inputs,
2424
)
2525

26-
from torch._functorch.aot_autograd import make_boxed_compiler
27-
from .aot_module import aot_module
26+
from torch._functorch.aot_autograd import aot_export_joint_simple
2827

2928

3029
logger = logging.getLogger(__name__)
@@ -36,8 +35,6 @@ def torch_tensorrt_backend(
3635
):
3736
DEFAULT_BACKEND = aot_torch_tensorrt_aten_backend
3837

39-
TracingContext.get().fake_mode.allow_non_fake_inputs = True
40-
4138
return DEFAULT_BACKEND(gm, sample_inputs, **kwargs)
4239

4340

@@ -47,21 +44,25 @@ def aot_torch_tensorrt_aten_backend(
4744
):
4845
settings = parse_dynamo_kwargs(kwargs)
4946

50-
custom_backend = partial(
51-
_pretraced_backend,
52-
settings=settings,
53-
)
54-
5547
# Perform Pre-AOT Lowering for Module-Level Replacement
5648
gm = pre_aot_substitutions(gm)
5749

58-
# Invoke AOTAutograd to translate operators to aten
59-
return aot_module(
60-
gm,
61-
sample_inputs,
62-
fw_compiler=make_boxed_compiler(custom_backend),
63-
decompositions=get_decompositions(),
64-
)
50+
fake_mode = detect_fake_mode(sample_inputs)
51+
52+
# Place backend tracing within FakeTensor context allowing nonfake Tensors
53+
with unittest.mock.patch.object(
54+
fake_mode, "allow_non_fake_inputs", True
55+
), fake_mode:
56+
57+
# Invoke AOTAutograd to translate operators to aten
58+
graph_module = aot_export_joint_simple(
59+
gm,
60+
sample_inputs,
61+
trace_joint=False,
62+
decompositions=get_decompositions(),
63+
)
64+
65+
return _pretraced_backend(graph_module, sample_inputs, settings)
6566

6667

6768
def _pretraced_backend(
@@ -81,16 +82,9 @@ def _pretraced_backend(
8182
try:
8283
logger.debug("Post-AOT Autograd graph:\n" + str(gm.graph))
8384

84-
frozen_gm, unfrozen_indices = freeze_autograd_gm(gm, sample_inputs)
85-
nonfrozen_inputs = [sample_inputs[idx] for idx in unfrozen_indices]
86-
87-
frozen_gm.graph.eliminate_dead_code()
88-
frozen_gm.graph.lint()
89-
frozen_gm.recompile()
90-
9185
trt_compiled = _compile_module(
92-
frozen_gm,
93-
nonfrozen_inputs,
86+
gm,
87+
sample_inputs,
9488
settings=settings,
9589
)
9690
return trt_compiled

py/torch_tensorrt/dynamo/lowering/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,4 +8,3 @@
88
from ._partition import partition, get_submod_inputs, DEFAULT_SINGLE_NODE_PARTITIONS
99
from .substitutions import *
1010
from ._fusers import *
11-
from ._freeze_aot_graph import *

py/torch_tensorrt/dynamo/lowering/_freeze_aot_graph.py

Lines changed: 0 additions & 73 deletions
This file was deleted.

0 commit comments

Comments
 (0)