Skip to content

Commit 91c5910

Browse files
ydwu4facebook-github-bot
authored andcommitted
Replace node.meta source_fn with source_fn_stack (#108595)
Summary: X-link: pytorch/executorch#210 A resubmit of #108447. Copy over the descriptions: This is a follow-up of the discussion in #108356, where we want to repalce source_fn with source_fn_stack Test Plan: See added tests in test_higher_order_ops.py and modify existing test. Differential Revision: D48984986 Pulled By: ydwu4
1 parent c75aec9 commit 91c5910

File tree

13 files changed

+218
-40
lines changed

13 files changed

+218
-40
lines changed

test/dynamo/test_aot_autograd.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -798,10 +798,11 @@ def _prepare_model_args():
798798
continue
799799
if min_seq_nr < 0:
800800
min_seq_nr = seq_nr
801-
mod_name = node.meta.get("source_fn", "")
801+
source_fn_stack = node.meta.get("source_fn_stack", [])
802802
orig_aten = node.meta.get("original_aten", "")
803-
if isinstance(mod_name, tuple):
804-
mod_name = mod_name[0]
803+
mod_name = ""
804+
if len(source_fn_stack) > 0:
805+
mod_name = source_fn_stack[-1][0]
805806
# Make all seq_nr relative so it starts at 0
806807
seq_nr = seq_nr - min_seq_nr
807808
seq_table = seq_table + f"{seq_nr}|{orig_aten}|{mod_name}\n"

test/dynamo/test_export.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -957,7 +957,7 @@ def forward(self, x):
957957
if node.op not in {"placeholder", "output"}:
958958
self.assertTrue(node.stack_trace is not None)
959959
self.assertTrue(node.meta["nn_module_stack"] is not None)
960-
self.assertTrue(node.meta["source_fn"] is not None)
960+
self.assertTrue(node.meta["source_fn_stack"] is not None)
961961

962962
torch._dynamo.reset()
963963

@@ -967,7 +967,7 @@ def forward(self, x):
967967
if node.op == "call_function":
968968
self.assertTrue(node.stack_trace is not None)
969969
self.assertTrue(node.meta["nn_module_stack"] is not None)
970-
self.assertTrue(node.meta["source_fn"] is not None)
970+
self.assertTrue(node.meta["source_fn_stack"] is not None)
971971
self.assertTrue(node.meta["val"] is not None)
972972
self.assertTrue(node.meta["original_aten"] is not None)
973973

test/dynamo/test_higher_order_ops.py

Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1605,6 +1605,149 @@ def fn(x):
16051605

16061606
self.assertTrue(activations.keys() == forward_handles.keys())
16071607

1608+
def _check_source_fn_stack(self, gm, exp_stack_dict):
1609+
for mod in gm.modules():
1610+
for node in mod.graph.nodes:
1611+
print(node.name)
1612+
if node.name in exp_stack_dict:
1613+
exp_stack = exp_stack_dict[node.name]
1614+
actual_stack = [
1615+
name for name, _ in node.meta.get("source_fn_stack", [])
1616+
]
1617+
print(f"{exp_stack}, {actual_stack}")
1618+
self.assertEqual(actual_stack, exp_stack)
1619+
1620+
def test_wrap_source_fn_stack(self):
1621+
class MockModule(torch.nn.Module):
1622+
def __init__(self):
1623+
super().__init__()
1624+
self.linear = torch.nn.Linear(4, 4)
1625+
1626+
def forward(self, x):
1627+
return self.linear(x)
1628+
1629+
mod = MockModule()
1630+
1631+
def gn(x):
1632+
return torch.cos(x) + wrap(mod, x)
1633+
1634+
def fn(x):
1635+
return wrap(gn, x)
1636+
1637+
backend = EagerAndRecordGraphs()
1638+
inp = torch.randn((4, 4))
1639+
torch.compile(fn, backend=backend, fullgraph=True)(inp)
1640+
1641+
gm = backend.graphs[0]
1642+
exp_stack = {
1643+
"cos": ["wrap", "cos"],
1644+
"add": ["wrap", "add"],
1645+
"linear": ["wrap", "wrap", "linear"],
1646+
}
1647+
self._check_source_fn_stack(gm, exp_stack)
1648+
1649+
def test_cond_source_fn_stack(self):
1650+
backend = EagerAndRecordGraphs()
1651+
1652+
@torch.compile(backend=backend, fullgraph=True)
1653+
def cond_f(pred, pred2, x, y):
1654+
def true_fn(pred2, x, y):
1655+
return x + y
1656+
1657+
def false_fn(pred2, x, y):
1658+
def true_fn2(x, y):
1659+
return x.sin() - y.cos()
1660+
1661+
def false_fn2(x, y):
1662+
return x.cos() - y.sin()
1663+
1664+
return control_flow.cond(pred2, true_fn2, false_fn2, [x, y])
1665+
1666+
return control_flow.cond(pred, true_fn, false_fn, [pred2, x, y])
1667+
1668+
pred = torch.tensor(True)
1669+
pred2 = torch.tensor(False)
1670+
xs = torch.randn(2, 3, 3)
1671+
y = torch.randn(3, 3)
1672+
cond_f(pred, pred2, xs, y)
1673+
1674+
gm = backend.graphs[0]
1675+
exp_stack = {
1676+
"add": ["cond", "add"],
1677+
"cos": ["cond", "cond", "cos"],
1678+
"sin": ["cond", "cond", "sin"],
1679+
"sub": ["cond", "cond", "sub"],
1680+
}
1681+
self._check_source_fn_stack(gm, exp_stack)
1682+
1683+
def test_map_source_fn_stack(self):
1684+
backend = EagerAndRecordGraphs()
1685+
1686+
xs = torch.randn(2, 3, 3)
1687+
y = torch.randn(3)
1688+
1689+
@torch.compile(backend=backend, fullgraph=True)
1690+
def map_f(xs, y):
1691+
def inner(x, y):
1692+
def inner2(x, y):
1693+
return x + y
1694+
1695+
return control_flow.map(inner2, x, y) * y.cos()
1696+
1697+
return control_flow.map(inner, xs, y).sin()
1698+
1699+
result = map_f(xs, y)
1700+
1701+
gm = backend.graphs[0]
1702+
exp_stack = {
1703+
"sin": ["sin"],
1704+
"cos": ["map", "cos"],
1705+
"mul": ["map", "mul"],
1706+
"add": ["map", "map", "add"],
1707+
}
1708+
self._check_source_fn_stack(gm, exp_stack)
1709+
1710+
def test_grad_source_fn_stack(self):
1711+
backend = EagerAndRecordGraphs()
1712+
1713+
def fn(x):
1714+
return x.sin().sum()
1715+
1716+
@torch.compile(backend=backend, fullgraph=False)
1717+
def wrapper_fn(x):
1718+
return torch.func.grad(torch.func.grad(fn))(x)
1719+
1720+
x = torch.randn(())
1721+
1722+
wrapper_fn(x)
1723+
gm = backend.graphs[0]
1724+
exp_stack = {
1725+
"sin": ["grad_impl", "grad_impl", "sin"],
1726+
"sum": ["grad_impl", "grad_impl", "sum"],
1727+
}
1728+
self._check_source_fn_stack(gm, exp_stack)
1729+
1730+
def test_vmap_source_fn_stack(self):
1731+
backend = EagerAndRecordGraphs()
1732+
1733+
def inner_fn(x):
1734+
return torch.func.vmap(lambda x: x.sum(0) + x.sum(1))(x)
1735+
1736+
@torch.compile(backend=backend, fullgraph=True)
1737+
def fn(x):
1738+
return torch.func.vmap(lambda x: inner_fn(x.cos()))(x)
1739+
1740+
x = torch.randn(3, 3, 3, 3)
1741+
fn(x)
1742+
gm = backend.graphs[0]
1743+
exp_stack = {
1744+
"cos": ["vmap_impl", "cos"],
1745+
"sum_1": ["vmap_impl", "vmap_impl", "sum_1"],
1746+
"sum_2": ["vmap_impl", "vmap_impl", "sum_2"],
1747+
"add": ["vmap_impl", "vmap_impl", "add"],
1748+
}
1749+
self._check_source_fn_stack(gm, exp_stack)
1750+
16081751

16091752
class FuncTorchHigherOrderOpTests(torch._dynamo.test_case.TestCase):
16101753
def run(self, result=None):

test/export/test_export.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -346,7 +346,7 @@ def forward(self, x):
346346
node.name in ep.graph_signature.inputs_to_buffers or
347347
node.name in ep.graph_signature.inputs_to_parameters
348348
):
349-
self.assertTrue("source_fn" in node.meta)
349+
self.assertTrue("source_fn_stack" in node.meta)
350350
self.assertTrue("nn_module_stack" in node.meta)
351351

