Skip to content

Commit 645594e

Browse files
ydwu4facebook-github-bot
authored andcommitted
Replace node.meta source_fn with source_fn_stack (#210)
Summary: Pull Request resolved: #210 A resubmit of pytorch/pytorch#108447. Copy over the descriptions: This is a follow-up of the discussion in pytorch/pytorch#108356, where we want to repalce source_fn with source_fn_stack Before this PR, for the following example: ```python backend = EagerAndRecordGraphs() torch.compile(backend=backend, fullgraph=True) def cond_f(pred, pred2, x, y): def true_fn(pred2, x, y): return x + y def false_fn(pred2, x, y): def true_fn2(x, y): return x.sin() - y.cos() def false_fn2(x, y): return x.cos() - y.sin() return control_flow.cond(pred2, true_fn2, false_fn2, (x, y)) return control_flow.cond(pred, true_fn, false_fn, (pred2, x, y)) ``` The graph captured is shown below: ```python class GraphModule(torch.nn.Module): def forward(self, L_pred_ : torch.Tensor, L_pred2_ : torch.Tensor, L_x_ : torch.Tensor, L_y_ : torch.Tensor): l_pred_ = L_pred_ l_pred2_ = L_pred2_ l_x_ = L_x_ l_y_ = L_y_ cond_true_1 = self.cond_true_1 cond_false_1 = self.cond_false_1 cond = torch.ops.higher_order.cond(l_pred_, cond_true_1, cond_false_1, [l_pred2_, l_x_, l_y_]); l_pred_ = cond_true_1 = cond_false_1 = l_pred2_ = l_x_ = l_y_ = None return (cond,) class GraphModule(torch.nn.Module): def forward(self, l_pred2_, l_x_, l_y_): add = l_x_ + l_y_; l_x_ = l_y_ = None return add class GraphModule(torch.nn.Module): def forward(self, l_pred2_, l_x_, l_y_): cond_true_0 = self.cond_true_0 cond_false_0 = self.cond_false_0 cond = torch.ops.higher_order.cond(l_pred2_, cond_true_0, cond_false_0, [l_x_, l_y_]); l_pred2_ = cond_true_0 = cond_false_0 = l_x_ = l_y_ = None return cond class GraphModule(torch.nn.Module): def forward(self, l_x_, l_y_): sin = l_x_.sin(); l_x_ = None cos = l_y_.cos(); l_y_ = None sub = sin - cos; sin = cos = None return sub class GraphModule(torch.nn.Module): def forward(self, l_x_, l_y_): cos = l_x_.cos(); l_x_ = None sin = l_y_.sin(); l_y_ = None sub = cos - sin; cos = sin = None return sub ``` the source_fn for inner cond, sin, cos will be a (name, target) tuple: ``` ('cond', <torch._ops.HigherOrderOperator object at xxx>) ('sin', 'sin') ('cos', 'cos') ('sub'. <built-in function sub>) ``` After this pr, the source_fn_stack will be a list of (name, target) tuple. The bottom of stack is the end of the list. ``` [('cond', <torch._ops.HigherOrderOperator object at xxx>), ('cond', <torch._ops.HigherOrderOperator object at xxx>)], [('cond', <torch._ops.HigherOrderOperator object at xxx>), ('cond', <torch._ops.HigherOrderOperator object at xxx>), ('sin', 'sin')], [('cond', <torch._ops.HigherOrderOperator object at xxx>), ('cond', <torch._ops.HigherOrderOperator object at xxx>), ('cos', 'cos')] [('cond', <torch._ops.HigherOrderOperator object at xxx>), ('cond', <torch._ops.HigherOrderOperator object at xxx>), ('sub', <built-in function sub>)] ``` X-link: pytorch/pytorch#108595 Test Plan: See added tests in test_higher_order_ops.py and modify existing test. Also updated bin by running: "buck2 run @//mode/dev-nosan fbcode//aibench/api:gen_test_files --config client.id=nuclide" cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy ngimel yf225 chenyang78 kadeng muchulee8 aakhundov Reviewed By: angelayi Differential Revision: D48984986 Pulled By: ydwu4 fbshipit-source-id: 12beb1cee5c6a5b7dc976eda54baaa8f564a0f1e
1 parent 61077c4 commit 645594e

File tree

6 files changed

+15
-11
lines changed

6 files changed

+15
-11
lines changed

backends/xnnpack/partition/xnnpack_partitioner.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -94,8 +94,8 @@ def __init__(
9494
assert len(self.constraints)
9595

9696
def check_common_constraints(self, node) -> bool:
97-
if self.unsupported_modules and "source_fn" in node.meta:
98-
return not node.meta["source_fn"][1] in self.unsupported_modules
97+
if self.unsupported_modules and "source_fn_stack" in node.meta:
98+
return not node.meta["source_fn_stack"][-1][1] in self.unsupported_modules
9999

100100
return True
101101

backends/xnnpack/passes/tag_implicit_q_dq_pass.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -89,8 +89,9 @@ def is_supported_quant_op(self, node: torch.fx.Node) -> bool:
8989

9090
def is_supported_quant_module(self, node: torch.fx.Node) -> bool:
9191
is_supported = (
92-
"source_fn" in node.meta
93-
and node.meta["source_fn"][1] in SUPPORTED_IMPLICIT_Q_DQ_MODULES_SET
92+
"source_fn_stack" in node.meta
93+
and node.meta["source_fn_stack"][-1][1]
94+
in SUPPORTED_IMPLICIT_Q_DQ_MODULES_SET
9495
)
9596
if is_supported and self.is_supported_quant_op(node):
9697
raise RuntimeError(

docs/website/docs/ir_spec/00_exir.md

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -247,12 +247,14 @@ the following metadata fields:
247247
{'self_linear': ('self.linear', <class 'torch.nn.Linear'>), 'self_sequential': ('self.sequential', <class 'torch.nn.Sequential'>)}
248248
```
249249

250-
* `node.meta["source_fn"]` contains the torch function or the leaf
250+
* `node.meta["source_fn_stack"]` contains the stack of torch function or the leaf
251251
`torch.nn.Module` class this node was called from before decomposition. For
252252
example, a node containing the `addmm` op from a `torch.nn.Linear` module call
253253
would contain `torch.nn.Linear` in their `source_fn`, and a node containing
254254
the `addmm` op from a `torch.nn.functional.Linear` module call would contain
255-
`torch.nn.functional.Linear` in their `source_fn`.
255+
`torch.nn.functional.Linear` in their `source_fn`. The stack records the higher order
256+
operator stack that this source_fn belongs to. For example, if a `torch.nn.Linear` module
257+
call is within the true branch of `cond`, then the stack will contain `['cond', 'torch.nn.Linear']`.
256258

257259

258260
### placeholder

exir/backend/test/test_backends.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -819,7 +819,7 @@ def forward(self, a, x, b):
819819
self.assertTrue(
820820
user.op == "call_function" and user.target == operator.getitem
821821
)
822-
self.assertTrue(user.meta.get("source_fn", None) is None)
822+
self.assertTrue(user.meta.get("source_fn_stack", None) is None)
823823
self.assertTrue(user.meta.get("nn_module_stack", None) is None)
824824

825825
executorch_prog = executorch_prog.to_executorch(
@@ -886,7 +886,7 @@ def forward(self, x, y):
886886
self.assertTrue(
887887
user.op == "call_function" and user.target == operator.getitem
888888
)
889-
self.assertTrue(user.meta.get("source_fn", None) is None)
889+
self.assertTrue(user.meta.get("source_fn_stack", None) is None)
890890
self.assertTrue(user.meta.get("nn_module_stack", None) is None)
891891

892892
executorch_prog = executorch_prog.to_executorch(

exir/lowered_backend_module.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -489,7 +489,7 @@ def create_submodule_from_nodes(
489489
# (ex. source_fn, # nn_module_stack)
490490
for user_node in submodule_node.users:
491491
user_node.meta.pop("nn_module_stack", None)
492-
user_node.meta.pop("source_fn", None)
492+
user_node.meta.pop("source_fn_stack", None)
493493

494494
erase_nodes(gm, sorted_nodes)
495495

sdk/edir/et_schema.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
class RESERVED_METADATA_ARG(Enum):
3535
DEBUG_HANDLE = "debug_handle"
3636
MODULE_STACK = "nn_module_stack"
37-
SOURCE_FN = "source_fn"
37+
SOURCE_FN_STACK = "source_fn_stack"
3838
MODULE_TYPE = "module_type"
3939
PROFILE_START_TIME = "profile_start_time"
4040
PROFILE_END_TIME = "profile_end_time"
@@ -463,9 +463,10 @@ def _update_module_mapping(
463463
metadata: Dict[str, Any],
464464
):
465465
if (
466-
source_fn := metadata.get("source_fn")
466+
source_fn_stack := metadata.get("source_fn_stack")
467467
) is not None and "nn_module_stack" in metadata:
468468
# (module name, module type)
469+
source_fn = source_fn_stack[-1]
469470
module_type = (
470471
source_fn[1] if isinstance(source_fn[1], str) else source_fn[1].__name__
471472
)

0 commit comments

Comments
 (0)