Skip to content

Commit ce69b53

Browse files
ydwu4facebook-github-bot
authored andcommitted
Replace node.meta source_fn with source_fn_stack (#108595)
Summary: X-link: pytorch/executorch#210 A resubmit of #108447. Copy over the descriptions: This is a follow-up of the discussion in #108356, where we want to repalce source_fn with source_fn_stack Before this PR, for the following example: ```python backend = EagerAndRecordGraphs() torch.compile(backend=backend, fullgraph=True) def cond_f(pred, pred2, x, y): def true_fn(pred2, x, y): return x + y def false_fn(pred2, x, y): def true_fn2(x, y): return x.sin() - y.cos() def false_fn2(x, y): return x.cos() - y.sin() return control_flow.cond(pred2, true_fn2, false_fn2, (x, y)) return control_flow.cond(pred, true_fn, false_fn, (pred2, x, y)) ``` The graph captured is shown below: ```python class GraphModule(torch.nn.Module): def forward(self, L_pred_ : torch.Tensor, L_pred2_ : torch.Tensor, L_x_ : torch.Tensor, L_y_ : torch.Tensor): l_pred_ = L_pred_ l_pred2_ = L_pred2_ l_x_ = L_x_ l_y_ = L_y_ cond_true_1 = self.cond_true_1 cond_false_1 = self.cond_false_1 cond = torch.ops.higher_order.cond(l_pred_, cond_true_1, cond_false_1, [l_pred2_, l_x_, l_y_]); l_pred_ = cond_true_1 = cond_false_1 = l_pred2_ = l_x_ = l_y_ = None return (cond,) class GraphModule(torch.nn.Module): def forward(self, l_pred2_, l_x_, l_y_): add = l_x_ + l_y_; l_x_ = l_y_ = None return add class GraphModule(torch.nn.Module): def forward(self, l_pred2_, l_x_, l_y_): cond_true_0 = self.cond_true_0 cond_false_0 = self.cond_false_0 cond = torch.ops.higher_order.cond(l_pred2_, cond_true_0, cond_false_0, [l_x_, l_y_]); l_pred2_ = cond_true_0 = cond_false_0 = l_x_ = l_y_ = None return cond class GraphModule(torch.nn.Module): def forward(self, l_x_, l_y_): sin = l_x_.sin(); l_x_ = None cos = l_y_.cos(); l_y_ = None sub = sin - cos; sin = cos = None return sub class GraphModule(torch.nn.Module): def forward(self, l_x_, l_y_): cos = l_x_.cos(); l_x_ = None sin = l_y_.sin(); l_y_ = None sub = cos - sin; cos = sin = None return sub ``` the source_fn for inner cond, sin, cos will be a (name, target) tuple: ``` ('cond', <torch._ops.HigherOrderOperator object at xxx>) ('sin', 'sin') ('cos', 'cos') ('sub'. <built-in function sub>) ``` After this pr, the source_fn_stack will be a list of (name, target) tuple. The bottom of stack is the end of the list. ``` [('cond', <torch._ops.HigherOrderOperator object at xxx>), ('cond', <torch._ops.HigherOrderOperator object at xxx>)], [('cond', <torch._ops.HigherOrderOperator object at xxx>), ('cond', <torch._ops.HigherOrderOperator object at xxx>), ('sin', 'sin')], [('cond', <torch._ops.HigherOrderOperator object at xxx>), ('cond', <torch._ops.HigherOrderOperator object at xxx>), ('cos', 'cos')] [('cond', <torch._ops.HigherOrderOperator object at xxx>), ('cond', <torch._ops.HigherOrderOperator object at xxx>), ('sub', <built-in function sub>)] ``` Test Plan: See added tests in test_higher_order_ops.py and modify existing test. Also updated bin by running: "buck2 run @//mode/dev-nosan fbcode//aibench/api:gen_test_files --config client.id=nuclide" cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy ngimel yf225 chenyang78 kadeng muchulee8 aakhundov Reviewed By: angelayi Differential Revision: D48984986 Pulled By: ydwu4
1 parent 40b83d9 commit ce69b53

File tree

14 files changed

+248
-46
lines changed

14 files changed

+248
-46
lines changed

test/dynamo/test_aot_autograd.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -798,10 +798,11 @@ def _prepare_model_args():
798798
continue
799799
if min_seq_nr < 0:
800800
min_seq_nr = seq_nr
801-
mod_name = node.meta.get("source_fn", "")
801+
source_fn_stack = node.meta.get("source_fn_stack", [])
802802
orig_aten = node.meta.get("original_aten", "")
803-
if isinstance(mod_name, tuple):
804-
mod_name = mod_name[0]
803+
mod_name = ""
804+
if len(source_fn_stack) > 0:
805+
mod_name = source_fn_stack[-1][0]
805806
# Make all seq_nr relative so it starts at 0
806807
seq_nr = seq_nr - min_seq_nr
807808
seq_table = seq_table + f"{seq_nr}|{orig_aten}|{mod_name}\n"

test/dynamo/test_export.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -964,7 +964,7 @@ def forward(self, x):
964964
if node.op not in {"placeholder", "output"}:
965965
self.assertTrue(node.stack_trace is not None)
966966
self.assertTrue(node.meta["nn_module_stack"] is not None)
967-
self.assertTrue(node.meta["source_fn"] is not None)
967+
self.assertTrue(node.meta["source_fn_stack"] is not None)
968968

969969
torch._dynamo.reset()
970970

@@ -974,7 +974,7 @@ def forward(self, x):
974974
if node.op == "call_function":
975975
self.assertTrue(node.stack_trace is not None)
976976
self.assertTrue(node.meta["nn_module_stack"] is not None)
977-
self.assertTrue(node.meta["source_fn"] is not None)
977+
self.assertTrue(node.meta["source_fn_stack"] is not None)
978978
self.assertTrue(node.meta["val"] is not None)
979979
self.assertTrue(node.meta["original_aten"] is not None)
980980

@@ -4014,7 +4014,6 @@ def fn(x):
40144014
self.assertEqual(
40154015
nd1.meta["nn_module_stack"], nd2.meta["nn_module_stack"]
40164016
)
4017-
self.assertEqual(nd1.meta["source_fn"], nd2.meta["source_fn"])
40184017
self.assertEqual(nd1.meta["stack_trace"], nd2.meta["stack_trace"])
40194018

40204019
def test_preserve_fx_node_metadata_recompile(self):

test/dynamo/test_higher_order_ops.py

Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# Owner(s): ["module: dynamo"]
22
import functools
3+
import pprint
34
import re
45
import unittest
56

@@ -1809,6 +1810,153 @@ def fn(x):
18091810

18101811
self.assertTrue(activations.keys() == forward_handles.keys())
18111812

1813+
def _get_source_fn_stack(self, gm, node_names):
1814+
ret = {}
1815+
for mod in gm.modules():
1816+
for node in mod.graph.nodes:
1817+
if node.name in node_names:
1818+
actual_stack = [
1819+
name for name, _ in node.meta.get("source_fn_stack", [])
1820+
]
1821+
ret[node.name] = actual_stack
1822+
return ret
1823+
1824+
def test_wrap_source_fn_stack(self):
1825+
class MockModule(torch.nn.Module):
1826+
def __init__(self):
1827+
super().__init__()
1828+
self.linear = torch.nn.Linear(4, 4)
1829+
1830+
def forward(self, x):
1831+
return self.linear(x)
1832+
1833+
mod = MockModule()
1834+
1835+
def gn(x):
1836+
return torch.cos(x) + wrap(mod, x)
1837+
1838+
def fn(x):
1839+
return wrap(gn, x)
1840+
1841+
backend = EagerAndRecordGraphs()
1842+
inp = torch.randn((4, 4))
1843+
torch.compile(fn, backend=backend, fullgraph=True)(inp)
1844+
1845+
gm = backend.graphs[0]
1846+
actual_stack = self._get_source_fn_stack(gm, {"cos", "add", "linear"})
1847+
self.assertExpectedInline(
1848+
pprint.pformat(actual_stack),
1849+
"""\
1850+
{'add': ['wrap', 'add'],
1851+
'cos': ['wrap', 'cos'],
1852+
'linear': ['wrap', 'wrap', 'linear']}""",
1853+
)
1854+
1855+
def test_cond_source_fn_stack(self):
1856+
backend = EagerAndRecordGraphs()
1857+
1858+
@torch.compile(backend=backend, fullgraph=True)
1859+
def cond_f(pred, pred2, x, y):
1860+
def true_fn(pred2, x, y):
1861+
return x + y
1862+
1863+
def false_fn(pred2, x, y):
1864+
def true_fn2(x, y):
1865+
return x.sin() - y.cos()
1866+
1867+
def false_fn2(x, y):
1868+
return x.cos() - y.sin()
1869+
1870+
return control_flow.cond(pred2, true_fn2, false_fn2, [x, y])
1871+
1872+
return control_flow.cond(pred, true_fn, false_fn, [pred2, x, y])
1873+
1874+
pred = torch.tensor(True)
1875+
pred2 = torch.tensor(False)
1876+
xs = torch.randn(2, 3, 3)
1877+
y = torch.randn(3, 3)
1878+
cond_f(pred, pred2, xs, y)
1879+
1880+
gm = backend.graphs[0]
1881+
actual_stack = self._get_source_fn_stack(gm, {"cos", "add", "sin", "sub"})
1882+
self.assertExpectedInline(
1883+
pprint.pformat(actual_stack),
1884+
"""\
1885+
{'add': ['cond', 'add'],
1886+
'cos': ['cond', 'cond', 'cos'],
1887+
'sin': ['cond', 'cond', 'sin'],
1888+
'sub': ['cond', 'cond', 'sub']}""",
1889+
)
1890+
1891+
def test_map_source_fn_stack(self):
1892+
backend = EagerAndRecordGraphs()
1893+
1894+
xs = torch.randn(2, 3, 3)
1895+
y = torch.randn(3)
1896+
1897+
@torch.compile(backend=backend, fullgraph=True)
1898+
def map_f(xs, y):
1899+
def inner(x, y):
1900+
def inner2(x, y):
1901+
return x + y
1902+
1903+
return control_flow.map(inner2, x, y) * y.cos()
1904+
1905+
return control_flow.map(inner, xs, y).sin()
1906+
1907+
result = map_f(xs, y)
1908+
1909+
gm = backend.graphs[0]
1910+
actual_stack = self._get_source_fn_stack(gm, {"cos", "add", "sin"})
1911+
self.assertExpectedInline(
1912+
pprint.pformat(actual_stack),
1913+
"""{'add': ['map', 'map', 'add'], 'cos': ['map', 'cos'], 'sin': ['sin']}""",
1914+
)
1915+
1916+
def test_grad_source_fn_stack(self):
1917+
backend = EagerAndRecordGraphs()
1918+
1919+
def fn(x):
1920+
return x.sin().sum()
1921+
1922+
@torch.compile(backend=backend, fullgraph=False)
1923+
def wrapper_fn(x):
1924+
return torch.func.grad(torch.func.grad(fn))(x)
1925+
1926+
x = torch.randn(())
1927+
1928+
wrapper_fn(x)
1929+
gm = backend.graphs[0]
1930+
actual_stack = self._get_source_fn_stack(gm, {"sum_1", "sin"})
1931+
self.assertExpectedInline(
1932+
pprint.pformat(actual_stack),
1933+
"""\
1934+
{'sin': ['grad_impl', 'grad_impl', 'sin'],
1935+
'sum_1': ['grad_impl', 'grad_impl', 'sum_1']}""",
1936+
)
1937+
1938+
def test_vmap_source_fn_stack(self):
1939+
backend = EagerAndRecordGraphs()
1940+
1941+
def inner_fn(x):
1942+
return torch.func.vmap(lambda x: x.sum(0) + x.sum(1))(x)
1943+
1944+
@torch.compile(backend=backend, fullgraph=True)
1945+
def fn(x):
1946+
return torch.func.vmap(lambda x: inner_fn(x.cos()))(x)
1947+
1948+
x = torch.randn(3, 3, 3, 3)
1949+
fn(x)
1950+
gm = backend.graphs[0]
1951+
actual_stack = self._get_source_fn_stack(gm, {"sum_1", "sum_2", "add"})
1952+
self.assertExpectedInline(
1953+
pprint.pformat(actual_stack),
1954+
"""\
1955+
{'add': ['vmap_impl', 'vmap_impl', 'add'],
1956+
'sum_1': ['vmap_impl', 'vmap_impl', 'sum_1'],
1957+
'sum_2': ['vmap_impl', 'vmap_impl', 'sum_2']}""",
1958+
)
1959+
18121960

18131961
class FuncTorchHigherOrderOpTests(torch._dynamo.test_case.TestCase):
18141962
def run(self, result=None):

test/export/test_export.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -346,7 +346,7 @@ def forward(self, x):
346346
node.name in ep.graph_signature.inputs_to_buffers or
347347
node.name in ep.graph_signature.inputs_to_parameters
348348
):
349-
self.assertTrue("source_fn" in node.meta)
349+
self.assertTrue("source_fn_stack" in node.meta)
350350
self.assertTrue("nn_module_stack" in node.meta)
351351

