Skip to content

Commit 1be038d

Browse files
suchir1facebook-github-bot
authored andcommitted
Add unit tests for old lowering flow for op_cat.py (pytorch#6847)
Summary: The team moved to leveraging a new API which allows them to improve the reliability of our lowering infra. Lowering here refers to converting a PyTorch model that's recognizable by the underlying hardware. This diff makes sure there are still unit tests for the older APIs. Reviewed By: mcr229 Differential Revision: D65914291
1 parent ecdc007 commit 1be038d

File tree

1 file changed

+57
-31
lines changed
  • backends/xnnpack/test/ops

1 file changed

+57
-31
lines changed

backends/xnnpack/test/ops/cat.py

Lines changed: 57 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -36,39 +36,45 @@ def forward(self, arg1, arg2, arg3, arg4, arg5):
3636
return x + x # Quantize by propagation.
3737

3838
def _test_cat(self, module, inputs, cat_num=1, quant=False, quant_ops=2):
39-
tester = Tester(module, inputs)
40-
41-
if quant:
42-
tester.quantize()
43-
44-
tester.export().check_count({"torch.ops.aten.cat": 1})
45-
tester.dump_artifact()
46-
47-
if quant:
48-
# Expect multiple quantize ops - one per input, cat, and add.
49-
tester.check_node_count(
50-
{
51-
# Q/DQ pair for each input and quantized op. For most tests, there are
52-
# two quantized ops - cat and add.
53-
torch.ops.quantized_decomposed.quantize_per_tensor.default: (
54-
cat_num + quant_ops
55-
)
56-
}
39+
for legacy_mode in (True, False):
40+
tester = Tester(module, inputs)
41+
42+
if quant:
43+
tester.quantize()
44+
45+
tester.export().check_count({"torch.ops.aten.cat": 1})
46+
tester.dump_artifact()
47+
48+
if quant:
49+
# Expect multiple quantize ops - one per input, cat, and add.
50+
tester.check_node_count(
51+
{
52+
# Q/DQ pair for each input and quantized op. For most tests, there are
53+
# two quantized ops - cat and add.
54+
torch.ops.quantized_decomposed.quantize_per_tensor.default: (
55+
cat_num + quant_ops
56+
)
57+
}
58+
)
59+
60+
61+
if legacy_mode:
62+
tester.to_edge()
63+
tester.partition()
64+
else:
65+
tester.to_edge_transform_and_lower()
66+
67+
if quant:
68+
tester.check_not(["torch.ops.quantized_decomposed"])
69+
70+
(
71+
tester.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
72+
.check_not(["executorch_exir_dialects_edge__ops_aten_cat"])
73+
.to_executorch()
74+
.serialize()
75+
.run_method_and_compare_outputs()
5776
)
5877

59-
tester.to_edge_transform_and_lower()
60-
61-
if quant:
62-
tester.check_not(["torch.ops.quantized_decomposed"])
63-
64-
(
65-
tester.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
66-
.check_not(["executorch_exir_dialects_edge__ops_aten_cat"])
67-
.to_executorch()
68-
.serialize()
69-
.run_method_and_compare_outputs()
70-
)
71-
7278
def test_fp16_cat2(self):
7379
"""
7480
Using Clamp2 because fp16 add is done in fp32 ATM. Need to fix that first.
@@ -155,6 +161,26 @@ def test_fp32_cat_unsupported(self):
155161
.check_count({"executorch_exir_dialects_edge__ops_aten_cat": 1})
156162
)
157163

164+
def test_fp32_cat_unsupported_legacy_mode(self):
165+
"""
166+
XNNPACK only supports concatenating up to 4 values, so it should not delegate here.
167+
"""
168+
inputs = (
169+
torch.randn(1, 2, 3),
170+
torch.randn(3, 2, 3),
171+
torch.randn(2, 2, 3),
172+
torch.randn(5, 2, 3),
173+
torch.randn(1, 2, 3),
174+
)
175+
(
176+
Tester(self.Cat5(), inputs)
177+
.export()
178+
.check_count({"torch.ops.aten.cat": 1})
179+
.to_edge()
180+
.partition()
181+
.check_count({"executorch_exir_dialects_edge__ops_aten_cat": 1})
182+
)
183+
158184
class CatNegativeDim(torch.nn.Module):
159185
def __init__(self):
160186
super().__init__()

0 commit comments

Comments
 (0)