Skip to content

Commit 3efcea0

Browse files
committed
fix: Add constant folding utility to freezing
1 parent ef470da commit 3efcea0

File tree

5 files changed

+29
-8
lines changed

5 files changed

+29
-8
lines changed

py/torch_tensorrt/dynamo/backend/backends.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import torch._dynamo as td
99
from torch._dynamo.utils import detect_fake_mode
1010
from torch._functorch.aot_autograd import aot_export_joint_simple
11+
from torch._inductor.freezing import ConstantFolder, replace_node_with_constant
1112
from torch_tensorrt.dynamo import CompilationSettings
1213
from torch_tensorrt.dynamo.compile import compile_module
1314
from torch_tensorrt.dynamo.lowering._decompositions import get_decompositions
@@ -62,6 +63,8 @@ def aot_torch_tensorrt_aten_backend(
6263
),
6364
)
6465

66+
constant_fold(graph_module)
67+
6568
return _pretraced_backend(graph_module, sample_inputs, settings)
6669

6770

@@ -105,3 +108,25 @@ def _pretraced_backend(
105108
+ "specify pass_through_build_failures=False."
106109
)
107110
raise
111+
112+
113+
@torch.utils._python_dispatch._disable_current_modes() # type: ignore
114+
def constant_fold(gm: torch.fx.GraphModule) -> Any:
115+
cf = ConstantFolder(gm, skip_constructors=False)
116+
cf.run()
117+
118+
for node, constant in cf.node_replacements.items():
119+
replace_node_with_constant(gm, node, constant)
120+
121+
erased_params = []
122+
for node in gm.graph.nodes:
123+
if node.op == "get_attr" and len(node.users) == 0:
124+
delattr(gm, node.target)
125+
erased_params.append(node)
126+
127+
for node in erased_params:
128+
gm.graph.erase_node(node)
129+
130+
gm.graph.eliminate_dead_code()
131+
gm.graph.lint()
132+
gm.recompile()

py/torch_tensorrt/dynamo/lowering/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
from ._decompositions import get_decompositions # noqa: F401
2-
from ._freeze_aot_graph import * # noqa: F401
32
from ._fusers import * # noqa: F401
43
from ._pre_aot_lowering import SUBSTITUTION_REGISTRY # noqa: F401
54
from ._pre_aot_lowering import register_substitution # noqa: F401

py/torch_tensorrt/dynamo/lowering/_pre_aot_lowering.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -81,10 +81,6 @@ def pre_aot_substitutions(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
8181
"""
8282
logger.debug("Pre-module replacement graph:\n" + str(gm.graph))
8383

84-
# Ensure all parameters are in inference mode
85-
for param in gm.parameters():
86-
param.requires_grad = False
87-
8884
# Iterate over graph nodes, extracting module calls, to check for interceptions
8985
for n in gm.graph.nodes:
9086
exists_in_registry = False

py/torch_tensorrt/dynamo/partitioning/_adjacency_partitioner.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,9 @@ def is_node_supported(
3939
) -> bool:
4040
node_name = ConverterRegistry.qualified_name_or_str(node.target)
4141

42-
if node in CONVERTERS and node_name not in self.torch_executed_ops:
42+
if (
43+
node in CONVERTERS or (node.op == "get_attr" and "constant" in node_name)
44+
) and node_name not in self.torch_executed_ops:
4345
# If node is a proper, supported computational node, store the operator
4446
if not node.is_impure():
4547
if node_name not in self.supported_operators:

py/torch_tensorrt/dynamo/partitioning/_global_partitioner.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -122,8 +122,7 @@ def is_node_supported(
122122
node_name = ConverterRegistry.qualified_name_or_str(node.target)
123123

124124
if (
125-
node.target in CONVERTERS.keys()
126-
or (node.op == "get_attr" and "constant" in node_name)
125+
node in CONVERTERS or (node.op == "get_attr" and "constant" in node_name)
127126
) and node_name not in self.torch_executed_ops:
128127
# If node is a proper, supported computational node, store the operator
129128
if not node.is_impure():

0 commit comments

Comments
 (0)