Skip to content

Test #578

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 6 commits into from
Closed

Test #578

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .ci/docker/ci_commit_pins/nightly.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
dev20231002
dev20231005
6 changes: 4 additions & 2 deletions backends/xnnpack/operators/node_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,7 +423,7 @@ def define_nodes_tensor_inputs_outputs(
inp,
xnn_graph,
vals_to_ids,
quant_params=QuantParams.from_inputs(inp),
quant_params=QuantParams.from_inputs(inp, self._exported_program),
convert_to_nhwc=convert_to_nhwc,
)
else:
Expand All @@ -434,7 +434,9 @@ def define_nodes_tensor_inputs_outputs(
)
# Define Input Node
input_node = get_input_node(node, input_type_map.node_input)
input_quant_params = QuantParams.from_inputs(input_node)
input_quant_params = QuantParams.from_inputs(
input_node, self._exported_program
)
self.define_tensor(
input_node,
xnn_graph,
Expand Down
4 changes: 2 additions & 2 deletions backends/xnnpack/operators/op_add.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def define_node(
input1,
xnn_graph,
vals_to_ids,
quant_params=QuantParams.from_inputs(input1),
quant_params=QuantParams.from_inputs(input1, self._exported_program),
)
input1_id = vals_to_ids[input1]

Expand All @@ -53,7 +53,7 @@ def define_node(
input2,
xnn_graph,
vals_to_ids,
quant_params=QuantParams.from_inputs(input2),
quant_params=QuantParams.from_inputs(input2, self._exported_program),
)
input2_id = vals_to_ids[input2]

Expand Down
4 changes: 3 additions & 1 deletion backends/xnnpack/operators/op_cat.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,9 @@ def define_node(
tensor_input,
xnn_graph,
vals_to_ids,
quant_params=QuantParams.from_inputs(tensor_input),
quant_params=QuantParams.from_inputs(
tensor_input, self._exported_program
),
)

self.define_tensor(
Expand Down
2 changes: 1 addition & 1 deletion backends/xnnpack/operators/op_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def define_node(
kwargs = {}
# input
input_node = get_input_node(node, 0)
input_quant_params = QuantParams.from_inputs(input_node)
input_quant_params = QuantParams.from_inputs(input_node, self._exported_program)
self.define_tensor(
input_node,
xnn_graph,
Expand Down
4 changes: 2 additions & 2 deletions backends/xnnpack/operators/op_multiply.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def define_node(
input1,
xnn_graph,
vals_to_ids,
quant_params=QuantParams.from_inputs(input1),
quant_params=QuantParams.from_inputs(input1, self._exported_program),
)
input1_id = vals_to_ids[input1]

Expand All @@ -53,7 +53,7 @@ def define_node(
input2,
xnn_graph,
vals_to_ids,
quant_params=QuantParams.from_inputs(input2),
quant_params=QuantParams.from_inputs(input2, self._exported_program),
)
input2_id = vals_to_ids[input2]

Expand Down
4 changes: 2 additions & 2 deletions backends/xnnpack/operators/op_sub.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def define_node(
input1,
xnn_graph,
vals_to_ids,
quant_params=QuantParams.from_inputs(input1),
quant_params=QuantParams.from_inputs(input1, self._exported_program),
)
input1_id = vals_to_ids[input1]

Expand All @@ -53,7 +53,7 @@ def define_node(
input2,
xnn_graph,
vals_to_ids,
quant_params=QuantParams.from_inputs(input2),
quant_params=QuantParams.from_inputs(input2, self._exported_program),
)
input2_id = vals_to_ids[input2]

Expand Down
2 changes: 1 addition & 1 deletion backends/xnnpack/operators/op_to_copy.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def define_node(
)

input_node = get_input_node(node, 0)
input_quant_params = QuantParams.from_inputs(input_node)
input_quant_params = QuantParams.from_inputs(input_node, self._exported_program)
output_quant_params = QuantParams.from_outputs(node)

permute_order = PERM_NCHW_TO_NHWC if to_channels_last else PERM_NHWC_TO_NCHW
Expand Down
12 changes: 10 additions & 2 deletions backends/xnnpack/operators/quant_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,9 @@
import torch
from executorch.backends.xnnpack.passes.tag_implicit_q_dq_pass import TagImplicitQDqPass
from executorch.backends.xnnpack.utils.quant_utils import is_dequant, is_quant
from executorch.backends.xnnpack.utils.utils import check_or_raise
from executorch.backends.xnnpack.utils.utils import check_or_raise, is_param_node
from executorch.exir.dialects._ops import ops as exir_ops
from torch.export import ExportedProgram


class QuantParams:
Expand Down Expand Up @@ -178,11 +179,18 @@ def from_weights(cls, tensor_node: torch.fx.Node) -> Optional[QuantParams]:
return cls.from_q_dq_node(q)

@classmethod
def from_inputs(cls, tensor_node: torch.fx.Node) -> Optional[QuantParams]:
def from_inputs(
cls, tensor_node: torch.fx.Node, ep: ExportedProgram
) -> Optional[QuantParams]:
# tensor_node is quantized if it is produced by a dequant node
if is_dequant(tensor_node) and TagImplicitQDqPass.is_tagged_as_implicit_q_dq(
tensor_node
):
dq_input = cast(torch.fx.Node, tensor_node.args[0])
if is_quant(dq_input):
q_input = cast(torch.fx.Node, dq_input.args[0])
if is_param_node(ep, q_input):
return cls.from_q_dq_node(dq_input)
return cls.from_q_dq_node(tensor_node)

return None
Expand Down
2 changes: 2 additions & 0 deletions backends/xnnpack/test/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,8 @@ python_unittest(
]),
tags = ["long_running"],
deps = [
"fbsource//third-party/pypi/timm:timm",
"fbsource//third-party/pypi/torchsr:torchsr", # @manual
"//caffe2:torch",
"//executorch/backends/xnnpack/test/tester:tester",
"//pytorch/vision:torchvision",
Expand Down
42 changes: 42 additions & 0 deletions backends/xnnpack/test/models/edsr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import unittest

import torch

from executorch.backends.xnnpack.test.tester import Tester
from torchsr.models import edsr_r16f64


class TestEDSR(unittest.TestCase):
edsr = edsr_r16f64(2, False).eval() # noqa
model_inputs = (torch.ones(1, 3, 224, 224),)

def test_fp32_edsr(self):
(
Tester(self.edsr, self.model_inputs)
.export()
.to_edge()
.partition()
.to_executorch()
.serialize()
.run_method()
.compare_outputs()
)

def test_qs8_edsr(self):
(
Tester(self.edsr, self.model_inputs)
.quantize()
.export()
.to_edge()
.partition()
.to_executorch()
.serialize()
.run_method()
.compare_outputs()
)
67 changes: 67 additions & 0 deletions backends/xnnpack/test/models/inception_v3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import unittest

import torch
import torchvision.models as models
from executorch.backends.xnnpack.test.tester import Tester


class TestInceptionV3(unittest.TestCase):
# pyre-ignore
ic3 = models.inception_v3(weights="IMAGENET1K_V1").eval() # noqa
model_inputs = (torch.ones(1, 3, 224, 224),)

all_operators = {
"executorch_exir_dialects_edge__ops_aten_addmm_default",
"executorch_exir_dialects_edge__ops_aten_add_Tensor",
"executorch_exir_dialects_edge__ops_aten_cat_default",
"executorch_exir_dialects_edge__ops_aten_convolution_default",
"executorch_exir_dialects_edge__ops_aten_max_pool2d_with_indices_default",
# "executorch.exir.dialects.edge._ops.aten.avg_pool2d.default", Currently do not have avg_pool2d partitioned
"executorch_exir_dialects_edge__ops_aten_mean_dim",
"executorch_exir_dialects_edge__ops_aten__native_batch_norm_legit_no_training_default",
"executorch_exir_dialects_edge__ops_aten_permute_copy_default",
"executorch_exir_dialects_edge__ops_aten_relu_default",
}

def test_fp32_ic3(self):

(
Tester(self.ic3, self.model_inputs)
.export()
.to_edge()
.check(list(self.all_operators))
.partition()
.check(["torch.ops.executorch_call_delegate"])
.check_not(list(self.all_operators))
.to_executorch()
.serialize()
.run_method()
.compare_outputs()
)

def test_qs8_ic3(self):
# Quantization fuses away batchnorm, so it is no longer in the graph
ops_after_quantization = self.all_operators - {
"executorch_exir_dialects_edge__ops_aten__native_batch_norm_legit_no_training_default",
}

(
Tester(self.ic3, self.model_inputs)
.quantize()
.export()
.to_edge()
.check(list(ops_after_quantization))
.partition()
.check(["torch.ops.executorch_call_delegate"])
.check_not(list(ops_after_quantization))
.to_executorch()
.serialize()
.run_method()
.compare_outputs()
)
65 changes: 65 additions & 0 deletions backends/xnnpack/test/models/inception_v4.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import unittest

import torch
from executorch.backends.xnnpack.test.tester import Tester
from timm.models import inception_v4


class TestInceptionV4(unittest.TestCase):
ic4 = inception_v4(pretrained=False).eval()
model_inputs = (torch.ones(3, 299, 299).unsqueeze(0),)

all_operators = {
"executorch_exir_dialects_edge__ops_aten_addmm_default",
# "executorch.exir.dialects.edge._ops.aten.avg_pool2d.default", Currently do not have avg_pool2d partitioned
"executorch_exir_dialects_edge__ops_aten_cat_default",
"executorch_exir_dialects_edge__ops_aten_convolution_default",
"executorch_exir_dialects_edge__ops_aten_max_pool2d_with_indices_default",
"executorch_exir_dialects_edge__ops_aten_mean_dim",
"executorch_exir_dialects_edge__ops_aten__native_batch_norm_legit_no_training_default",
"executorch_exir_dialects_edge__ops_aten_permute_copy_default",
"executorch_exir_dialects_edge__ops_aten_relu_default",
}

def test_fp32_ic4(self):

(
Tester(self.ic4, self.model_inputs)
.export()
.to_edge()
.check(list(self.all_operators))
.partition()
.check(["torch.ops.executorch_call_delegate"])
.check_not(list(self.all_operators))
.to_executorch()
.serialize()
.run_method()
.compare_outputs()
)

def test_qs8_ic4(self):
# Quantization fuses away batchnorm, so it is no longer in the graph
ops_after_quantization = self.all_operators - {
"executorch_exir_dialects_edge__ops_aten__native_batch_norm_legit_no_training_default",
}

(
Tester(self.ic4, self.model_inputs)
.quantize()
.export()
.to_edge()
.check(list(ops_after_quantization))
.partition()
.check(["torch.ops.executorch_call_delegate"])
.check_not(list(ops_after_quantization))
.to_executorch()
.serialize()
.run_method()
.compare_outputs()
)
6 changes: 5 additions & 1 deletion backends/xnnpack/test/tester/tester.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,10 @@
get_symmetric_quantization_config,
XNNPACKQuantizer,
)
from torch.ao.quantization.quantizer.xnnpack_quantizer_utils import QuantizationConfig
from torch.ao.quantization.quantizer.xnnpack_quantizer_utils import (
convert_scalars_to_attrs,
QuantizationConfig,
)
from torch.testing import FileCheck
from torch.utils._pytree import tree_flatten

Expand Down Expand Up @@ -105,6 +108,7 @@ def run(
self, artifact: torch.nn.Module, inputs: Optional[Tuple[torch.Tensor]]
) -> None:
captured_graph = export.capture_pre_autograd_graph(artifact, inputs)
captured_graph = convert_scalars_to_attrs(captured_graph)
prepared = prepare_pt2e(captured_graph, self.quantizer)
converted = convert_pt2e(prepared)
self.converted_graph = converted
Expand Down
2 changes: 1 addition & 1 deletion install_requirements.sh
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ pip install .
# models in executorch/examples/models.
# The version in this file will be the correct version for the
# corresponsing version of the repo.
NIGHTLY_VERSION=dev20231002
NIGHTLY_VERSION=dev20231005

TORCH_VERSION=2.2.0.${NIGHTLY_VERSION}
pip install --force-reinstall --pre torch=="${TORCH_VERSION}" -i https://download.pytorch.org/whl/nightly/cpu
Expand Down
2 changes: 1 addition & 1 deletion third-party/pytorch
Submodule pytorch updated 308 files