Skip to content

Commit 69a7738

Browse files
mlazospytorchmergebot
authored andcommitted
Propagate buffer and parameter indices through AOT (pytorch#130393)
Pull Request resolved: pytorch#130393 Approved by: https://github.com/bdhirsh ghstack dependencies: pytorch#130391, pytorch#130392, pytorch#130503
1 parent 200d3d0 commit 69a7738

File tree

8 files changed

+91
-15
lines changed

8 files changed

+91
-15
lines changed

test/dynamo/test_subclasses.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1526,6 +1526,45 @@ def f(x):
15261526
out_test = compiled_f(view)
15271527
self.assertEqual(out_ref, out_test)
15281528

1529+
@torch._dynamo.config.patch("inline_inbuilt_nn_modules", True)
1530+
def test_mark_static_with_subclass_desugaring(self):
1531+
from typing import Any, Callable, Dict, List, Optional
1532+
1533+
from torch._dynamo.decorators import mark_static_address
1534+
from torch._inductor.compile_fx import compile_fx
1535+
from torch._inductor.cudagraph_utils import BoxedDeviceIndex
1536+
from torch._inductor.utils import BoxedBool
1537+
1538+
x_inner = torch.ones(4)
1539+
x = TwoTensor(x_inner, x_inner)
1540+
mark_static_address(x, guard=False)
1541+
1542+
def inner_compile(
1543+
gm: torch.fx.GraphModule,
1544+
example_inputs: List[torch.Tensor],
1545+
cudagraphs: Optional[BoxedBool] = None,
1546+
static_input_idxs: Optional[List[int]] = None,
1547+
is_backward: bool = False,
1548+
graph_id: Optional[int] = None,
1549+
cpp_wrapper: bool = False,
1550+
aot_mode: bool = False,
1551+
is_inference: bool = False,
1552+
boxed_forward_device_index: Optional[BoxedDeviceIndex] = None,
1553+
user_visible_outputs: Optional[Dict[str, None]] = None,
1554+
layout_opt: Optional[bool] = None,
1555+
extern_node_serializer: Optional[Callable[[List[Any]], Any]] = None,
1556+
):
1557+
self.assertEqual(static_input_idxs, [1, 2])
1558+
return gm
1559+
1560+
compiler = functools.partial(compile_fx, inner_compile=inner_compile)
1561+
1562+
@torch.compile(backend=compiler)
1563+
def fn(t0, t1, t2):
1564+
return t0 + t1 + t2 + 2
1565+
1566+
fn(torch.ones(4), x, torch.ones(4))
1567+
15291568

15301569
instantiate_parametrized_tests(SubclassTests)
15311570

torch/_functorch/_aot_autograd/collect_metadata_analysis.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import collections
1212
import logging
1313
from functools import wraps
14-
from typing import Callable, DefaultDict, Dict, List
14+
from typing import Callable, DefaultDict, Dict, List, Optional
1515

1616
import torch
1717
import torch.utils._pytree as pytree
@@ -25,6 +25,7 @@
2525
is_traceable_wrapper_subclass,
2626
transform_subclass,
2727
)
28+
2829
from .functional_utils import (
2930
are_all_mutations_hidden_from_autograd,
3031
are_all_mutations_under_no_grad_or_inference_mode,
@@ -124,6 +125,8 @@ def run_functionalized_fw_and_collect_metadata(
124125
keep_input_mutations: bool,
125126
# TODO: refactor to kill this flag
126127
is_train: bool = False,
128+
# Note: this is guaranteed to be set when running under dynamo
129+
static_input_indices: Optional[List[int]] = None,
127130
pre_dispatch: bool = False,
128131
) -> Callable[..., ViewAndMutationMeta]:
129132
memo: Dict[Tensor, Tensor] = {}
@@ -666,17 +669,15 @@ def view_avoid_dupes_with_primals(t):
666669
)
667670
user_outs = pytree.tree_map(from_fun, f_output_tangents)
668671

