Skip to content

Commit e4df382

Browse files
committed
feat: Add preliminary support for freezing tensors in Dynamo
1 parent 64ce49b commit e4df382

File tree

6 files changed

+254
-9
lines changed

6 files changed

+254
-9
lines changed
Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
from typing import Callable, Dict, Optional
2+
3+
import torch
4+
import torch.utils._pytree as pytree
5+
from torch import nn
6+
from torch._functorch.aot_autograd import (
7+
AOT_COUNTER,
8+
AOTConfig,
9+
create_aot_dispatcher_function,
10+
create_functional_call,
11+
)
12+
from torch._functorch.partitioners import default_partition
13+
from torch._subclasses import FakeTensor
14+
15+
16+
def aot_module(
17+
mod: nn.Module,
18+
args,
19+
fw_compiler: Callable,
20+
partition_fn: Callable = default_partition,
21+
decompositions: Optional[Dict] = None,
22+
keep_inference_input_mutations=False,
23+
) -> nn.Module:
24+
"""
25+
Adapted from:
26+
https://github.com/pytorch/pytorch/blob/cce2b7e3c95a7505b41bdfc53939d84d56e31260/torch/_functorch/aot_autograd.py#L3656-L3776
27+
28+
This is the simplified or low overhead version of aot_module. For frontends
29+
like TorchDynamo, the input functions/modules to AOT are static and have
30+
unpacked inputs/outputs. This gives us an opportunity to remove the
31+
(1) pytree overhead to parse inputs/outputs,
32+
(2) AOT Autograd cache,
33+
(3) Reading of params/buffers in every forward call
34+
35+
36+
:func:`aot_module_simplified` removes these overheads.
37+
"""
38+
39+
params = {
40+
**dict(mod.named_parameters(remove_duplicate=False)),
41+
**dict(mod.named_buffers(remove_duplicate=False)),
42+
}
43+
params_flat, params_spec = pytree.tree_flatten(params)
44+
params_flat = list(params_flat)
45+
params_len = len(params_flat)
46+
47+
functional_call = create_functional_call(mod, params_spec, params_len)
48+
49+
seen_sources = set()
50+
51+
full_args = []
52+
# First, the params
53+
full_args.extend(params_flat)
54+
55+
if torch._guards.TracingContext.get():
56+
torch._guards.TracingContext.get().params_flat = params_flat
57+
58+
aot_autograd_arg_pos_to_source = None
59+
# Then, the params 1:1 mapped sources, if relevant.
60+
if hasattr(mod, "_param_name_to_source"):
61+
aot_autograd_arg_pos_to_source = []
62+
# We now know this came from dynamo, and (1) we care about guards,
63+
# so setting up aot_autograd_arg_pos_to_source for downstream dedup guards
64+
# can now be done safely. (2) Dynamo logic protects the 1:1 sizing below.
65+
for name in params.keys():
66+
assert name in mod._param_name_to_source, f"{name} not found."
67+
source = mod._param_name_to_source[name]
68+
assert source not in seen_sources, source
69+
seen_sources.add(source)
70+
aot_autograd_arg_pos_to_source.append(source)
71+
72+
# Next, the input args
73+
full_args.extend(args)
74+
75+
if hasattr(mod, "graph"):
76+
# Non dynamo entrypoints can get to here...
77+
for i, node in enumerate(mod.graph.nodes):
78+
if node.op == "placeholder":
79+
if hasattr(node, "_dynamo_source"):
80+
# ... but not here!
81+
if aot_autograd_arg_pos_to_source is None:
82+
aot_autograd_arg_pos_to_source = []
83+
source = node._dynamo_source
84+
assert source not in seen_sources, source
85+
seen_sources.add(source)
86+
aot_autograd_arg_pos_to_source.append(source)
87+
88+
if aot_autograd_arg_pos_to_source is not None:
89+
assert len(full_args) == len(aot_autograd_arg_pos_to_source)
90+
91+
dynamic_shapes = False
92+
for x in full_args:
93+
if isinstance(x, FakeTensor):
94+
dynamic_shapes = x.fake_mode.shape_env is not None
95+
break
96+
97+
aot_config = AOTConfig(
98+
fw_compiler=fw_compiler,
99+
bw_compiler=fw_compiler,
100+
inference_compiler=fw_compiler,
101+
partition_fn=partition_fn,
102+
decompositions=decompositions,
103+
num_params_buffers=params_len,
104+
aot_id=next(AOT_COUNTER),
105+
keep_inference_input_mutations=keep_inference_input_mutations,
106+
dynamic_shapes=dynamic_shapes,
107+
aot_autograd_arg_pos_to_source=aot_autograd_arg_pos_to_source,
108+
is_export=False,
109+
no_tangents=False,
110+
)
111+
112+
compiled_fn = create_aot_dispatcher_function(
113+
functional_call,
114+
full_args,
115+
aot_config,
116+
)
117+
118+
def forward(*runtime_args):
119+
full_args = []
120+
full_args.extend(runtime_args)
121+
return compiled_fn(full_args)
122+
123+
# Just for convenience
124+
forward.zero_grad = mod.zero_grad
125+
forward.named_parameters = mod.named_parameters
126+
forward.named_buffers = mod.named_buffers
127+
128+
return forward