352352

@@ -1071,8 +1071,13 @@ def forward(self, x):
10711071
for mod in gm.modules():
10721072
for node in mod.graph.nodes:
10731073
if node.name in {"sin", "cos"}:
1074-
actual_source_fns.append(node.meta.get("source_fn", None))
1075-
exp_source_fns = [("cos", "cos"), ("sin", "sin")]
1074+
source_fn_st = node.meta.get("source_fn_stack", None)
1075+
if source_fn_st is not None:
1076+
source_names = []
1077+
for source_fn in source_fn_st:
1078+
source_names.append(source_fn[0])
1079+
actual_source_fns.append(source_names)
1080+
exp_source_fns = [["cond", "cos"], ["cond", "sin"]]
10761081
self.assertEqual(actual_source_fns, exp_source_fns)
10771082

10781083
def test_lift_constants(self) -> None:

test/export/test_serialize.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -273,10 +273,10 @@ def _check_graph_nodes(gm1, gm2, _check_meta=True):
273273
# node1.meta.get("nn_module_stack", None),
274274
# node2.meta.get("nn_module_stack", None),
275275
# )
276-
# Check "source_fn" metadata
276+
# Check "source_fn_stack" metadata
277277
self.assertEqual(
278-
node1.meta.get("source_fn", None),
279-
node2.meta.get("source_fn", None),
278+
node1.meta.get("source_fn_stack", None),
279+
node2.meta.get("source_fn_stack", None),
280280
)
281281

