Skip to content

Add sub tests #94

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions backends/xnnpack/test/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ python_unittest(
name = "test_xnnpack_ops",
srcs = [
"ops/add.py",
"ops/sub.py",
],
deps = [
"//caffe2:torch",
Expand Down
92 changes: 92 additions & 0 deletions backends/xnnpack/test/ops/sub.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import unittest

import torch
from executorch.backends.xnnpack.partition.xnnpack_partitioner import (
XnnpackQuantizedPartitioner2,
)
from executorch.backends.xnnpack.test.tester import Partition, Tester


class TestXNNPACKSub(unittest.TestCase):
class SubModule(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, x, y):
z = x - y
z = z - y
z = x - z
w = z - z
z = z - w
return z

def test_sub(self):
sub_module = self.SubModule()
model_inputs = (torch.randn(2, 3), torch.randn(2, 3))

(
Tester(sub_module, model_inputs)
.export()
.check_count({"torch.ops.aten.sub.Tensor": 5})
.to_edge()
.check_count({"executorch_exir_dialects_edge__ops_aten_sub_Tensor": 5})
.partition()
.check_count({"torch.ops.executorch_call_delegate": 1})
.check_not(["executorch_exir_dialects_edge__ops_aten_sub_Tensor"])
.to_executorch()
.serialize()
.run_method()
.compare_outputs()
)

def test_sub_quantized(self):
sub_module = self.SubModule()
model_inputs = (torch.randn(2, 3), torch.randn(2, 3))

(
Tester(sub_module, model_inputs)
.quantize()
.check(["torch.ops.quantized_decomposed"])
.export()
.check_count({"torch.ops.aten.sub.Tensor": 5})
.to_edge()
.check_count({"executorch_exir_dialects_edge__ops_aten_sub_Tensor": 5})
.partition(Partition(partitioner=XnnpackQuantizedPartitioner2))
.check_count({"torch.ops.executorch_call_delegate": 1})
.check_not(["executorch_exir_dialects_edge__ops_aten_sub_Tensor"])
.check_not(["torch.ops.quantized_decomposed"])
.to_executorch()
.serialize()
.run_method()
.compare_outputs()
)

# Skipping since annotate patterns for sub are missing
@unittest.expectedFailure
def test_sub_quantized_pt2e(self):
sub_module = self.SubModule()
model_inputs = (torch.randn(2, 3), torch.randn(2, 3))

(
Tester(sub_module, model_inputs)
.export()
.check_count({"torch.ops.aten.sub.Tensor": 5})
.quantize2()
.check(["torch.ops.quantized_decomposed"])
.to_edge()
.check_count({"executorch_exir_dialects_edge__ops_aten_sub_Tensor": 5})
.partition(Partition(partitioner=XnnpackQuantizedPartitioner2))
.check_count({"torch.ops.executorch_call_delegate": 1})
.check_not(["executorch_exir_dialects_edge__ops_aten_sub_Tensor"])
.check_not(["torch.ops.quantized_decomposed"])
.to_executorch()
.serialize()
.run_method()
.compare_outputs()
)
18 changes: 0 additions & 18 deletions backends/xnnpack/test/test_xnnpack.py
Original file line number Diff line number Diff line change
Expand Up @@ -623,24 +623,6 @@ def forward(self, x, y):

self.lower_and_test_with_partitioner(mul_module, model_inputs)

def test_xnnpack_backend_sub(self):
class Sub(torch.nn.Module):
def __init__(self):
super().__init__()
self.sub = torch.sub

def forward(self, x, y):
return self.sub(x, y)

module = Sub()
M = torch.randn(2, 3)
N = torch.randn(2, 3)
model_inputs = (
M,
N,
)
self.lower_and_test_with_partitioner(module, model_inputs)

def test_xnnpack_backend_floor(self):
model_inputs = (torch.randn(1, 3, 3),)
self.lower_and_test_with_partitioner(torch.floor, model_inputs)
Expand Down