Skip to content

Commit e9a0c4f

Browse files
mcremon-metafacebook-github-bot
authored andcommitted
Update OSS repo (#2033)
Summary: Update the OSS Xtensa repo with more up to date compiler and quantizer things. Introduce a test folder and a conv1d test. Reviewed By: cccclai Differential Revision: D54034581
1 parent 252508b commit e9a0c4f

16 files changed

+1196
-138
lines changed

docs/source/build-run-xtensa.md

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -68,13 +68,14 @@ examples/xtensa/
6868
├── aot
6969
├── kernels
7070
├── ops
71+
├── tests
7172
├── third-party
7273
└── utils
7374
```
7475

7576
***AoT (Ahead-of-Time) Components***:
7677

77-
The AoT folder contains all of the python scripts and functions needed to export the model to an ExecuTorch `.pte` file. In our case, [export_example.py](https://github.com/pytorch/executorch/blob/main/examples/xtensa/aot/export_example.py) defines a model and some example inputs (set to a vector of ones), and runs it through the quantizer (from [quantizer.py](https://github.com/pytorch/executorch/blob/main/examples/xtensa/aot/quantizer.py)). Then a few compiler passes, also defined in [quantizer.py](https://github.com/pytorch/executorch/blob/main/examples/xtensa/aot/quantizer.py), will replace operators with custom ones that are supported and optimized on the chip. Any operator needed to compute things should be defined in [meta_registrations.py](https://github.com/pytorch/executorch/blob/main/examples/xtensa/aot/meta_registrations.py) and have corresponding implemetations in the other folders.
78+
The AoT folder contains all of the python scripts and functions needed to export the model to an ExecuTorch `.pte` file. In our case, [export_example.py](https://github.com/pytorch/executorch/blob/main/examples/xtensa/aot/export_example.py) is an API that takes a model (nn.Module) and representative inputs and runs it through the quantizer (from [quantizer.py](https://github.com/pytorch/executorch/blob/main/examples/xtensa/aot/quantizer.py)). Then a few compiler passes, also defined in [quantizer.py](https://github.com/pytorch/executorch/blob/main/examples/xtensa/aot/quantizer.py), will replace operators with custom ones that are supported and optimized on the chip. Any operator needed to compute things should be defined in [meta_registrations.py](https://github.com/pytorch/executorch/blob/main/examples/xtensa/aot/meta_registrations.py) and have corresponding implemetations in the other folders.
7879

7980
***Operators***:
8081

@@ -99,13 +100,15 @@ python3 -m examples.portable.scripts.export --model_name="add"
99100

100101
***Quantized Linear***:
101102

102-
The second, more complex model is a quantized [linear](https://pytorch.org/docs/stable/generated/torch.nn.Linear.html) operation. The model is defined [here](https://github.com/pytorch/executorch/blob/main/examples/xtensa/aot/export_example.py#L88). Linear is the backbone of most Automatic Speech Recognition (ASR) models.
103+
The other, more complex model are custom operators, including:
104+
- a quantized [linear](https://pytorch.org/docs/stable/generated/torch.nn.Linear.html) operation. The model is defined [here](https://github.com/pytorch/executorch/blob/main/examples/xtensa/tests/quantized_linear_example.py#L28). Linear is the backbone of most Automatic Speech Recognition (ASR) models.
105+
- a quantized [conv1d](https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html) operation. The model is defined [here](https://github.com/pytorch/executorch/blob/main/examples/xtensa/tests/quantized_conv1d_example.py#L36). Convolutions are important in wake word and many denoising models.
103106

104-
The generated file is called `XtensaDemoModel.pte`.
107+
In both cases the generated file is called `XtensaDemoModel.pte`.
105108

106109
```bash
107110
cd executorch
108-
python3 -m examples.xtensa.aot.export_example
111+
python3 -m examples.xtensa.tests.quantized_<linear,conv1d>_example
109112
```
110113

111114
### Runtime
@@ -189,6 +192,6 @@ First 20 elements of output 0
189192

190193
In this tutorial, you have learned how to export a quantized operation, build the ExecuTorch runtime and run this model on the Xtensa HiFi4 DSP chip.
191194

192-
The model in this tutorial is a typical operation appearing in ASR models, and can be extended to a complete ASR model by creating the model in [export_example.py](https://github.com/pytorch/executorch/blob/main/examples/xtensa/aot/export_example.py) and adding the needed operators/kernels to [operators](https://github.com/pytorch/executorch/blob/main/examples/xtensa/ops) and [kernels](https://github.com/pytorch/executorch/blob/main/examples/xtensa/kernels).
195+
The (quantized linear) model in this tutorial is a typical operation appearing in ASR models, and can be extended to a complete ASR model by creating the model as a new test and adding the needed operators/kernels to [operators](https://github.com/pytorch/executorch/blob/main/examples/xtensa/ops) and [kernels](https://github.com/pytorch/executorch/blob/main/examples/xtensa/kernels).
193196

194197
Other models can be created following the same structure, always assuming that operators and kernels are available.

examples/xtensa/aot/compiler.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import logging
8+
from typing import Any, Callable
9+
10+
import torch
11+
12+
from executorch.exir import EdgeCompileConfig, EdgeProgramManager, to_edge
13+
14+
from torch.export import export
15+
from torch.export.exported_program import ExportedProgram
16+
17+
18+
def export_program(
19+
model: Callable,
20+
inputs: Any,
21+
pt2_quant: bool = False,
22+
) -> ExportedProgram:
23+
# we don't support training mode. Make it eval
24+
if hasattr(model, "eval"):
25+
if pt2_quant:
26+
# pyre-fixme[6]: Incompatible parameter type.
27+
torch.ao.quantization.move_exported_model_to_eval(model)
28+
else:
29+
# pyre-fixme[16]: Anonymous callable has no attribute `eval`.
30+
model.eval()
31+
32+
# if it's already an ExportedProgram, just return it
33+
if isinstance(model, ExportedProgram):
34+
return model
35+
36+
assert isinstance(model, torch.nn.Module), "model should be an nn.Module"
37+
38+
# Prevent mkldnn decompositions
39+
torch._C._set_mkldnn_enabled(False)
40+
41+
# else: capture the model and return it.
42+
return export(model, inputs)
43+
44+
45+
# Export the model and lower it it edge IR.
46+
def export_to_edge(
47+
model: Callable,
48+
inputs: Any,
49+
pt2_quant: bool = False,
50+
dump_graphs: bool = False,
51+
) -> EdgeProgramManager:
52+
# Export the model into an ExportedProgram.
53+
expo_program = export_program(model, inputs, pt2_quant)
54+
55+
if dump_graphs:
56+
logging.info(
57+
f"Exported graph:\n{expo_program.graph_module.graph.print_tabular()}"
58+
)
59+
60+
# Call to_edge to convert the graph to edge IR.
61+
edge_prog_manager = to_edge(
62+
expo_program, compile_config=EdgeCompileConfig(_check_ir_validity=False)
63+
)
64+
65+
if dump_graphs:
66+
logging.info(
67+
f"Edge graph:\n{edge_prog_manager.exported_program().graph_module.graph.print_tabular()}"
68+
)
69+
70+
return edge_prog_manager

examples/xtensa/aot/export_example.py

Lines changed: 12 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -10,47 +10,28 @@
1010

1111
from .meta_registrations import * # noqa
1212

13-
import torch
14-
from executorch.exir import EdgeCompileConfig, ExecutorchBackendConfig
13+
from executorch.exir import ExecutorchBackendConfig
1514
from torch._export import capture_pre_autograd_graph
1615
from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
1716

18-
from ...portable.utils import export_to_edge, save_pte_program
17+
from ...portable.utils import save_pte_program
1918

19+
from .compiler import export_to_edge
2020
from .quantizer import (
2121
QuantFusion,
2222
ReplacePT2DequantWithXtensaDequant,
2323
ReplacePT2QuantWithXtensaQuant,
24-
XtensaQuantizer,
24+
XtensaBaseQuantizer,
2525
)
2626

2727

2828
FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s"
2929
logging.basicConfig(level=logging.INFO, format=FORMAT)
3030

3131

32-
if __name__ == "__main__":
33-
in_features = 32
34-
out_features = 16
35-
bias = True
36-
shape = [64, in_features]
37-
38-
class QuantizedLinear(torch.nn.Module):
39-
def __init__(self, in_features: int, out_features: int, bias: bool):
40-
super().__init__()
41-
self.output_linear = torch.nn.Linear(in_features, out_features, bias=bias)
42-
43-
def forward(self, x: torch.Tensor):
44-
output_linear_out = self.output_linear(x)
45-
return output_linear_out
46-
47-
model = QuantizedLinear(in_features, out_features, bias)
48-
model.eval()
49-
50-
example_inputs = (torch.ones(shape),)
51-
32+
def export_xtensa_model(model, example_inputs):
5233
# Quantizer
53-
quantizer = XtensaQuantizer()
34+
quantizer = XtensaBaseQuantizer()
5435

5536
# Export
5637
model_exp = capture_pre_autograd_graph(model, example_inputs)
@@ -66,26 +47,15 @@ def forward(self, x: torch.Tensor):
6647
patterns = [q.pattern for q in quantizer.quantizers]
6748
QuantFusion(patterns)(converted_model)
6849

69-
# pre-autograd export. eventually this will become torch.export
70-
converted_model_exp = capture_pre_autograd_graph(converted_model, example_inputs)
50+
# Get edge program (note: the name will change to export_to_xtensa in future PRs)
51+
edge_prog_manager = export_to_edge(converted_model, example_inputs, pt2_quant=True)
7152

72-
converted_model_exp = torch.ao.quantization.move_exported_model_to_eval(
73-
converted_model_exp
53+
# Run a couple required passes for quant/dequant ops
54+
xtensa_prog_manager = edge_prog_manager.transform(
55+
[ReplacePT2QuantWithXtensaQuant(), ReplacePT2DequantWithXtensaDequant()]
7456
)
7557

76-
exec_prog = (
77-
export_to_edge(
78-
converted_model_exp,
79-
example_inputs,
80-
EdgeCompileConfig(
81-
_check_ir_validity=False,
82-
),
83-
)
84-
.transform(
85-
[ReplacePT2QuantWithXtensaQuant(), ReplacePT2DequantWithXtensaDequant()]
86-
)
87-
.to_executorch(config=ExecutorchBackendConfig(extract_constant_segment=False))
88-
)
58+
exec_prog = xtensa_prog_manager.to_executorch(config=ExecutorchBackendConfig())
8959

9060
logging.info(f"Final exported graph:\n{exec_prog.exported_program().graph}")
9161

examples/xtensa/aot/meta_registrations.py

Lines changed: 51 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,14 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
from typing import Optional, Tuple
8+
79
import torch
810
from executorch.exir.scalar_type import ScalarType
911
from torch.library import impl, Library
1012

13+
from .utils import get_conv1d_output_size
14+
1115
lib = Library("xtensa", "DEF")
1216

1317
lib.define(
@@ -25,10 +29,17 @@
2529
)
2630

2731
lib.define(
28-
"quantized_linear_pt2(Tensor src, Tensor weight, Tensor bias, float src_scale, int src_zero_point, float weight_scale, int weight_zero_point, Tensor out_multiplier, Tensor out_shift, int out_zero_point) -> (Tensor Z)"
32+
"quantized_linear(Tensor src, Tensor weight, Tensor bias, int src_zero_point, Tensor weight_zero_point, Tensor out_multiplier, Tensor out_shift, int out_zero_point, Tensor? offset) -> (Tensor Z)"
33+
)
34+
lib.define(
35+
"quantized_linear.out(Tensor src, Tensor weight, Tensor bias, int src_zero_point, Tensor weight_zero_point, Tensor out_multiplier, Tensor out_shift, int out_zero_point, Tensor? offset, *, Tensor(a!) out) -> Tensor(a!)"
36+
)
37+
38+
lib.define(
39+
"quantized_conv(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, Tensor weight_zero_point, Tensor bias_scale, float out_scale, int out_zero_point, Tensor out_multiplier, Tensor out_shift, bool channel_last=False) -> (Tensor Z)"
2940
)
3041
lib.define(
31-
"quantized_linear_pt2.out(Tensor src, Tensor weight, Tensor bias, float src_scale, int src_zero_point, float weight_scale, int weight_zero_point, Tensor out_multiplier, Tensor out_shift, int out_zero_point, *, Tensor(a!) out) -> Tensor(a!)"
42+
"quantized_conv.out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, Tensor weight_zero_point, Tensor bias_scale, float out_scale, int out_zero_point, Tensor out_multiplier, Tensor out_shift, bool channel_last=False, *, Tensor(a!) out) -> Tensor(a!)"
3243
)
3344

3445
m = Library("xtensa", "IMPL", "Meta")
@@ -58,18 +69,17 @@ def dequantize_per_tensor_meta(
5869
return input.new_empty(input.size(), dtype=torch.float)
5970

6071

61-
@impl(m, "quantized_linear_pt2")
62-
def quantized_linear_pt2_meta(
72+
@impl(m, "quantized_linear")
73+
def quantized_linear_meta(
6374
src: torch.Tensor,
6475
weight: torch.Tensor,
6576
bias: torch.Tensor,
66-
in_scale: float,
6777
in_zero_point: int,
68-
weight_scale: float,
69-
weight_zero_point: int,
70-
out_multiplier: int,
71-
out_shift: int,
78+
weight_zero_point: torch.Tensor,
79+
out_multiplier: torch.Tensor,
80+
out_shift: torch.Tensor,
7281
out_zero_point: int,
82+
offset: Optional[torch.Tensor]
7383
):
7484
# src comes in shape [leading_dims, in_dim]
7585
# weight comes in shape [out_dim, in_dim]
@@ -79,3 +89,35 @@ def quantized_linear_pt2_meta(
7989
assert len(weight_size) == 2
8090
out_size[-1] = weight_size[0]
8191
return src.new_empty(out_size, dtype=torch.uint8)
92+
93+
94+
@impl(m, "quantized_conv")
95+
def quantized_conv_meta(
96+
input: torch.Tensor,
97+
weight: torch.Tensor,
98+
bias: torch.Tensor,
99+
stride: Tuple[int],
100+
padding: Tuple[int],
101+
dilation: Tuple[int],
102+
groups: int,
103+
in_zero_point: int,
104+
weight_zero_point: torch.Tensor,
105+
bias_scale: torch.Tensor,
106+
output_scale: float,
107+
output_zero_point: int,
108+
out_multiplier: torch.Tensor,
109+
out_shift: torch.Tensor,
110+
channel_last: bool = False,
111+
):
112+
out_channels, _in_channels, *kernel_size = weight.shape
113+
in_size = input.shape
114+
# Assert that the input tensor has at least 3 dimensions, and at most 6
115+
assert len(in_size) > 2
116+
assert len(in_size) < 6
117+
118+
# Compute the output tensor size
119+
output_size = get_conv1d_output_size(
120+
in_size, out_channels, stride[0], padding[0], dilation[0], kernel_size[0]
121+
)
122+
123+
return input.new_empty(output_size, dtype=input.dtype)

0 commit comments

Comments
 (0)