py/torch_tensorrt/dynamo/backend/backends.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,17 @@
66

77
import torch
88
import torch._dynamo as td
9-
from torch._functorch.aot_autograd import aot_module_simplified, make_boxed_compiler
9+
from torch._functorch.aot_autograd import make_boxed_compiler
10+
from torch._guards import TracingContext
1011
from torch_tensorrt.dynamo import CompilationSettings
1112
from torch_tensorrt.dynamo.compile import compile_module
1213
from torch_tensorrt.dynamo.lowering._decompositions import get_decompositions
14+
from torch_tensorrt.dynamo.lowering._freeze_aot_graph import freeze_autograd_gm
1315
from torch_tensorrt.dynamo.lowering._pre_aot_lowering import pre_aot_substitutions
1416
from torch_tensorrt.dynamo.utils import parse_dynamo_kwargs
1517

18+
from .aot_module import aot_module
19+
1620
logger = logging.getLogger(__name__)
1721

1822

@@ -33,8 +37,9 @@ def torch_tensorrt_backend(
3337

3438
DEFAULT_BACKEND = aot_torch_tensorrt_aten_backend
3539

36-
compiled_mod: torch.nn.Module = DEFAULT_BACKEND(gm, sample_inputs, **kwargs)
37-
return compiled_mod
40+
TracingContext.get().fake_mode.allow_non_fake_inputs = True
41+
42+
return DEFAULT_BACKEND(gm, sample_inputs, **kwargs)
3843

3944

4045
@td.register_backend(name="aot_torch_tensorrt_aten") # type: ignore[misc]
@@ -52,7 +57,7 @@ def aot_torch_tensorrt_aten_backend(
5257
gm = pre_aot_substitutions(gm)
5358

5459
# Invoke AOTAutograd to translate operators to aten
55-
return aot_module_simplified(
60+
return aot_module(
5661
gm,
5762
sample_inputs,
5863
fw_compiler=make_boxed_compiler(custom_backend),
@@ -77,9 +82,16 @@ def _pretraced_backend(
7782
try:
7883
logger.debug("Post-AOT Autograd graph:\n" + str(gm.graph))
7984

85+
frozen_gm, unfrozen_indices = freeze_autograd_gm(gm, sample_inputs)
86+
nonfrozen_inputs = [sample_inputs[idx] for idx in unfrozen_indices]
87+
88+
frozen_gm.graph.eliminate_dead_code()
89+
frozen_gm.graph.lint()
90+
frozen_gm.recompile()
91+
8092
trt_compiled = compile_module(
81-
gm,
82-
sample_inputs,
93+
frozen_gm,
94+
nonfrozen_inputs,
8395
settings=settings,
8496
)
8597
return trt_compiled

py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,15 @@
33
from datetime import datetime
44
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Sequence, Set
55

6-
import numpy
6+
import numpy as np
77

88
# @manual=//deeplearning/trt/python:py_tensorrt
99
import tensorrt as trt
1010
import torch
1111
import torch.fx
1212
from torch.fx.node import _get_qualified_name
1313
from torch.fx.passes.shape_prop import TensorMetadata
14+
from torch.utils._python_dispatch import _disable_current_modes
1415
from torch_tensorrt._Input import Input
1516
from torch_tensorrt.dynamo.conversion.converter_utils import get_node_name
1617
from torch_tensorrt.fx.observer import Observer
@@ -169,7 +170,7 @@ def run(
169170

170171
cache = None
171172
if timing_cache:
172-
cache_file = numpy.array(timing_cache)
173+
cache_file = np.array(timing_cache)
173174
cache = builder_config.create_timing_cache(cache_file.tobytes())
174175
else:
175176
cache = builder_config.create_timing_cache(b"")
@@ -323,6 +324,21 @@ def call_function(self, target: str, args: Any, kwargs: Any) -> Any:
323324
assert self._cur_node_name is not None
324325
return converter(self.network, target, args, kwargs, self._cur_node_name)
325326

327+
def get_attr(self, target: str, args: Any, kwargs: Any) -> np.ndarray:
328+
with _disable_current_modes():
329+
from torch_tensorrt.fx.converters import to_numpy
330+
331+
frozen_attr = self.fetch_attr(target)
332+
333+
if isinstance(frozen_attr, torch.nn.Parameter):
334+
constant_tensor = frozen_attr.data
335+
else:
336+
constant_tensor = frozen_attr
337+
338+
network_constant = to_numpy(constant_tensor)
339+
340+
return network_constant
341+
326342
def call_method(self, target: str, args: Any, kwargs: Any) -> Any:
327343
assert isinstance(target, str)
328344
converter = CONVERTERS.get(self._cur_node)
@@ -344,6 +360,17 @@ def output(self, target: str, args: Any, kwargs: Any) -> List[Any]:
344360
else:
345361
outputs = (args[0],)
346362

363+
for output_idx in range(len(outputs)):
364+
from torch_tensorrt.fx.converters import get_trt_tensor
365+
366+
output = outputs[output_idx]
367+
368+
if not isinstance(output, trt.tensorrt.ITensor):
369+
new_output = get_trt_tensor(self.network, output, target)
370+
outputs = (
371+
outputs[:output_idx] + (new_output,) + outputs[output_idx + 1 :]
372+
)
373+
347374
if not all(isinstance(output, trt.tensorrt.ITensor) for output in outputs):
348375
raise RuntimeError("TensorRT requires all outputs to be Tensor!")
349376

py/torch_tensorrt/dynamo/lowering/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from ._decompositions import get_decompositions # noqa: F401
2+
from ._freeze_aot_graph import * # noqa: F401
23
from ._fusers import * # noqa: F401
34
from ._pre_aot_lowering import SUBSTITUTION_REGISTRY # noqa: F401
45
from ._pre_aot_lowering import register_substitution # noqa: F401
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
import unittest
2+
from typing import List, Sequence, Tuple
3+
4+
import torch
5+
import torch.fx.traceback as fx_traceback
6+
from torch._dynamo.utils import detect_fake_mode
7+
from torch._functorch.compile_utils import fx_graph_cse
8+
from torch._inductor.compile_fx import fake_tensor_prop
9+
from torch._inductor.freezing import constant_fold, replace_params_with_constants
10+
from torch.fx.experimental.proxy_tensor import make_fx
11+
from torch.fx.passes.tools_common import legalize_graph
12+
13+
14+
def freeze_autograd_gm(
15+
aot_autograd_gm: torch.fx.GraphModule,
16+
example_inputs: Sequence[torch._subclasses.FakeTensor],
17+
) -> Tuple[torch.fx.GraphModule, List[int]]:
18+
"""
19+
Adapted from:
20+
https://github.com/pytorch/pytorch/blob/750b9b359f06cb8b8c2d5b6118bba636e2112cbb/torch/_inductor/freezing.py#L186-L243
21+
22+
Inlines parameters that are not mutated into constants and optimizes the graph through constant propagation
23+
and other techniques. If enabled, the function also discards the original parameters of the module for memory efficiency.
24+
25+
Assumes that this function is run in dynamo tracing post aot_autograd.
26+
27+
Args:
28+
aot_autograd_gm (torch.fx.GraphModule): The aot_autograd constructed GraphModule to be frozen.
29+
example_inputs (List[torch.Tensor]): A list of example input tensors to be used in the freezing process.
30+
31+
Returns:
32+
Tuple[torch.fx.GraphModule, List[int]]: A tuple containing the frozen GraphModule and a list of indices
33+
of the inputs that were preserved (not turned into constants).
34+
"""
35+
# Extract necessary metadata and parameters
36+
fw_metadata = torch._guards.TracingContext.get().fw_metadata
37+
params_flat = torch._guards.TracingContext.get().params_flat
38+
assert fw_metadata is not None and params_flat is not None
39+
40+
# Replace placeholders with get_attr nodes
41+
preserved_arg_indices = replace_params_with_constants(
42+
aot_autograd_gm, params_flat, fw_metadata
43+
)
44+
45+
constant_fold(aot_autograd_gm)
46+
47+
fake_mode = detect_fake_mode(example_inputs)
48+
49+
# constant params will be real tensors, not fake
50+
# TODO: fake_mode should should enable py dispatcher if its symbolic ?
51+
with unittest.mock.patch.object(
52+
fake_mode, "allow_non_fake_inputs", True
53+
), fake_mode:
54+
args = [e for i, e in enumerate(example_inputs) if i in preserved_arg_indices]
55+
with fx_traceback.preserve_node_meta():
56+
aot_autograd_gm = make_fx(aot_autograd_gm, _allow_non_fake_inputs=True)(
57+
*args
58+
)
59+
60+
# TODO - further restrict cse ? right now needed to dedup aliasing ops
61+
cse_graph = fx_graph_cse(aot_autograd_gm.graph)
62+
aot_autograd_gm.graph = cse_graph
63+
aot_autograd_gm.recompile()
64+
65+
# Make sure meta['val'] is properly setup(weight conversion
66+
# or decompose_unfused_batchnorms lost meta['val']).
67+
aot_example_inputs = [example_inputs[ind] for ind in preserved_arg_indices]
68+
fake_tensor_prop(aot_autograd_gm, aot_example_inputs, True)
69+
70+
# TODO - apply legalization in pattern matcher
71+
legalize_graph(aot_autograd_gm)
72+
constant_fold(aot_autograd_gm)
73+
74+
return aot_autograd_gm, preserved_arg_indices

py/torch_tensorrt/dynamo/partitioning/_global_partitioner.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,10 @@ def is_node_supported(
121121
) -> bool:
122122
node_name = ConverterRegistry.qualified_name_or_str(node.target)
123123

124-
if node in CONVERTERS and node_name not in self.torch_executed_ops:
124+
if (
125+
node.target in CONVERTERS.keys()
126+
or (node.op == "get_attr" and "constant" in node_name)
127+
) and node_name not in self.torch_executed_ops:
125128
# If node is a proper, supported computational node, store the operator
126129
if not node.is_impure():
127130
if node_name not in self.supported_operators:

0 commit comments

Comments
 (0)