Skip to content

Commit 4b44ff2

Browse files
committed
fix: Add constant folding utility to freezing
1 parent 399f929 commit 4b44ff2

File tree

5 files changed

+29
-8
lines changed

5 files changed

+29
-8
lines changed

py/torch_tensorrt/dynamo/backend/backends.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import torch._dynamo as td
99
from torch._dynamo.utils import detect_fake_mode
1010
from torch._functorch.aot_autograd import aot_export_joint_simple
11+
from torch._inductor.freezing import ConstantFolder, replace_node_with_constant
1112
from torch_tensorrt.dynamo import CompilationSettings
1213
from torch_tensorrt.dynamo.compile import compile_module
1314
from torch_tensorrt.dynamo.lowering._decompositions import get_decompositions
@@ -62,6 +63,8 @@ def aot_torch_tensorrt_aten_backend(
6263
),
6364
)
6465

66+
constant_fold(graph_module)
67+
6568
return _pretraced_backend(graph_module, sample_inputs, settings)
6669

6770

@@ -105,3 +108,25 @@ def _pretraced_backend(
105108
+ "specify pass_through_build_failures=False."
106109
)
107110
raise
111+
112+
113+
@torch.utils._python_dispatch._disable_current_modes() # type: ignore
114+
def constant_fold(gm: torch.fx.GraphModule) -> Any:
115+
cf = ConstantFolder(gm, skip_constructors=False)
116+
cf.run()
117+
118+
for node, constant in cf.node_replacements.items():
119+
replace_node_with_constant(gm, node, constant)
120+
121+
erased_params = []
122+
for node in gm.graph.nodes:
123+
if node.op == "get_attr" and len(node.users) == 0:
124+
delattr(gm, node.target)
125+
erased_params.append(node)
126+
127+
for node in erased_params:
128+
gm.graph.erase_node(node)
129+
130+
gm.graph.eliminate_dead_code()
131+
gm.graph.lint()
132+
gm.recompile()

py/torch_tensorrt/dynamo/lowering/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
from ._decompositions import get_decompositions # noqa: F401
2-
from ._freeze_aot_graph import * # noqa: F401
32
from ._fusers import * # noqa: F401
43
from ._pre_aot_lowering import SUBSTITUTION_REGISTRY # noqa: F401
54
from ._pre_aot_lowering import register_substitution # noqa: F401

py/torch_tensorrt/dynamo/lowering/_pre_aot_lowering.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -81,10 +81,6 @@ def pre_aot_substitutions(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
8181
"""
8282
logger.debug("Pre-module replacement graph:\n" + str(gm.graph))
8383

84-
# Ensure all parameters are in inference mode
85-
for param in gm.parameters():
86-
param.requires_grad = False
87-
8884
# Iterate over graph nodes, extracting module calls, to check for interceptions
8985
for n in gm.graph.nodes:
9086
exists_in_registry = False

py/torch_tensorrt/dynamo/partitioning/_adjacency_partitioner.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,9 @@ def is_node_supported(
4343
) -> bool:
4444
node_name = ConverterRegistry.qualified_name_or_str(node.target)
4545

46-
if node in CONVERTERS and node_name not in self.torch_executed_ops:
46+
if (
47+
node in CONVERTERS or (node.op == "get_attr" and "constant" in node_name)
48+
) and node_name not in self.torch_executed_ops:
4749
# If node is a proper, supported computational node, store the operator
4850
if not node.is_impure():
4951
if node_name not in self.supported_operators:

py/torch_tensorrt/dynamo/partitioning/_global_partitioner.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -154,8 +154,7 @@ def is_node_supported(
154154
node_name = ConverterRegistry.qualified_name_or_str(node.target)
155155

156156
if (
157-
node.target in CONVERTERS.keys()
158-
or (node.op == "get_attr" and "constant" in node_name)
157+
node in CONVERTERS or (node.op == "get_attr" and "constant" in node_name)
159158
) and node_name not in self.torch_executed_ops:
160159
# If node is a proper, supported computational node, store the operator
161160
if not node.is_impure():

0 commit comments

Comments
 (0)