669-
if (
670-
torch._dynamo.config.inline_inbuilt_nn_modules
671-
or torch._dynamo.compiled_autograd.in_compiled_autograd_region
672-
):
673-
static_parameter_input_indices = [
672+
nonlocal static_input_indices
673+
static_input_indices = static_input_indices or []
674+
if torch._dynamo.compiled_autograd.in_compiled_autograd_region:
675+
passed_indices = set(static_input_indices)
676+
static_input_indices = [
674677
i
675678
for i, arg in enumerate(flat_args)
676-
if isinstance(arg, torch.nn.Parameter)
679+
if (isinstance(arg, torch.nn.Parameter) or i in passed_indices)
677680
]
678-
else:
679-
static_parameter_input_indices = []
680681

681682
f_mutated_inputs = [
682683
inp
@@ -729,7 +730,7 @@ def view_avoid_dupes_with_primals(t):
729730
subclass_tangent_meta=create_subclass_meta(traced_tangents),
730731
is_train=is_train,
731732
grad_enabled_mutation=grad_enabled_mutation,
732-
static_parameter_indices=static_parameter_input_indices,
733+
static_input_indices=static_input_indices,
733734
tokens=mode._tokens,
734735
)
735736
return metadata

torch/_functorch/_aot_autograd/runtime_wrappers.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -905,6 +905,7 @@ def wrapped_flat_fn(*args):
905905
if config.debug_assert:
906906
ref_fw_metadata = run_functionalized_fw_and_collect_metadata(
907907
wrapped_flat_fn,
908+
static_input_indices=aot_config.static_input_indices,
908909
keep_input_mutations=fw_metadata.keep_input_mutations,
909910
is_train=fw_metadata.is_train,
910911
)(*deduped_flat_args)
@@ -1094,6 +1095,7 @@ def wrapped_flat_fn(*args):
10941095
if config.debug_assert:
10951096
ref_fw_metadata = run_functionalized_fw_and_collect_metadata(
10961097
wrapped_flat_fn,
1098+
static_input_indices=aot_config.static_input_indices,
10971099
keep_input_mutations=fw_metadata.keep_input_mutations,
10981100
is_train=fw_metadata.is_train,
10991101
)(*flat_args_with_synthetic_bases)

torch/_functorch/_aot_autograd/schemas.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -329,7 +329,7 @@ class ViewAndMutationMeta:
329329
deterministic: Optional[bool] = None
330330

331331
# Keeps track of which input indices store parameters (which we will treat as static)
332-
static_parameter_indices: List[int] = field(default_factory=list)
332+
static_input_indices: List[int] = field(default_factory=list)
333333

334334
# Map of effect type (ex. _EffectType.ORDERED) to token. If there are
335335
# side-effectful operators, FunctionalTensorMode will populate this
@@ -803,6 +803,7 @@ class AOTConfig:
803803
no_tangents: bool = False
804804
dynamic_shapes: bool = False
805805
aot_autograd_arg_pos_to_source: Optional[List[Source]] = None
806+
static_input_indices: Optional[List[int]] = None
806807
inference_compiler: Optional[Callable] = None
807808
enable_log: bool = True
808809
# this is always false outside of export.

torch/_functorch/_aot_autograd/subclass_utils.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,24 @@ def concat_inner_tensors_from_subclasses(xs):
136136
return unwrapped_args
137137

138138

139+
def remap_unwrapped_subclass_arg_indices(wrapped_args, static_input_indices):
140+
static_input_indices = set(static_input_indices)
141+
new_ind = 0
142+
remapped_static_indices = []
143+
for i, arg in enumerate(wrapped_args):
144+
num_indices = 1
145+
if is_traceable_wrapper_subclass(arg):
146+
num_indices = len(get_plain_tensors(arg))
147+
148+
for _ in range(num_indices):
149+
if i in static_input_indices:
150+
remapped_static_indices.append(new_ind)
151+
152+
new_ind += 1
153+
154+
return remapped_static_indices
155+
156+
139157
# Turns a flattened list of tensor arguments into (maybe) subclass tensors.
140158
# This function is used both at trace time and runtime, so we have an is_runtime flag telling us which context we're in.
141159
def wrap_tensor_subclasses(

torch/_functorch/_aot_autograd/traced_function_transforms.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@
5353
)
5454
from .subclass_utils import (
5555
create_subclass_meta,
56+
remap_unwrapped_subclass_arg_indices,
5657
requires_subclass_dispatch,
5758
unwrap_tensor_subclasses,
5859
wrap_tensor_subclasses_maybe_joint,
@@ -702,6 +703,9 @@ def metadata_fn(*primals):
702703
args_unwrapped = unwrap_tensor_subclasses(
703704
args, is_joint_structure=is_joint_structure
704705
)
706+
remapped_static_indices = remap_unwrapped_subclass_arg_indices(
707+
args, meta.static_input_indices
708+
)
705709

706710
if is_joint_structure:
707711
primals_unwrapped = args_unwrapped[0]
@@ -729,6 +733,7 @@ def metadata_fn(*primals):
729733
# See Note: [Partitioner handling for Subclasses, Part 2] for more info.
730734
meta_updated = run_functionalized_fw_and_collect_metadata(
731735
metadata_fn,
736+
static_input_indices=remapped_static_indices,
732737
keep_input_mutations=meta.keep_input_mutations,
733738
is_train=meta.is_train,
734739
)(*primals_unwrapped)

torch/_functorch/aot_autograd.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from torch.fx.experimental.proxy_tensor import make_fx
2121
from torch.fx.experimental.symbolic_shapes import ShapeEnv
2222
from torch.utils._python_dispatch import is_traceable_wrapper_subclass
23+
2324
from . import config
2425
from ._aot_autograd.autograd_cache import ( # noqa: F401
2526
AOTAutogradCache,
@@ -588,6 +589,7 @@ def _dup_fake_script_obj(fake_flat_args):
588589
with ctx:
589590
fw_metadata = run_functionalized_fw_and_collect_metadata(
590591
flat_fn,
592+
static_input_indices=aot_config.static_input_indices,
591593
keep_input_mutations=aot_config.keep_inference_input_mutations,
592594
is_train=needs_autograd,
593595
pre_dispatch=aot_config.pre_dispatch,
@@ -623,6 +625,7 @@ def _dup_fake_script_obj(fake_flat_args):
623625
keep_input_mutations=aot_config.keep_inference_input_mutations,
624626
is_train=False,
625627
pre_dispatch=aot_config.pre_dispatch,
628+
static_input_indices=aot_config.static_input_indices,
626629
)(*fake_flat_args)
627630
else:
628631
fw_metadata = ViewAndMutationMeta(
@@ -636,7 +639,7 @@ def _dup_fake_script_obj(fake_flat_args):
636639
subclass_tangent_meta=fw_metadata.subclass_tangent_meta,
637640
is_train=False,
638641
tokens=fw_metadata.tokens,
639-
static_parameter_indices=fw_metadata.static_parameter_indices,
642+
static_input_indices=fw_metadata.static_input_indices,
640643
)
641644

642645
if fw_metadata.num_intermediate_bases > 0:
@@ -941,9 +944,10 @@ def aot_module_simplified(
941944
# Next, the input args
942945
full_args.extend(args)
943946

947+
static_input_indices = []
944948
if hasattr(mod, "graph"):
945949
# Non dynamo entrypoints can get to here...
946-
for node in mod.graph.find_nodes(op="placeholder"):
950+
for pos, node in enumerate(mod.graph.find_nodes(op="placeholder")):
947951
if hasattr(node, "_dynamo_source"):
948952
# ... but not here!
949953
if aot_autograd_arg_pos_to_source is None:
@@ -953,6 +957,11 @@ def aot_module_simplified(
953957
seen_sources.add(source)
954958
aot_autograd_arg_pos_to_source.append(source)
955959

960+
if "tensor_dict" in node.meta and node.meta["tensor_dict"].get(
961+
"_dynamo_static_input_type", None
962+
):
963+
static_input_indices.append(pos)
964+
956965
if aot_autograd_arg_pos_to_source is not None:
957966
assert len(full_args) == len(aot_autograd_arg_pos_to_source)
958967

@@ -973,6 +982,7 @@ def aot_module_simplified(
973982
keep_inference_input_mutations=keep_inference_input_mutations,
974983
dynamic_shapes=dynamic_shapes,
975984
aot_autograd_arg_pos_to_source=aot_autograd_arg_pos_to_source,
985+
static_input_indices=static_input_indices,
976986
is_export=False,
977987
no_tangents=False,
978988
cache_key=None,

torch/_inductor/compile_fx.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ def get_static_input_idxs(num_fixed):
136136
if not context or not context.fw_metadata:
137137
return fixed
138138

139-
return fixed + context.fw_metadata.static_parameter_indices
139+
return fixed + context.fw_metadata.static_input_indices
140140

141141

142142
@functools.lru_cache(None)
@@ -1252,7 +1252,7 @@ def fw_compiler_freezing(
12521252
params_flat[i] = None
12531253

12541254
if tracing_context.fw_metadata:
1255-
static_input_idxs += tracing_context.fw_metadata.static_parameter_indices
1255+
static_input_idxs += tracing_context.fw_metadata.static_input_indices
12561256

12571257
with mock.patch.object(fake_mode, "allow_non_fake_inputs", True):
12581258
optimized_function = inner_compile(

0 commit comments

Comments
 (0)