Skip to content

Commit 2759a7e

Browse files
committed
Propagate buffer and parameter indices through AOT
ghstack-source-id: 90b6a05 Pull Request resolved: #130393
1 parent 2021eb8 commit 2759a7e

File tree

6 files changed

+23
-19
lines changed

6 files changed

+23
-19
lines changed

torch/_functorch/_aot_autograd/collect_metadata_analysis.py

Lines changed: 4 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import collections
1212
import logging
1313
from functools import wraps
14-
from typing import Callable, DefaultDict, Dict, List
14+
from typing import Callable, DefaultDict, Dict, List, Optional
1515

1616
import torch
1717
import torch.utils._pytree as pytree
@@ -25,6 +25,7 @@
2525
is_traceable_wrapper_subclass,
2626
transform_subclass,
2727
)
28+
2829
from .functional_utils import (
2930
are_all_mutations_hidden_from_autograd,
3031
are_all_mutations_under_no_grad_or_inference_mode,
@@ -124,6 +125,7 @@ def run_functionalized_fw_and_collect_metadata(
124125
keep_input_mutations: bool,
125126
# TODO: refactor to kill this flag
126127
is_train: bool = False,
128+
static_input_indices: Optional[List[int]] = None,
127129
pre_dispatch: bool = False,
128130
) -> Callable[..., ViewAndMutationMeta]:
129131
memo: Dict[Tensor, Tensor] = {}
@@ -666,18 +668,6 @@ def view_avoid_dupes_with_primals(t):
666668
)
667669
user_outs = pytree.tree_map(from_fun, f_output_tangents)
668670

669-
if (
670-
torch._dynamo.config.inline_inbuilt_nn_modules
671-
or torch._dynamo.compiled_autograd.in_compiled_autograd_region
672-
):
673-
static_parameter_input_indices = [
674-
i
675-
for i, arg in enumerate(flat_args)
676-
if isinstance(arg, torch.nn.Parameter)
677-
]
678-
else:
679-
static_parameter_input_indices = []
680-
681671
f_mutated_inputs = [
682672
inp
683673
for inp, info in zip(flat_f_args, input_info)
@@ -729,7 +719,7 @@ def view_avoid_dupes_with_primals(t):
729719
subclass_tangent_meta=create_subclass_meta(traced_tangents),
730720
is_train=is_train,
731721
grad_enabled_mutation=grad_enabled_mutation,
732-
static_parameter_indices=static_parameter_input_indices,
722+
static_input_indices=static_input_indices if static_input_indices else [],
733723
tokens=mode._tokens,
734724
)
735725
return metadata

torch/_functorch/_aot_autograd/runtime_wrappers.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -905,6 +905,7 @@ def wrapped_flat_fn(*args):
905905
if config.debug_assert:
906906
ref_fw_metadata = run_functionalized_fw_and_collect_metadata(
907907
wrapped_flat_fn,
908+
static_input_indices=aot_config.static_input_indices,
908909
keep_input_mutations=fw_metadata.keep_input_mutations,
909910
is_train=fw_metadata.is_train,
910911
)(*deduped_flat_args)
@@ -1094,6 +1095,7 @@ def wrapped_flat_fn(*args):
10941095
if config.debug_assert:
10951096
ref_fw_metadata = run_functionalized_fw_and_collect_metadata(
10961097
wrapped_flat_fn,
1098+
static_input_indices=aot_config.static_input_indices,
10971099
keep_input_mutations=fw_metadata.keep_input_mutations,
10981100
is_train=fw_metadata.is_train,
10991101
)(*flat_args_with_synthetic_bases)

torch/_functorch/_aot_autograd/schemas.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -328,7 +328,7 @@ class ViewAndMutationMeta:
328328
deterministic: Optional[bool] = None
329329

330330
# Keeps track of which input indices store parameters (which we will treat as static)
331-
static_parameter_indices: List[int] = field(default_factory=list)
331+
static_input_indices: List[int] = field(default_factory=list)
332332

333333
# Map of effect type (ex. _EffectType.ORDERED) to token. If there are
334334
# side-effectful operators, FunctionalTensorMode will populate this
@@ -802,6 +802,7 @@ class AOTConfig:
802802
no_tangents: bool = False
803803
dynamic_shapes: bool = False
804804
aot_autograd_arg_pos_to_source: Optional[List[Source]] = None
805+
static_input_indices: Optional[List[int]] = None
805806
inference_compiler: Optional[Callable] = None
806807
enable_log: bool = True
807808
# this is always false outside of export.