282282
_check_graph_nodes(ep.graph_module, deserialized_ep.graph_module, _check_meta)

test/functorch/test_control_flow.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1450,7 +1450,7 @@ def false_fn(x):
14501450
return x * x.sin()
14511451

14521452
def foo(x):
1453-
return cond(x.shape[0] == 4, true_fn, false_fn, [x])
1453+
return cond(x.shape[0] == 4, true_fn, false_fn, (x,))
14541454
inp = torch.randn([4, 3])
14551455
gm, _ = torch._dynamo.export(foo)(inp)
14561456

@@ -1461,7 +1461,7 @@ def run_with_interpreter(*args):
14611461

14621462

14631463
checked_ops = {"add", "mul", "sin", "cos"}
1464-
checked_meta = ["source_fn", "stack_trace"]
1464+
checked_meta = ["source_fn_stack", "stack_trace"]
14651465
all_source_fns = collect_meta_for_filtered_nodes(gm, checked_ops, checked_meta)
14661466
new_source_fns = collect_meta_for_filtered_nodes(new_gm, checked_ops, checked_meta)
14671467
self.assertEqual(all_source_fns, new_source_fns)

test/test_fx.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1757,13 +1757,13 @@ def forward(self, x):
17571757
if node.op == 'get_attr':
17581758
node.meta["nn_module_stack"] = "self"
17591759
node.meta["stack_trace"] = "stack_trace"
1760-
node.meta["source_fn"] = "source_fn"
1760+
node.meta["source_fn_stack"] = "source_fn_stack"
17611761
new_gm = Transformer(gm).transform()
17621762
for node in new_gm.graph.nodes:
17631763
if node.op == 'get_attr':
17641764
self.assertEqual(node.meta["nn_module_stack"], "self")
17651765
self.assertEqual(node.meta["stack_trace"], "stack_trace")
1766-
self.assertEqual(node.meta["source_fn"], "source_fn")
1766+
self.assertEqual(node.meta["source_fn_stack"], "source_fn_stack")
17671767

17681768

17691769
def test_interpreter(self):

torch/_dynamo/output_graph.py

Lines changed: 27 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -409,11 +409,13 @@ def remove_node(self, *args, **kwargs):
409409
return self.current_tracer.remove_node(*args, **kwargs)
410410

