Skip to content

Commit cbda8be

Browse files
Revert "Propagate buffer and parameter indices through AOT (pytorch#130393)"
This reverts commit 69a7738. Reverted pytorch#130393 on behalf of https://github.com/clee2000 due to broke lint for torch/_functorch/_aot_autograd/subclass_utils.py https://github.com/pytorch/pytorch/actions/runs/9948630877/job/27483551649 https://hud.pytorch.org/pytorch/pytorch/commit/80236dca90b0874cb2b6f9c9fa5f159c55726401 lint was green on PR, probably a landrace ([comment](pytorch#130393 (comment)))
1 parent 9cb23ba commit cbda8be

File tree

8 files changed

+15
-91
lines changed

8 files changed

+15
-91
lines changed

test/dynamo/test_subclasses.py

Lines changed: 0 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1526,45 +1526,6 @@ def f(x):
15261526
out_test = compiled_f(view)
15271527
self.assertEqual(out_ref, out_test)
15281528

1529-
@torch._dynamo.config.patch("inline_inbuilt_nn_modules", True)
1530-
def test_mark_static_with_subclass_desugaring(self):
1531-
from typing import Any, Callable, Dict, List, Optional
1532-
1533-
from torch._dynamo.decorators import mark_static_address
1534-
from torch._inductor.compile_fx import compile_fx
1535-
from torch._inductor.cudagraph_utils import BoxedDeviceIndex
1536-
from torch._inductor.utils import BoxedBool
1537-
1538-
x_inner = torch.ones(4)
1539-
x = TwoTensor(x_inner, x_inner)
1540-
mark_static_address(x, guard=False)
1541-
1542-
def inner_compile(
1543-
gm: torch.fx.GraphModule,
1544-
example_inputs: List[torch.Tensor],
1545-
cudagraphs: Optional[BoxedBool] = None,
1546-
static_input_idxs: Optional[List[int]] = None,
1547-
is_backward: bool = False,
1548-
graph_id: Optional[int] = None,
1549-
cpp_wrapper: bool = False,
1550-
aot_mode: bool = False,
1551-
is_inference: bool = False,
1552-
boxed_forward_device_index: Optional[BoxedDeviceIndex] = None,
1553-
user_visible_outputs: Optional[Dict[str, None]] = None,
1554-
layout_opt: Optional[bool] = None,
1555-
extern_node_serializer: Optional[Callable[[List[Any]], Any]] = None,
1556-
):
1557-
self.assertEqual(static_input_idxs, [1, 2])
1558-
return gm
1559-
1560-
compiler = functools.partial(compile_fx, inner_compile=inner_compile)
1561-
1562-
@torch.compile(backend=compiler)
1563-
def fn(t0, t1, t2):
1564-
return t0 + t1 + t2 + 2
1565-
1566-
fn(torch.ones(4), x, torch.ones(4))
1567-
15681529

15691530
instantiate_parametrized_tests(SubclassTests)
15701531

torch/_functorch/_aot_autograd/collect_metadata_analysis.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import collections
1212
import logging
1313
from functools import wraps
14-
from typing import Callable, DefaultDict, Dict, List, Optional
14+
from typing import Callable, DefaultDict, Dict, List
1515

1616
import torch
1717
import torch.utils._pytree as pytree
@@ -25,7 +25,6 @@
2525
is_traceable_wrapper_subclass,
2626
transform_subclass,
2727
)
28-
2928
from .functional_utils import (
3029
are_all_mutations_hidden_from_autograd,
3130
are_all_mutations_under_no_grad_or_inference_mode,
@@ -125,8 +124,6 @@ def run_functionalized_fw_and_collect_metadata(
125124
keep_input_mutations: bool,
126125
# TODO: refactor to kill this flag
127126
is_train: bool = False,
128-
# Note: this is guaranteed to be set when running under dynamo
129-
static_input_indices: Optional[List[int]] = None,
130127
pre_dispatch: bool = False,
131128
) -> Callable[..., ViewAndMutationMeta]:
132129
memo: Dict[Tensor, Tensor] = {}
@@ -669,15 +666,17 @@ def view_avoid_dupes_with_primals(t):
669666
)
670667
user_outs = pytree.tree_map(from_fun, f_output_tangents)
671668

