Skip to content

Commit da1f29b

Browse files
kirklandsignfacebook-github-bot
authored andcommitted
example with mv2 (#64)
Summary: Pull Request resolved: #64 Adding export example for XNNPACK delegated models, also adding to executor runner to run Reviewed By: guangy10 Differential Revision: D48371417 fbshipit-source-id: 836e49c020aec880799fdd635b6c71f6145a0536
1 parent 236675d commit da1f29b

File tree

11 files changed

+234
-34
lines changed

11 files changed

+234
-34
lines changed

.ci/scripts/gather_test_models.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@
99
import os
1010
from typing import Any
1111

12-
from examples.models import MODEL_NAME_TO_MODEL
13-
from examples.quantization.example import QUANT_MODEL_NAME_TO_MODEL
12+
from examples.models import MODEL_NAME_TO_MODEL, MODEL_NAME_TO_OPTIONS
13+
from executorch.examples.models.models import MODEL_NAME_TO_OPTIONS
1414

1515
BUILD_TOOLS = [
1616
"buck2",
@@ -39,7 +39,9 @@ def export_models_for_ci() -> None:
3939
# https://docs.github.com/en/actions/using-jobs/using-a-matrix-for-your-jobs
4040
models = {"include": []}
4141
for name in MODEL_NAME_TO_MODEL.keys():
42-
quantization = name in QUANT_MODEL_NAME_TO_MODEL
42+
quantization = (
43+
name in MODEL_NAME_TO_OPTIONS and MODEL_NAME_TO_OPTIONS[name].quantization
44+
)
4345
for build_tool in BUILD_TOOLS:
4446
models["include"].append(
4547
{"build-tool": build_tool, "model": name, "quantization": quantization}

examples/README.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ It also includes a list of modules, from a simple `Add` to a full model like `Mo
77
## Directory structure
88
```bash
99
examples
10+
|── backend # Contains examples for exporting delegate models and running them using custom executor runners
1011
├── custom_ops # Contains examples to register custom operators into PyTorch as well as register its kernels into Executorch runtime
1112
├── executor_runner # This is an example C++ wrapper around the ET runtime
1213
├── export # Python helper scripts to illustrate export workflow
@@ -71,6 +72,9 @@ you can also find the valid quantized example models by running:
7172
buck2 run executorch/examples/quantization:example -- --help
7273
```
7374

75+
## XNNPACK Backend
76+
Please see [Backend README](backend/README) for XNNPACK quantization, export, and run workflow.
77+
7478
## Dependencies
7579

7680
Various models listed in this directory have dependencies on some other packages, e.g. torchvision, torchaudio.

examples/backend/README

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
This README gives some examples on backend-specific model workflow.
2+
3+
# XNNPACK Backend
4+
5+
[XNNPACK](https://github.com/google/XNNPACK) is a library of optimized of neural network inference operators for ARM and x86 platforms. Our delegate
6+
lowers models to run using these highly optimized CPU operators. You can try out lowering and running some example
7+
models using the following command:
8+
9+
```
10+
python3 -m examples.backend.xnnpack_examples --model_name="mv2" --delegate
11+
# For quantized model
12+
python3 -m examples.backend.xnnpack_examples --model_name="mv2" --quantize --delegate
13+
```
14+
15+
This will produce an xnnpack_mv2.pte model that can be run using XNNPACK's operators. This will also print out
16+
the lowered graph, showing what parts of the models have been lowered to XNNPACK via executorch_call_delegate.
17+
18+
You can run the model by running:
19+
20+
```
21+
buck2 run examples/backend:xnn_executor_runner --model_name="mv2"
22+
```

examples/backend/TARGETS

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")
2+
load(":targets.bzl", "define_common_targets")
3+
4+
oncall("executorch")
5+
6+
define_common_targets()
7+
8+
runtime.python_binary(
9+
name = "xnnpack_examples",
10+
main_src = "xnnpack_examples.py",
11+
deps = [
12+
"//caffe2:torch",
13+
"//executorch/backends/xnnpack:xnnpack_preprocess",
14+
"//executorch/backends/xnnpack/partition:xnnpack_partitioner",
15+
"//executorch/examples/models:models",
16+
"//executorch/examples/quantization:quant_utils",
17+
"//executorch/exir/backend:backend_api",
18+
],
19+
)

examples/backend/targets.bzl

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "get_oss_build_kwargs", "runtime")
2+
3+
def define_common_targets():
4+
"""Defines targets that should be shared between fbcode and xplat.
5+
6+
The directory containing this targets.bzl file should also contain both
7+
TARGETS and BUCK files that call this function.
8+
"""
9+
10+
# executor runner for XNNPACK Backend and portable kernels.
11+
runtime.cxx_binary(
12+
name = "xnn_executor_runner",
13+
srcs = [],
14+
deps = [
15+
"//executorch/examples/executor_runner:executor_runner_lib",
16+
"//executorch/backends/xnnpack:xnnpack_backend",
17+
"//executorch/kernels/portable:generated_lib_all_ops",
18+
],
19+
define_static_target = True,
20+
**get_oss_build_kwargs()
21+
)

examples/backend/xnnpack_examples.py

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# Example script for exporting simple models to flatbuffer
8+
9+
import argparse
10+
import logging
11+
12+
import executorch.exir as exir
13+
from executorch.backends.xnnpack.partition.xnnpack_partitioner import (
14+
XnnpackFloatingPointPartitioner,
15+
XnnpackQuantizedPartitioner2,
16+
)
17+
from executorch.exir.backend.backend_api import to_backend
18+
19+
from ..models import MODEL_NAME_TO_MODEL, MODEL_NAME_TO_OPTIONS
20+
from ..quantization.utils import quantize
21+
22+
logger = logging.getLogger(__name__)
23+
logger.setLevel(logging.INFO)
24+
25+
26+
if __name__ == "__main__":
27+
parser = argparse.ArgumentParser()
28+
parser.add_argument(
29+
"-m",
30+
"--model_name",
31+
required=True,
32+
help=f"Provide model name. Valid ones: {list(MODEL_NAME_TO_OPTIONS.keys())}",
33+
)
34+
parser.add_argument(
35+
"-q",
36+
"--quantize",
37+
action="store_true",
38+
required=False,
39+
default=False,
40+
help="Flag for producing quantized or floating-point model",
41+
)
42+
parser.add_argument(
43+
"-d",
44+
"--delegate",
45+
action="store_true",
46+
required=False,
47+
default=True,
48+
help="Flag for producing XNNPACK delegated model",
49+
)
50+
51+
args = parser.parse_args()
52+
53+
if not args.delegate:
54+
raise NotImplementedError(
55+
"T161880157: Quantization-only without delegation is not supported yet"
56+
)
57+
58+
if args.model_name not in MODEL_NAME_TO_OPTIONS:
59+
raise RuntimeError(
60+
f"Model {args.model_name} is not a valid name. or not quantizable right now, "
61+
"please contact executorch team if you want to learn why or how to support "
62+
"quantization for the requested model"
63+
f"Available models are {list(MODEL_NAME_TO_OPTIONS.keys())}."
64+
)
65+
66+
model, example_inputs = MODEL_NAME_TO_MODEL[args.model_name]()
67+
model = model.eval()
68+
69+
partitioner = XnnpackFloatingPointPartitioner
70+
if args.quantize:
71+
logging.info("Quantizing Model...")
72+
model = quantize(model, example_inputs)
73+
# TODO(T161849167): Partitioner will eventually be a single partitioner for both fp32 and quantized models
74+
partitioner = XnnpackQuantizedPartitioner2
75+
76+
# TODO(T161852812): use export.utils.export_to_edge Delegate implementation is currently on an unlifted graph.
77+
# It will eventually be changed to a lifted graph, in which _unlift=False,
78+
edge = exir.capture(
79+
model, example_inputs, exir.CaptureConfig(enable_aot=True, _unlift=True)
80+
).to_edge(exir.EdgeCompileConfig(_check_ir_validity=False))
81+
logging.info(f"Exported graph:\n{edge.exported_program.graph}")
82+
83+
edge.exported_program = to_backend(edge.exported_program, partitioner)
84+
logging.info(f"Lowered graph:\n{edge.exported_program.graph}")
85+
86+
exec_prog = edge.to_executorch()
87+
buffer = exec_prog.buffer
88+
quant_tag = "_quantize" if args.quantize else ""
89+
filename = f"{args.model_name}_xnnpack_{quant_tag}.pte"
90+
logging.info(f"Saving exported program to {filename}.")
91+
with open(filename, "wb") as f:
92+
f.write(buffer)

examples/models/__init__.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,6 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
from .models import MODEL_NAME_TO_MODEL
7+
from .models import MODEL_NAME_TO_MODEL, MODEL_NAME_TO_OPTIONS
88

9-
__all__ = [
10-
MODEL_NAME_TO_MODEL,
11-
]
9+
__all__ = [MODEL_NAME_TO_MODEL, MODEL_NAME_TO_OPTIONS]

examples/models/models.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
# @file models.py
88
# Simple models for demonstration purposes.
99

10+
from dataclasses import dataclass
11+
1012
from typing import Any, Tuple
1113

1214
import torch
@@ -140,3 +142,17 @@ def gen_resnet50_model_and_inputs() -> Tuple[torch.nn.Module, Any]:
140142
"resnet18": gen_resnet18_model_and_inputs,
141143
"resnet50": gen_resnet50_model_and_inputs,
142144
}
145+
146+
147+
@dataclass
148+
class OptimizationOptions(object):
149+
quantization: bool
150+
xnnpack_delegation: bool
151+
152+
153+
MODEL_NAME_TO_OPTIONS = {
154+
"linear": OptimizationOptions(True, True),
155+
"add": OptimizationOptions(True, True),
156+
"add_mul": OptimizationOptions(True, True),
157+
"mv2": OptimizationOptions(True, True),
158+
}

examples/quantization/TARGETS

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,22 @@ runtime.python_binary(
55
main_src = "example.py",
66
preload_deps = ["//executorch/kernels/quantized:aot_lib"],
77
deps = [
8+
":quant_utils",
89
"//caffe2:torch",
910
"//executorch/examples/export:export_example",
1011
"//executorch/examples/models:models",
1112
],
1213
)
14+
15+
runtime.python_library(
16+
name = "quant_utils",
17+
srcs = [
18+
"utils.py",
19+
],
20+
visibility = [
21+
"//executorch/examples/...",
22+
],
23+
deps = [
24+
"//caffe2:torch",
25+
],
26+
)

examples/quantization/example.py

Lines changed: 7 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -27,30 +27,9 @@
2727

2828
from ..export.export_example import export_to_pte
2929

30-
from ..models import MODEL_NAME_TO_MODEL
30+
from ..models import MODEL_NAME_TO_MODEL, MODEL_NAME_TO_OPTIONS
3131

32-
# Note: for mv3, the mul op is not supported in XNNPACKQuantizer, that could be supported soon
33-
QUANT_MODEL_NAME_TO_MODEL = {
34-
name: MODEL_NAME_TO_MODEL[name] for name in ["linear", "add", "add_mul", "mv2"]
35-
}
36-
37-
38-
def quantize(model_name, model, example_inputs):
39-
"""This is the official recommended flow for quantization in pytorch 2.0 export"""
40-
m = model.eval()
41-
m = export.capture_pre_autograd_graph(m, copy.deepcopy(example_inputs))
42-
print("original model:", m)
43-
quantizer = XNNPACKQuantizer()
44-
# if we set is_per_channel to True, we also need to add out_variant of quantize_per_channel/dequantize_per_channel
45-
operator_config = get_symmetric_quantization_config(is_per_channel=False)
46-
quantizer.set_global(operator_config)
47-
m = prepare_pt2e(m, quantizer)
48-
# calibration
49-
m(*example_inputs)
50-
m = convert_pt2e(m)
51-
print("quantized model:", m)
52-
# make sure we can export to flat buffer
53-
export_to_pte(model_name, m, copy.deepcopy(example_inputs))
32+
from .utils import quantize
5433

5534

5635
def verify_xnnpack_quantizer_matching_fx_quant_model(model_name, model, example_inputs):
@@ -102,7 +81,7 @@ def verify_xnnpack_quantizer_matching_fx_quant_model(model_name, model, example_
10281
"-m",
10382
"--model_name",
10483
required=True,
105-
help=f"Provide model name. Valid ones: {list(QUANT_MODEL_NAME_TO_MODEL.keys())}",
84+
help=f"Provide model name. Valid ones: {list(MODEL_NAME_TO_OPTIONS.keys())}",
10685
)
10786
parser.add_argument(
10887
"-ve",
@@ -122,12 +101,12 @@ def verify_xnnpack_quantizer_matching_fx_quant_model(model_name, model, example_
122101
args = parser.parse_args()
123102
if args.so_library:
124103
torch.ops.load_library(args.so_library)
125-
if not args.verify and args.model_name not in QUANT_MODEL_NAME_TO_MODEL:
104+
if not args.verify and args.model_name not in MODEL_NAME_TO_OPTIONS:
126105
raise RuntimeError(
127106
f"Model {args.model_name} is not a valid name. or not quantizable right now, "
128107
"please contact executorch team if you want to learn why or how to support "
129108
"quantization for the requested model"
130-
f"Available models are {list(QUANT_MODEL_NAME_TO_MODEL.keys())}."
109+
f"Available models are {list(MODEL_NAME_TO_OPTIONS.keys())}."
131110
)
132111

133112
model, example_inputs = MODEL_NAME_TO_MODEL[args.model_name]()
@@ -137,5 +116,6 @@ def verify_xnnpack_quantizer_matching_fx_quant_model(model_name, model, example_
137116
args.model_name, model, example_inputs
138117
)
139118

140-
quantize(args.model_name, model, example_inputs)
119+
quantized_model = quantize(model, example_inputs)
120+
export_to_pte(args.model_name, quantized_model, copy.deepcopy(example_inputs))
141121
print("finished")

examples/quantization/utils.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import copy
8+
9+
import torch._export as export
10+
from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
11+
from torch.ao.quantization.quantizer.xnnpack_quantizer import (
12+
get_symmetric_quantization_config,
13+
XNNPACKQuantizer,
14+
)
15+
16+
17+
def quantize(model, example_inputs):
18+
"""This is the official recommended flow for quantization in pytorch 2.0 export"""
19+
m = model.eval()
20+
m = export.capture_pre_autograd_graph(m, copy.deepcopy(example_inputs))
21+
print("original model:", m)
22+
quantizer = XNNPACKQuantizer()
23+
# if we set is_per_channel to True, we also need to add out_variant of quantize_per_channel/dequantize_per_channel
24+
operator_config = get_symmetric_quantization_config(is_per_channel=False)
25+
quantizer.set_global(operator_config)
26+
m = prepare_pt2e(m, quantizer)
27+
# calibration
28+
m(*example_inputs)
29+
m = convert_pt2e(m)
30+
print("quantized model:", m)
31+
# make sure we can export to flat buffer
32+
return m

0 commit comments

Comments
 (0)