Skip to content

Commit 7154176

Browse files
authored
Add basic ambiguity check in the tests (#9371)
### Summary This is also an example of how to use the new dim_order APIs to detect ambiguity when coming from PyTorch Tensor to ExecuTorch. ### Test plan ``` python -m unittest exir.tests.test_memory_format_ops_pass.TestMemoryFormatOpsPass ```
1 parent e0235f0 commit 7154176

File tree

2 files changed

+91
-1
lines changed

2 files changed

+91
-1
lines changed

exir/tests/test_memory_format_ops_pass.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from executorch.exir.pass_base import ExportPass, ProxyValue
2525

2626
from executorch.exir.tests.test_memory_format_ops_pass_utils import (
27+
AmbiguousDimOrderError,
2728
MemoryFormatOpsPassTestUtils,
2829
MemoryFormatTestSet,
2930
PropagateToCopyChannalsLastModule,
@@ -124,8 +125,34 @@ def test_op_dim_order_propagation(self) -> None:
124125
target_memory_format=torch.channels_last,
125126
_load_for_executorch_from_buffer=_load_for_executorch_from_buffer,
126127
),
128+
check_unambiguous_dim_order=True,
127129
)
128130

131+
def test_op_dim_order_propagation_ambiguous(self) -> None:
132+
try:
133+
MemoryFormatOpsPassTestUtils.memory_format_test_runner(
134+
self,
135+
MemoryFormatTestSet(
136+
module=PropagateToCopyChannalsLastModule().eval(),
137+
op=torch.ops.aten._to_copy.default,
138+
sample_input=(
139+
torch.rand_like(
140+
torch.zeros(
141+
[2, 1, 2, 2]
142+
), # Ambiguous shape should trigger AmbiguousDimOrderError!
143+
dtype=torch.float32,
144+
memory_format=torch.contiguous_format,
145+
),
146+
),
147+
target_memory_format=torch.channels_last,
148+
_load_for_executorch_from_buffer=_load_for_executorch_from_buffer,
149+
),
150+
check_unambiguous_dim_order=True,
151+
)
152+
AssertionError("Should have raised AmbiguousDimOrderError")
153+
except AmbiguousDimOrderError:
154+
pass # Expected error
155+
129156
# Only test dim order replacement result in lean mode test.
130157
# This test is irrelevant with operator mode.
131158
def test_dim_order_replacement(self) -> None:

exir/tests/test_memory_format_ops_pass_utils.py

Lines changed: 64 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,11 @@
2020
is_channel_last_dim_order,
2121
is_contiguous_dim_order,
2222
)
23+
from executorch.exir.pass_base import ExportPass
2324

2425
from torch.export import export
26+
27+
from torch.fx.passes.infra.pass_manager import PassManager
2528
from torch.testing import FileCheck
2629
from torch.utils._pytree import tree_flatten
2730

@@ -99,10 +102,67 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
99102
return t1 * t2
100103

101104

105+
class AmbiguousDimOrderError(RuntimeError):
106+
pass
107+
108+
109+
def assert_unambiguous_dim_order(gm):
110+
class ExampleNOPPass(ExportPass):
111+
"""
112+
Does nothing!
113+
"""
114+
115+
def call_operator(self, op, args, kwargs, meta):
116+
return super().call_operator(
117+
op,
118+
args,
119+
kwargs,
120+
meta,
121+
)
122+
123+
# This is an example of how one can detect ambiguous dim_order anywhere in the graph.
124+
# You can be surgical and only detect it in the nodes you are interested in or something else.
125+
def detect_ambiguity(gm):
126+
"""
127+
Check every node's output tensor dim_order and raise if it is ambiguous for a list of formats.
128+
"""
129+
130+
def get_tensors(node: torch.fx.Node) -> torch.Tensor:
131+
val = node.meta["val"]
132+
if isinstance(val, torch.Tensor):
133+
return [val]
134+
elif isinstance(val, (list, tuple)):
135+
return [tensor for tensor in val if isinstance(tensor, torch.Tensor)]
136+
return []
137+
138+
for node in gm.graph.nodes:
139+
if node.op == "call_function":
140+
for tensor in get_tensors(node):
141+
# Let's make sure dim_order is not ambiguous, raise otherwise.
142+
# This is raising because we can't do anything about it.
143+
# The right course of follow up action is to ask user to try with a different example input.
144+
try:
145+
_ = tensor.dim_order(
146+
ambiguity_check=[
147+
torch.contiguous_format,
148+
torch.channels_last,
149+
]
150+
)
151+
except Exception:
152+
raise AmbiguousDimOrderError
153+
154+
# any pass or passes, just using MemoryFormatOpsPass as an example
155+
dim_order_pass_manager = PassManager(passes=[ExampleNOPPass()])
156+
dim_order_pass_manager.add_checks(detect_ambiguity)
157+
dim_order_pass_manager(gm)
158+
159+
102160
class MemoryFormatOpsPassTestUtils:
103161
@staticmethod
104162
def memory_format_test_runner(
105-
test_class: unittest.TestCase, test_set: MemoryFormatTestSet
163+
test_class: unittest.TestCase,
164+
test_set: MemoryFormatTestSet,
165+
check_unambiguous_dim_order: bool = False,
106166
):
107167
before = export(
108168
test_set.module, test_set.sample_input, strict=True
@@ -121,6 +181,9 @@ def memory_format_test_runner(
121181
before, compile_config=EdgeCompileConfig(_skip_dim_order=False)
122182
)
123183

184+
if check_unambiguous_dim_order:
185+
assert_unambiguous_dim_order(epm.exported_program().graph_module)
186+
124187
# check memory format ops, if needed
125188
if test_set.op_level_check:
126189
aten_op_str, edge_op_str = MemoryFormatOps2Str[test_set.op]

0 commit comments

Comments
 (0)