torch/_functorch/_aot_autograd/traced_function_transforms.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -729,6 +729,7 @@ def metadata_fn(*primals):
729729
# See Note: [Partitioner handling for Subclasses, Part 2] for more info.
730730
meta_updated = run_functionalized_fw_and_collect_metadata(
731731
metadata_fn,
732+
static_input_indices=meta.static_input_indices,
732733
keep_input_mutations=meta.keep_input_mutations,
733734
is_train=meta.is_train,
734735
)(*primals_unwrapped)

torch/_functorch/aot_autograd.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from torch.fx.experimental.proxy_tensor import make_fx
2121
from torch.fx.experimental.symbolic_shapes import ShapeEnv
2222
from torch.utils._python_dispatch import is_traceable_wrapper_subclass
23+
2324
from . import config
2425
from ._aot_autograd.autograd_cache import ( # noqa: F401
2526
AOTAutogradCache,
@@ -583,11 +584,13 @@ def _dup_fake_script_obj(fake_flat_args):
583584
with ctx:
584585
fw_metadata = run_functionalized_fw_and_collect_metadata(
585586
flat_fn,
587+
static_input_indices=aot_config.static_input_indices,
586588
keep_input_mutations=aot_config.keep_inference_input_mutations,
587589
is_train=needs_autograd,
588590
pre_dispatch=aot_config.pre_dispatch,
589591
)(*_dup_fake_script_obj(fake_flat_args))
590592

593+
fw_metadata.static_input_indices = aot_config.static_input_indices
591594
req_subclass_dispatch = requires_subclass_dispatch(
592595
fake_flat_args, fw_metadata
593596
)
@@ -631,7 +634,7 @@ def _dup_fake_script_obj(fake_flat_args):
631634
subclass_tangent_meta=fw_metadata.subclass_tangent_meta,
632635
is_train=False,
633636
tokens=fw_metadata.tokens,
634-
static_parameter_indices=fw_metadata.static_parameter_indices,
637+
static_input_indices=fw_metadata.static_input_indices,
635638
)
636639

637640
if fw_metadata.num_intermediate_bases > 0:
@@ -936,9 +939,10 @@ def aot_module_simplified(
936939
# Next, the input args
937940
full_args.extend(args)
938941

942+
static_input_indices = []
939943
if hasattr(mod, "graph"):
940944
# Non dynamo entrypoints can get to here...
941-
for node in mod.graph.find_nodes(op="placeholder"):
945+
for pos, node in enumerate(mod.graph.find_nodes(op="placeholder")):
942946
if hasattr(node, "_dynamo_source"):
943947
# ... but not here!
944948
if aot_autograd_arg_pos_to_source is None:
@@ -948,6 +952,11 @@ def aot_module_simplified(
948952
seen_sources.add(source)
949953
aot_autograd_arg_pos_to_source.append(source)
950954

955+
if "tensor_dict" in node.meta and node.meta["tensor_dict"].get(
956+
"_dynamo_static_input_type", None
957+
):
958+
static_input_indices.append(pos)
959+
951960
if aot_autograd_arg_pos_to_source is not None:
952961
assert len(full_args) == len(aot_autograd_arg_pos_to_source)
953962

@@ -968,6 +977,7 @@ def aot_module_simplified(
968977
keep_inference_input_mutations=keep_inference_input_mutations,
969978
dynamic_shapes=dynamic_shapes,
970979
aot_autograd_arg_pos_to_source=aot_autograd_arg_pos_to_source,
980+
static_input_indices=static_input_indices,
971981
is_export=False,
972982
no_tangents=False,
973983
cache_key=None,

torch/_inductor/compile_fx.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ def get_static_input_idxs(num_fixed):
137137
if not context or not context.fw_metadata:
138138
return fixed
139139

140-
return fixed + context.fw_metadata.static_parameter_indices
140+
return fixed + context.fw_metadata.static_input_indices
141141

142142

143143
@functools.lru_cache(None)
@@ -1254,7 +1254,7 @@ def fw_compiler_freezing(
12541254
params_flat[i] = None
12551255

12561256
if tracing_context.fw_metadata:
1257-
static_input_idxs += tracing_context.fw_metadata.static_parameter_indices
1257+
static_input_idxs += tracing_context.fw_metadata.static_input_indices
12581258

12591259
with mock.patch.object(fake_mode, "allow_non_fake_inputs", True):
12601260
optimized_function = inner_compile(

0 commit comments

Comments
 (0)