Skip to content

Add unit tests for old lowering flow for op_cat.py #6847

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Nov 14, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 57 additions & 31 deletions backends/xnnpack/test/ops/cat.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,39 +36,45 @@ def forward(self, arg1, arg2, arg3, arg4, arg5):
return x + x # Quantize by propagation.

def _test_cat(self, module, inputs, cat_num=1, quant=False, quant_ops=2):
tester = Tester(module, inputs)

if quant:
tester.quantize()

tester.export().check_count({"torch.ops.aten.cat": 1})
tester.dump_artifact()

if quant:
# Expect multiple quantize ops - one per input, cat, and add.
tester.check_node_count(
{
# Q/DQ pair for each input and quantized op. For most tests, there are
# two quantized ops - cat and add.
torch.ops.quantized_decomposed.quantize_per_tensor.default: (
cat_num + quant_ops
)
}
for legacy_mode in (True, False):
tester = Tester(module, inputs)

if quant:
tester.quantize()

tester.export().check_count({"torch.ops.aten.cat": 1})
tester.dump_artifact()

if quant:
# Expect multiple quantize ops - one per input, cat, and add.
tester.check_node_count(
{
# Q/DQ pair for each input and quantized op. For most tests, there are
# two quantized ops - cat and add.
torch.ops.quantized_decomposed.quantize_per_tensor.default: (
cat_num + quant_ops
)
}
)


if legacy_mode:
tester.to_edge()
tester.partition()
else:
tester.to_edge_transform_and_lower()

if quant:
tester.check_not(["torch.ops.quantized_decomposed"])

(
tester.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
.check_not(["executorch_exir_dialects_edge__ops_aten_cat"])
.to_executorch()
.serialize()
.run_method_and_compare_outputs()
)

tester.to_edge_transform_and_lower()

if quant:
tester.check_not(["torch.ops.quantized_decomposed"])

(
tester.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
.check_not(["executorch_exir_dialects_edge__ops_aten_cat"])
.to_executorch()
.serialize()
.run_method_and_compare_outputs()
)

def test_fp16_cat2(self):
"""
Using Clamp2 because fp16 add is done in fp32 ATM. Need to fix that first.
Expand Down Expand Up @@ -155,6 +161,26 @@ def test_fp32_cat_unsupported(self):
.check_count({"executorch_exir_dialects_edge__ops_aten_cat": 1})
)

def test_fp32_cat_unsupported_legacy_mode(self):
"""
XNNPACK only supports concatenating up to 4 values, so it should not delegate here.
"""
inputs = (
torch.randn(1, 2, 3),
torch.randn(3, 2, 3),
torch.randn(2, 2, 3),
torch.randn(5, 2, 3),
torch.randn(1, 2, 3),
)
(
Tester(self.Cat5(), inputs)
.export()
.check_count({"torch.ops.aten.cat": 1})
.to_edge()
.partition()
.check_count({"executorch_exir_dialects_edge__ops_aten_cat": 1})
)

class CatNegativeDim(torch.nn.Module):
def __init__(self):
super().__init__()
Expand Down