672-
nonlocal static_input_indices
673-
static_input_indices = static_input_indices or []
674-
if torch._dynamo.compiled_autograd.in_compiled_autograd_region:
675-
passed_indices = set(static_input_indices)
676-
static_input_indices = [
669+
if (
670+
torch._dynamo.config.inline_inbuilt_nn_modules
671+
or torch._dynamo.compiled_autograd.in_compiled_autograd_region
672+
):
673+
static_parameter_input_indices = [
677674
i
678675
for i, arg in enumerate(flat_args)
679-
if (isinstance(arg, torch.nn.Parameter) or i in passed_indices)
676+
if isinstance(arg, torch.nn.Parameter)
680677
]
678+
else:
679+
static_parameter_input_indices = []
681680

682681
f_mutated_inputs = [
683682
inp
@@ -730,7 +729,7 @@ def view_avoid_dupes_with_primals(t):
730729
subclass_tangent_meta=create_subclass_meta(traced_tangents),
731730
is_train=is_train,
732731
grad_enabled_mutation=grad_enabled_mutation,
733-
static_input_indices=static_input_indices,
732+
static_parameter_indices=static_parameter_input_indices,
734733
tokens=mode._tokens,
735734
)
736735
return metadata

torch/_functorch/_aot_autograd/runtime_wrappers.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -905,7 +905,6 @@ def wrapped_flat_fn(*args):
905905
if config.debug_assert:
906906
ref_fw_metadata = run_functionalized_fw_and_collect_metadata(
907907
wrapped_flat_fn,
908-
static_input_indices=aot_config.static_input_indices,
909908
keep_input_mutations=fw_metadata.keep_input_mutations,
910909
is_train=fw_metadata.is_train,
911910
)(*deduped_flat_args)
@@ -1095,7 +1094,6 @@ def wrapped_flat_fn(*args):
10951094
if config.debug_assert:
10961095
ref_fw_metadata = run_functionalized_fw_and_collect_metadata(
10971096
wrapped_flat_fn,
1098-
static_input_indices=aot_config.static_input_indices,
10991097
keep_input_mutations=fw_metadata.keep_input_mutations,
11001098
is_train=fw_metadata.is_train,
11011099
)(*flat_args_with_synthetic_bases)

torch/_functorch/_aot_autograd/schemas.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -329,7 +329,7 @@ class ViewAndMutationMeta:
329329
deterministic: Optional[bool] = None
330330

331331
# Keeps track of which input indices store parameters (which we will treat as static)
332-
static_input_indices: List[int] = field(default_factory=list)
332+
static_parameter_indices: List[int] = field(default_factory=list)
333333

334334
# Map of effect type (ex. _EffectType.ORDERED) to token. If there are
335335
# side-effectful operators, FunctionalTensorMode will populate this
@@ -803,7 +803,6 @@ class AOTConfig:
803803
no_tangents: bool = False
804804
dynamic_shapes: bool = False
805805
aot_autograd_arg_pos_to_source: Optional[List[Source]] = None
806-
static_input_indices: Optional[List[int]] = None
807806
inference_compiler: Optional[Callable] = None
808807
enable_log: bool = True
809808
# this is always false outside of export.

torch/_functorch/_aot_autograd/subclass_utils.py

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -136,24 +136,6 @@ def concat_inner_tensors_from_subclasses(xs):
136136
return unwrapped_args
137137

138138

139-
def remap_unwrapped_subclass_arg_indices(wrapped_args, static_input_indices):
140-
static_input_indices = set(static_input_indices)
141-
new_ind = 0
142-
remapped_static_indices = []
143-
for i, arg in enumerate(wrapped_args):
144-
num_indices = 1
145-
if is_traceable_wrapper_subclass(arg):
146-
num_indices = len(get_plain_tensors(arg))
147-
148-
for _ in range(num_indices):
149-
if i in static_input_indices:
150-
remapped_static_indices.append(new_ind)
151-
152-
new_ind += 1
153-
154-
return remapped_static_indices
155-
156-
157139
# Turns a flattened list of tensor arguments into (maybe) subclass tensors.
158140
# This function is used both at trace time and runtime, so we have an is_runtime flag telling us which context we're in.
159141
def wrap_tensor_subclasses(

torch/_functorch/_aot_autograd/traced_function_transforms.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,6 @@
5353
)
5454
from .subclass_utils import (
5555
create_subclass_meta,
56-
remap_unwrapped_subclass_arg_indices,
5756
requires_subclass_dispatch,
5857
unwrap_tensor_subclasses,
5958
wrap_tensor_subclasses_maybe_joint,
@@ -703,9 +702,6 @@ def metadata_fn(*primals):
703702
args_unwrapped = unwrap_tensor_subclasses(
704703
args, is_joint_structure=is_joint_structure
705704
)
706-
remapped_static_indices = remap_unwrapped_subclass_arg_indices(
707-
args, meta.static_input_indices
708-
)
709705

710706
if is_joint_structure:
711707
primals_unwrapped = args_unwrapped[0]
@@ -733,7 +729,6 @@ def metadata_fn(*primals):
733729
# See Note: [Partitioner handling for Subclasses, Part 2] for more info.
734730
meta_updated = run_functionalized_fw_and_collect_metadata(
735731
metadata_fn,
736-
static_input_indices=remapped_static_indices,
737732
keep_input_mutations=meta.keep_input_mutations,
738733
is_train=meta.is_train,
739734
)(*primals_unwrapped)

torch/_functorch/aot_autograd.py

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
from torch.fx.experimental.proxy_tensor import make_fx
2121
from torch.fx.experimental.symbolic_shapes import ShapeEnv
2222
from torch.utils._python_dispatch import is_traceable_wrapper_subclass
23-
2423
from . import config
2524
from ._aot_autograd.autograd_cache import ( # noqa: F401
2625
AOTAutogradCache,
@@ -589,7 +588,6 @@ def _dup_fake_script_obj(fake_flat_args):
589588
with ctx:
590589
fw_metadata = run_functionalized_fw_and_collect_metadata(
591590
flat_fn,
592-
static_input_indices=aot_config.static_input_indices,
593591
keep_input_mutations=aot_config.keep_inference_input_mutations,
594592
is_train=needs_autograd,
595593
pre_dispatch=aot_config.pre_dispatch,
@@ -625,7 +623,6 @@ def _dup_fake_script_obj(fake_flat_args):
625623
keep_input_mutations=aot_config.keep_inference_input_mutations,
626624
is_train=False,
627625
pre_dispatch=aot_config.pre_dispatch,
628-
static_input_indices=aot_config.static_input_indices,
629626
)(*fake_flat_args)
630627
else:
631628
fw_metadata = ViewAndMutationMeta(
@@ -639,7 +636,7 @@ def _dup_fake_script_obj(fake_flat_args):
639636
subclass_tangent_meta=fw_metadata.subclass_tangent_meta,
640637
is_train=False,
641638
tokens=fw_metadata.tokens,
642-
static_input_indices=fw_metadata.static_input_indices,
639+
static_parameter_indices=fw_metadata.static_parameter_indices,
643640
)
644641

645642
if fw_metadata.num_intermediate_bases > 0:
@@ -944,10 +941,9 @@ def aot_module_simplified(
944941
# Next, the input args
945942
full_args.extend(args)
946943

947-
static_input_indices = []
948944
if hasattr(mod, "graph"):
949945
# Non dynamo entrypoints can get to here...
950-
for pos, node in enumerate(mod.graph.find_nodes(op="placeholder")):
946+
for node in mod.graph.find_nodes(op="placeholder"):
951947
if hasattr(node, "_dynamo_source"):
952948
# ... but not here!
953949
if aot_autograd_arg_pos_to_source is None:
@@ -957,11 +953,6 @@ def aot_module_simplified(
957953
seen_sources.add(source)
958954
aot_autograd_arg_pos_to_source.append(source)
959955

960-
if "tensor_dict" in node.meta and node.meta["tensor_dict"].get(
961-
"_dynamo_static_input_type", None
962-
):
963-
static_input_indices.append(pos)
964-
965956
if aot_autograd_arg_pos_to_source is not None:
966957
assert len(full_args) == len(aot_autograd_arg_pos_to_source)
967958

@@ -982,7 +973,6 @@ def aot_module_simplified(
982973
keep_inference_input_mutations=keep_inference_input_mutations,
983974
dynamic_shapes=dynamic_shapes,
984975
aot_autograd_arg_pos_to_source=aot_autograd_arg_pos_to_source,
985-
static_input_indices=static_input_indices,
986976
is_export=False,
987977
no_tangents=False,
988978
cache_key=None,

torch/_inductor/compile_fx.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ def get_static_input_idxs(num_fixed):
136136
if not context or not context.fw_metadata:
137137
return fixed
138138

139-
return fixed + context.fw_metadata.static_input_indices
139+
return fixed + context.fw_metadata.static_parameter_indices
140140

141141

142142
@functools.lru_cache(None)
@@ -1246,7 +1246,7 @@ def fw_compiler_freezing(
12461246
params_flat[i] = None
12471247

12481248
if tracing_context.fw_metadata:
1249-
static_input_idxs += tracing_context.fw_metadata.static_input_indices
1249+
static_input_idxs += tracing_context.fw_metadata.static_parameter_indices
12501250

12511251
with mock.patch.object(fake_mode, "allow_non_fake_inputs", True):
12521252
optimized_function = inner_compile(

0 commit comments

Comments
 (0)