352352
def test_export_api_with_dynamic_shapes(self):
@@ -1339,8 +1339,13 @@ def forward(self, x):
13391339
for mod in gm.modules():
13401340
for node in mod.graph.nodes:
13411341
if node.name in {"sin", "cos"}:
1342-
actual_source_fns.append(node.meta.get("source_fn", None))
1343-
exp_source_fns = [("cos", "cos"), ("sin", "sin")]
1342+
source_fn_st = node.meta.get("source_fn_stack", None)
1343+
if source_fn_st is not None:
1344+
source_names = []
1345+
for source_fn in source_fn_st:
1346+
source_names.append(source_fn[0])
1347+
actual_source_fns.append(source_names)
1348+
exp_source_fns = [["cond", "cos"], ["cond", "sin"]]
13441349
self.assertEqual(actual_source_fns, exp_source_fns)
13451350

13461351
def test_lift_constants(self) -> None:

test/export/test_serialize.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -273,10 +273,10 @@ def _check_graph_nodes(gm1, gm2, _check_meta=True):
273273
# node1.meta.get("nn_module_stack", None),
274274
# node2.meta.get("nn_module_stack", None),
275275
# )
276-
# Check "source_fn" metadata
276+
# Check "source_fn_stack" metadata
277277
self.assertEqual(
278-
node1.meta.get("source_fn", None),
279-
node2.meta.get("source_fn", None),
278+
node1.meta.get("source_fn_stack", None),
279+
node2.meta.get("source_fn_stack", None),
280280
)
281281

