Skip to content

Commit b014f1f

Browse files
ydwu4facebook-github-bot
authored andcommitted
Replace node.meta source_fn with source_fn_stack (#210)
Summary: A resubmit of pytorch/pytorch#108447. Copy over the descriptions: This is a follow-up of the discussion in pytorch/pytorch#108356, where we want to repalce source_fn with source_fn_stack X-link: pytorch/pytorch#108595 Test Plan: See added tests in test_higher_order_ops.py and modify existing test. Differential Revision: D48984986 Pulled By: ydwu4
1 parent d2e6750 commit b014f1f

File tree

4 files changed

+12
-8
lines changed

4 files changed

+12
-8
lines changed

backends/xnnpack/partition/xnnpack_partitioner.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -87,8 +87,8 @@ def __init__(
8787
assert len(self.constraints)
8888

8989
def check_common_constraints(self, node) -> bool:
90-
if self.unsupported_modules and "source_fn" in node.meta:
91-
return not node.meta["source_fn"][1] in self.unsupported_modules
90+
if self.unsupported_modules and "source_fn_stack" in node.meta:
91+
return not node.meta["source_fn_stack"][-1][1] in self.unsupported_modules
9292

9393
return True
9494

backends/xnnpack/passes/tag_implicit_q_dq_pass.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,8 +90,9 @@ def is_supported_quant_op(self, node: torch.fx.Node) -> bool:
9090

9191
def is_supported_quant_module(self, node: torch.fx.Node) -> bool:
9292
is_supported = (
93-
"source_fn" in node.meta
94-
and node.meta["source_fn"][1] in SUPPORTED_IMPLICIT_Q_DQ_MODULES_SET
93+
"source_fn_stack" in node.meta
94+
and node.meta["source_fn_stack"][-1][1]
95+
in SUPPORTED_IMPLICIT_Q_DQ_MODULES_SET
9596
)
9697
if is_supported and self.is_supported_quant_op(node):
9798
raise RuntimeError(

docs/website/docs/ir_spec/00_exir.md

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -247,12 +247,14 @@ the following metadata fields:
247247
{'self_linear': ('self.linear', <class 'torch.nn.Linear'>), 'self_sequential': ('self.sequential', <class 'torch.nn.Sequential'>)}
248248
```
249249

250-
* `node.meta["source_fn"]` contains the torch function or the leaf
250+
* `node.meta["source_fn_stack"]` contains the stack of torch function or the leaf
251251
`torch.nn.Module` class this node was called from before decomposition. For
252252
example, a node containing the `addmm` op from a `torch.nn.Linear` module call
253253
would contain `torch.nn.Linear` in their `source_fn`, and a node containing
254254
the `addmm` op from a `torch.nn.functional.Linear` module call would contain
255-
`torch.nn.functional.Linear` in their `source_fn`.
255+
`torch.nn.functional.Linear` in their `source_fn`. The stack records the higher order
256+
operator stack that this source_fn belongs to. For example, if a `torch.nn.Linear` module
257+
call is within the true branch of `cond`, then the stack will contain `['cond', 'torch.nn.Linear']`.
256258

257259

258260
### placeholder

sdk/edir/et_schema.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
class RESERVED_METADATA_ARG(Enum):
3535
DEBUG_HANDLE = "debug_handle"
3636
MODULE_STACK = "module_stack"
37-
SOURCE_FN = "source_fn"
37+
SOURCE_FN_STACK = "source_fn_stack"
3838
MODULE_TYPE = "module_type"
3939
PROFILE_START_TIME = "profile_start_time"
4040
PROFILE_END_TIME = "profile_end_time"
@@ -463,9 +463,10 @@ def _update_module_mapping(
463463
metadata: Dict[str, Any],
464464
):
465465
if (
466-
source_fn := metadata.get("source_fn")
466+
source_fn_stack := metadata.get("source_fn_stack")
467467
) is not None and "nn_module_stack" in metadata:
468468
# (module name, module type)
469+
source_fn = source_fn_stack[-1]
469470
module_type = (
470471
source_fn[1] if isinstance(source_fn[1], str) else source_fn[1].__name__
471472
)

0 commit comments

Comments
 (0)