411411
@contextlib.contextmanager
412-
def new_subtracer(self):
412+
def new_subtracer(self, source_target):
413413
new_scope_ctx = enter_new_scope()
414414
try:
415415
new_scope_ctx.__enter__()
416-
tracer = SubgraphTracer(self, parent=self.current_tracer)
416+
tracer = SubgraphTracer(
417+
self, parent=self.current_tracer, source_target=source_target
418+
)
417419
self.tracers.append(tracer)
418420
yield tracer
419421
finally:
@@ -1185,7 +1187,9 @@ class SubgraphTracer(fx.Tracer):
11851187
compiling and executing the graph.
11861188
"""
11871189

1188-
def __init__(self, output_graph, parent=None, export_root=False):
1190+
def __init__(
1191+
self, output_graph, parent=None, export_root=False, source_target=None
1192+
):
11891193
super().__init__()
11901194
self.output_graph = weakref.proxy(output_graph)
11911195
self.graph = torch.fx.Graph()
@@ -1220,6 +1224,17 @@ def __init__(self, output_graph, parent=None, export_root=False):
12201224
self.lifted_freevars = collections.OrderedDict()
12211225
self.prev_inst = None
12221226

1227+
# Each SubgraphTracer is associated with a source target, which indicates
1228+
# which operator this subgraph is attached to. We compute a source_fn_stack
1229+
# based on the source tareget. For the root tracer, it's set to [].
1230+
# This is useful for debugging and transforming the exported graph.
1231+
if self.parent is None:
1232+
self.source_fn_stack = []
1233+
else:
1234+
self.source_fn_stack = self.parent.source_fn_stack + [
1235+
(self.graph._target_to_str(source_target), source_target)
1236+
]
1237+
12231238
def create_proxy(
12241239
self,
12251240
kind,
@@ -1302,15 +1317,19 @@ def get_trace_call_log_str():
13021317
rv.node.meta["nn_module_stack"] = nn_module_stack.copy()
13031318

13041319
if kind in {"call_function", "call_method"}:
1305-
rv.node.meta["source_fn"] = (rv.node.name, target)
1320+
rv.node.meta["source_fn_stack"] = self.source_fn_stack + [
1321+
(rv.node.name, target)
1322+
]
13061323
elif kind == "call_module":
13071324
if self.parent is not None:
13081325
unimplemented("Invoking an nn.Module inside HigherOrderOperator")
13091326
# For modules we store the class
1310-
rv.node.meta["source_fn"] = (
1311-
rv.node.name,
1312-
rv.node.meta["nn_module_stack"][target][1],
1313-
)
1327+
rv.node.meta["source_fn_stack"] = self.source_fn_stack + [
1328+
(
1329+
rv.node.name,
1330+
rv.node.meta["nn_module_stack"][target][1],
1331+
)
1332+
]
13141333

13151334
frame_summaries: List[traceback.FrameSummary] = []
13161335
while tx:

torch/_dynamo/variables/higher_order_ops.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,7 @@ def speculate_subgraph(
118118
graph_checkpoint,
119119
checkpoint,
120120
description,
121+
source_target,
121122
*,
122123
always_restore=False,
123124
enable_grad=False,
@@ -141,7 +142,7 @@ def speculate_subgraph(
141142
)
142143

143144
try:
144-
with tx.output.new_subtracer() as tracer:
145+
with tx.output.new_subtracer(source_target) as tracer:
145146
args = validate_args_and_maybe_create_graph_inputs(
146147
sub_args, tracer, tx, manually_set_subgraph_inputs
147148
)
@@ -397,6 +398,7 @@ def speculate_branch(branch):
397398
graph_checkpoint,
398399
checkpoint,
399400
"cond",
401+
self.value,
400402
)
401403
# Reraise because we want to suggest workarounds
402404
except Unsupported as e:
@@ -546,6 +548,7 @@ def call_function(
546548
tx.output.graph,
547549
checkpoint,
548550
"torch.ops.higher_order.map",
551+
self.value,
549552
)
550553

551554
body_nn_modules = tx.copy_graphstate().output.nn_modules
@@ -683,6 +686,7 @@ def call_function(
683686
graph_checkpoint,
684687
checkpoint,
685688
"torch.func.grad",
689+
self.value,
686690
# See NOTE [HACK: Enable autograd while tracing function]
687691
enable_grad=True,
688692
)
@@ -872,6 +876,7 @@ def call_function(
872876
graph_checkpoint,
873877
checkpoint,
874878
"torch.vmap",
879+
self.value,
875880
)
876881

877882
body_name = add_subgraph(
@@ -982,6 +987,7 @@ def call_function(
982987
graph_checkpoint,
983988
checkpoint,
984989
"the user-defined autograd.Function",
990+
self.value,
985991
# Backwards should never, ever be stored!
986992
always_restore=always_restore,
987993
restore_side_effects=False,
@@ -1039,6 +1045,7 @@ def create_wrapped_node(self, tx, args, kwargs, description):
10391045
graph_checkpoint,
10401046
checkpoint,
10411047
description,
1048+
self.value,
10421049
manually_set_subgraph_inputs=False,
10431050
)
10441051

0 commit comments

Comments
 (0)