282282
_check_graph_nodes(ep.graph_module, deserialized_ep.graph_module, _check_meta)

test/functorch/test_control_flow.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1501,7 +1501,7 @@ def false_fn(x):
15011501
return x * x.sin()
15021502

15031503
def foo(x):
1504-
return cond(x.shape[0] == 4, true_fn, false_fn, [x])
1504+
return cond(x.shape[0] == 4, true_fn, false_fn, (x,))
15051505
inp = torch.randn([4, 3])
15061506
gm, _ = torch._dynamo.export(foo)(inp)
15071507

@@ -1512,7 +1512,7 @@ def run_with_interpreter(*args):
15121512

15131513

15141514
checked_ops = {"add", "mul", "sin", "cos"}
1515-
checked_meta = ["source_fn", "stack_trace"]
1515+
checked_meta = ["source_fn_stack", "stack_trace"]
15161516
all_source_fns = collect_meta_for_filtered_nodes(gm, checked_ops, checked_meta)
15171517
new_source_fns = collect_meta_for_filtered_nodes(new_gm, checked_ops, checked_meta)
15181518
self.assertEqual(all_source_fns, new_source_fns)

test/test_fx.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1777,13 +1777,13 @@ def forward(self, x):
17771777
if node.op == 'get_attr':
17781778
node.meta["nn_module_stack"] = "self"
17791779
node.meta["stack_trace"] = "stack_trace"
1780-
node.meta["source_fn"] = "source_fn"
1780+
node.meta["source_fn_stack"] = "source_fn_stack"
17811781
new_gm = Transformer(gm).transform()
17821782
for node in new_gm.graph.nodes:
17831783
if node.op == 'get_attr':
17841784
self.assertEqual(node.meta["nn_module_stack"], "self")
17851785
self.assertEqual(node.meta["stack_trace"], "stack_trace")
1786-
self.assertEqual(node.meta["source_fn"], "source_fn")
1786+
self.assertEqual(node.meta["source_fn_stack"], "source_fn_stack")
17871787

17881788

17891789
def test_interpreter(self):

torch/_dynamo/output_graph.py

Lines changed: 47 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -421,11 +421,13 @@ def remove_node(self, *args, **kwargs):
421421
return self.current_tracer.remove_node(*args, **kwargs)
422422

423423
@contextlib.contextmanager
424-
def new_subtracer(self):
424+
def new_subtracer(self, source_target):
425425
new_scope_ctx = enter_new_scope()
426426
try:
427427
new_scope_ctx.__enter__()
428-
tracer = SubgraphTracer(self, parent=self.current_tracer)
428+
tracer = SubgraphTracer(
429+
self, parent=self.current_tracer, source_target=source_target
430+
)
429431
self.tracers.append(tracer)
430432
yield tracer
431433
finally:
@@ -1171,7 +1173,9 @@ class SubgraphTracer(fx.Tracer):
11711173
compiling and executing the graph.
11721174
"""
11731175

1174-
def __init__(self, output_graph, parent=None, export_root=False):
1176+
def __init__(
1177+
self, output_graph, parent=None, export_root=False, source_target=None
1178+
):
11751179
super().__init__()
11761180
self.output_graph = weakref.proxy(output_graph)
11771181
self.graph = torch.fx.Graph()
@@ -1210,6 +1214,16 @@ def __init__(self, output_graph, parent=None, export_root=False):
12101214
self._orig_gm_meta = None
12111215
self._orig_gm_lineno_map = None
12121216
self._orig_gm_firstlineno = None
1217+
# Each SubgraphTracer is associated with a source target, which indicates
1218+
# which operator this subgraph is attached to. We compute a source_fn_stack
1219+
# based on the source tareget. For the root tracer, it's set to [].
1220+
# This is useful for debugging and transforming the exported graph.
1221+
if self.parent is None:
1222+
self.source_fn_stack = []
1223+
else:
1224+
self.source_fn_stack = self.parent.source_fn_stack + [
1225+
(self.graph._target_to_str(source_target), source_target)
1226+
]
12131227

12141228
def create_proxy(
12151229
self,
@@ -1305,6 +1319,24 @@ def get_trace_call_log_str():
13051319
self._orig_gm_meta = None
13061320
self._orig_gm_lineno_map = None
13071321
self._orig_gm_firstlineno = None
1322+
nn_module_stack = tx.nn_module_stack
1323+
if nn_module_stack:
1324+
rv.node.meta["nn_module_stack"] = nn_module_stack.copy()
1325+
1326+
if kind in {"call_function", "call_method"}:
1327+
rv.node.meta["source_fn_stack"] = self.source_fn_stack + [
1328+
(rv.node.name, target)
1329+
]
1330+
elif kind == "call_module":
1331+
if self.parent is not None:
1332+
unimplemented("Invoking an nn.Module inside HigherOrderOperator")
1333+
# For modules we store the class
1334+
rv.node.meta["source_fn_stack"] = self.source_fn_stack + [
1335+
(
1336+
rv.node.name,
1337+
rv.node.meta["nn_module_stack"][target][1],
1338+
)
1339+
]
13081340

13091341
# preserve original meta if it is available
13101342
if (
@@ -1322,26 +1354,30 @@ def get_trace_call_log_str():
13221354
meta = self._orig_gm_meta[node_idx]
13231355
if "stack_trace" in meta:
13241356
rv.node.meta["stack_trace"] = meta["stack_trace"]
1325-
if "nn_module_stack" in meta and "source_fn" in meta:
1357+
if "nn_module_stack" in meta and "source_fn_stack" in meta:
13261358
rv.node.meta["nn_module_stack"] = meta["nn_module_stack"]
1327-
rv.node.meta["source_fn"] = meta["source_fn"]
1359+
rv.node.meta["source_fn_stack"] = meta["source_fn_stack"]
13281360

13291361
if "nn_module_stack" not in rv.node.meta:
13301362
nn_module_stack = tx.nn_module_stack
13311363
if nn_module_stack:
13321364
rv.node.meta["nn_module_stack"] = nn_module_stack.copy()
13331365

1334-
if "source_fn" not in rv.node.meta:
1366+
if "source_fn_stack" not in rv.node.meta:
13351367
if kind in {"call_function", "call_method"}:
1336-
rv.node.meta["source_fn"] = (rv.node.name, target)
1368+
rv.node.meta["source_fn_stack"] = self.source_fn_stack + [
1369+
(rv.node.name, target)
1370+
]
13371371
elif kind == "call_module":
13381372
if self.parent is not None:
13391373
unimplemented("Invoking an nn.Module inside HigherOrderOperator")
13401374
# For modules we store the class
1341-
rv.node.meta["source_fn"] = (
1342-
rv.node.name,
1343-
rv.node.meta["nn_module_stack"][target][1],
1344-
)
1375+
rv.node.meta["source_fn_stack"] = self.source_fn_stack + [
1376+
(
1377+
rv.node.name,
1378+
rv.node.meta["nn_module_stack"][target][1],
1379+
)
1380+
]
13451381

13461382
if "stack_trace" not in rv.node.meta:
13471383
frame_summaries: List[traceback.FrameSummary] = []

0 commit comments

Comments
 (0)