Skip to content

Update OSS repo #2033

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion backends/vulkan/test/op_tests/utils/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ def prepack_ref(self, ref: ValueRef) -> bool:
else:
return ref.supports_prepack and self.should_prepack

def create_value_for(self, ref: ValueRefList) -> str:
def create_value_for(self, ref: ValueRefList) -> str: # noqa: C901
if isinstance(ref, list):
ret_str = ""
for r in ref:
Expand Down
33 changes: 25 additions & 8 deletions docs/source/build-run-xtensa.md
Original file line number Diff line number Diff line change
Expand Up @@ -68,13 +68,14 @@ examples/xtensa/
├── aot
├── kernels
├── ops
├── tests
├── third-party
└── utils
```

***AoT (Ahead-of-Time) Components***:

The AoT folder contains all of the python scripts and functions needed to export the model to an ExecuTorch `.pte` file. In our case, [export_example.py](https://github.com/pytorch/executorch/blob/main/examples/xtensa/aot/export_example.py) defines a model and some example inputs (set to a vector of ones), and runs it through the quantizer (from [quantizer.py](https://github.com/pytorch/executorch/blob/main/examples/xtensa/aot/quantizer.py)). Then a few compiler passes, also defined in [quantizer.py](https://github.com/pytorch/executorch/blob/main/examples/xtensa/aot/quantizer.py), will replace operators with custom ones that are supported and optimized on the chip. Any operator needed to compute things should be defined in [meta_registrations.py](https://github.com/pytorch/executorch/blob/main/examples/xtensa/aot/meta_registrations.py) and have corresponding implemetations in the other folders.
The AoT folder contains all of the python scripts and functions needed to export the model to an ExecuTorch `.pte` file. In our case, [export_example.py](https://github.com/pytorch/executorch/blob/main/examples/xtensa/aot/export_example.py) is an API that takes a model (nn.Module) and representative inputs and runs it through the quantizer (from [quantizer.py](https://github.com/pytorch/executorch/blob/main/examples/xtensa/aot/quantizer.py)). Then a few compiler passes, also defined in [quantizer.py](https://github.com/pytorch/executorch/blob/main/examples/xtensa/aot/quantizer.py), will replace operators with custom ones that are supported and optimized on the chip. Any operator needed to compute things should be defined in [meta_registrations.py](https://github.com/pytorch/executorch/blob/main/examples/xtensa/aot/meta_registrations.py) and have corresponding implemetations in the other folders.

***Operators***:

Expand All @@ -97,17 +98,31 @@ cd executorch
python3 -m examples.portable.scripts.export --model_name="add"
```

***Quantized Linear***:
***Quantized Operators***:

The second, more complex model is a quantized [linear](https://pytorch.org/docs/stable/generated/torch.nn.Linear.html) operation. The model is defined [here](https://github.com/pytorch/executorch/blob/main/examples/xtensa/aot/export_example.py#L88). Linear is the backbone of most Automatic Speech Recognition (ASR) models.
The other, more complex model are custom operators, including:
- a quantized [linear](https://pytorch.org/docs/stable/generated/torch.nn.Linear.html) operation. The model is defined [here](https://github.com/pytorch/executorch/blob/main/examples/xtensa/tests/quantized_linear_example.py#L28). Linear is the backbone of most Automatic Speech Recognition (ASR) models.
- a quantized [conv1d](https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html) operation. The model is defined [here](https://github.com/pytorch/executorch/blob/main/examples/xtensa/tests/quantized_conv1d_example.py#L36). Convolutions are important in wake word and many denoising models.

The generated file is called `XtensaDemoModel.pte`.
In both cases the generated file is called `XtensaDemoModel.pte`.

```bash
cd executorch
python3 -m examples.xtensa.tests.quantized_<linear,conv1d>_example
```

***Small Model: RNNT predictor***:

The torchaudio [RNNT-emformer](https://pytorch.org/audio/stable/tutorials/online_asr_tutorial.html) model is an Automatic Speech Recognition (ASR) model, comprised of three different submodels: an encoder, a predictor and a joiner.
The predictor is a sequence of basic ops (embedding, ReLU, linear, layer norm) and can be exported using:

```bash
cd executorch
python3 -m examples.xtensa.aot.export_example
python3 -m examples.xtensa.tests.rnnt_predictor_quantized_example
```

The generated file is called `XtensaDemoModel.pte`.

### Runtime

**Building the DSP firmware image**
Expand Down Expand Up @@ -139,12 +154,14 @@ cmake -DBUCK2=buck2 \
-DCMAKE_TOOLCHAIN_FILE=<path_to_executorch>/examples/xtensa/xtensa.cmake \
-DCMAKE_INSTALL_PREFIX=cmake-out \
-DCMAKE_BUILD_TYPE=Debug \
-DPYTHON_EXECUTABLE=python3 \
-DEXECUTORCH_BUILD_EXTENSION_RUNNER_UTIL=ON \
-DEXECUTORCH_BUILD_HOST_TARGETS=ON \
-DEXECUTORCH_BUILD_EXECUTOR_RUNNER=OFF \
-DEXECUTORCH_BUILD_PTHREADPOOL=OFF \
-DEXECUTORCH_BUILD_CPUINFO=OFF \
-DEXECUTORCH_BUILD_FLATC=OFF \
-DFLATC_EXECUTABLE="$(which flatc)" \
-DEXECUTORCH_BUILD_EXTENSION_RUNNER_UTIL=ON \
-DPYTHON_EXECUTABLE=python3 \
-Bcmake-out .

cmake --build cmake-out -j8 --target install --config Debug
Expand Down Expand Up @@ -196,6 +213,6 @@ First 20 elements of output 0

In this tutorial, you have learned how to export a quantized operation, build the ExecuTorch runtime and run this model on the Xtensa HiFi4 DSP chip.

The model in this tutorial is a typical operation appearing in ASR models, and can be extended to a complete ASR model by creating the model in [export_example.py](https://github.com/pytorch/executorch/blob/main/examples/xtensa/aot/export_example.py) and adding the needed operators/kernels to [operators](https://github.com/pytorch/executorch/blob/main/examples/xtensa/ops) and [kernels](https://github.com/pytorch/executorch/blob/main/examples/xtensa/kernels).
The (quantized linear) model in this tutorial is a typical operation appearing in ASR models, and can be extended to a complete ASR model by creating the model as a new test and adding the needed operators/kernels to [operators](https://github.com/pytorch/executorch/blob/main/examples/xtensa/ops) and [kernels](https://github.com/pytorch/executorch/blob/main/examples/xtensa/kernels).

Other models can be created following the same structure, always assuming that operators and kernels are available.
68 changes: 68 additions & 0 deletions examples/xtensa/aot/compiler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import logging
from typing import Any, Callable

import torch

from executorch.exir import EdgeCompileConfig, EdgeProgramManager, to_edge

from torch.export import export
from torch.export.exported_program import ExportedProgram


def export_program(
model: Callable,
inputs: Any,
pt2_quant: bool = False,
) -> ExportedProgram:
# we don't support training mode. Make it eval
if hasattr(model, "eval"):
if pt2_quant:
# pyre-fixme[6]: Incompatible parameter type.
torch.ao.quantization.move_exported_model_to_eval(model)
else:
# pyre-fixme[16]: Anonymous callable has no attribute `eval`.
model.eval()

# if it's already an ExportedProgram, just return it
if isinstance(model, ExportedProgram):
return model

assert isinstance(model, torch.nn.Module), "model should be an nn.Module"

# Prevent mkldnn decompositions
torch._C._set_mkldnn_enabled(False)

# else: capture the model and return it.
return export(model, inputs)


# Export the model and lower it it edge IR.
def export_to_edge(
model: Callable,
inputs: Any,
pt2_quant: bool = False,
dump_graphs: bool = False,
) -> EdgeProgramManager:
# Export the model into an ExportedProgram.
expo_program = export_program(model, inputs, pt2_quant)

if dump_graphs:
logging.info(f"Exported graph:\n{expo_program.graph_module.graph}")

# Call to_edge to convert the graph to edge IR.
edge_prog_manager = to_edge(
expo_program, compile_config=EdgeCompileConfig(_check_ir_validity=False)
)

if dump_graphs:
logging.info(
f"Edge graph:\n{edge_prog_manager.exported_program().graph_module.graph}"
)

return edge_prog_manager
59 changes: 15 additions & 44 deletions examples/xtensa/aot/export_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,47 +10,27 @@

from .meta_registrations import * # noqa

import torch
from executorch.exir import EdgeCompileConfig
from torch._export import capture_pre_autograd_graph
from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e

from ...portable.utils import export_to_edge, save_pte_program
from ...portable.utils import save_pte_program

from .compiler import export_to_edge
from .quantizer import (
QuantFusion,
ReplacePT2DequantWithXtensaDequant,
ReplacePT2QuantWithXtensaQuant,
XtensaQuantizer,
XtensaBaseQuantizer,
)


FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s"
logging.basicConfig(level=logging.INFO, format=FORMAT)


if __name__ == "__main__":
in_features = 32
out_features = 16
bias = True
shape = [64, in_features]

class QuantizedLinear(torch.nn.Module):
def __init__(self, in_features: int, out_features: int, bias: bool):
super().__init__()
self.output_linear = torch.nn.Linear(in_features, out_features, bias=bias)

def forward(self, x: torch.Tensor):
output_linear_out = self.output_linear(x)
return output_linear_out

model = QuantizedLinear(in_features, out_features, bias)
model.eval()

example_inputs = (torch.ones(shape),)

def export_xtensa_model(model, example_inputs):
# Quantizer
quantizer = XtensaQuantizer()
quantizer = XtensaBaseQuantizer()

# Export
model_exp = capture_pre_autograd_graph(model, example_inputs)
Expand All @@ -66,29 +46,20 @@ def forward(self, x: torch.Tensor):
patterns = [q.pattern for q in quantizer.quantizers]
QuantFusion(patterns)(converted_model)

# pre-autograd export. eventually this will become torch.export
converted_model_exp = capture_pre_autograd_graph(converted_model, example_inputs)
# Get edge program (note: the name will change to export_to_xtensa in future PRs)
edge_prog_manager = export_to_edge(converted_model, example_inputs, pt2_quant=True)

converted_model_exp = torch.ao.quantization.move_exported_model_to_eval(
converted_model_exp
# Run a couple required passes for quant/dequant ops
xtensa_prog_manager = edge_prog_manager.transform(
[ReplacePT2QuantWithXtensaQuant(), ReplacePT2DequantWithXtensaDequant()],
check_ir_validity=False,
)

exec_prog = (
export_to_edge(
converted_model_exp,
example_inputs,
edge_compile_config=EdgeCompileConfig(
_check_ir_validity=False,
),
)
.transform(
[ReplacePT2QuantWithXtensaQuant(), ReplacePT2DequantWithXtensaDequant()],
check_ir_validity=False,
)
.to_executorch()
)
exec_prog = xtensa_prog_manager.to_executorch()

logging.info(f"Final exported graph:\n{exec_prog.exported_program().graph}")
logging.info(
f"Final exported graph module:\n{exec_prog.exported_program().graph_module}"
)

# Save the program as XtensaDemoModel.pte
save_pte_program(exec_prog, "XtensaDemoModel")
97 changes: 88 additions & 9 deletions examples/xtensa/aot/meta_registrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,14 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Optional, Tuple

import torch
from executorch.exir.scalar_type import ScalarType
from torch.library import impl, Library

from .utils import get_conv1d_output_size

lib = Library("xtensa", "DEF")

lib.define(
Expand All @@ -25,10 +29,31 @@
)

lib.define(
"quantized_linear_pt2(Tensor src, Tensor weight, Tensor bias, float src_scale, int src_zero_point, float weight_scale, int weight_zero_point, Tensor out_multiplier, Tensor out_shift, int out_zero_point) -> (Tensor Z)"
"quantized_layer_norm(Tensor X, Tensor X_scale, Tensor X_zero_point, int[] normalized_shape, Tensor weight, Tensor bias, float eps, float output_scale, int output_zero_point) -> (Tensor Y)"
)

lib.define(
"quantized_layer_norm.out(Tensor X, Tensor X_scale, Tensor X_zero_point, int[] normalized_shape, Tensor weight, Tensor bias, float eps, float output_scale, int output_zero_point, *, Tensor(a!) out) -> Tensor (a!)"
)

lib.define(
"quantized_linear(Tensor src, Tensor weight, Tensor bias, int src_zero_point, Tensor weight_zero_point, Tensor out_multiplier, Tensor out_shift, int out_zero_point, Tensor? offset) -> (Tensor Z)"
)
lib.define(
"quantized_linear.out(Tensor src, Tensor weight, Tensor bias, int src_zero_point, Tensor weight_zero_point, Tensor out_multiplier, Tensor out_shift, int out_zero_point, Tensor? offset, *, Tensor(a!) out) -> Tensor(a!)"
)

lib.define("quantized_relu(Tensor X, Tensor X_zero_point) -> (Tensor Y)")

lib.define(
"quantized_relu.out(Tensor X, Tensor X_zero_point, *, Tensor(a!) out) -> Tensor (a!)"
)

lib.define(
"quantized_conv(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, Tensor weight_zero_point, Tensor bias_scale, float out_scale, int out_zero_point, Tensor out_multiplier, Tensor out_shift, bool channel_last=False) -> (Tensor Z)"
)
lib.define(
"quantized_linear_pt2.out(Tensor src, Tensor weight, Tensor bias, float src_scale, int src_zero_point, float weight_scale, int weight_zero_point, Tensor out_multiplier, Tensor out_shift, int out_zero_point, *, Tensor(a!) out) -> Tensor(a!)"
"quantized_conv.out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, Tensor weight_zero_point, Tensor bias_scale, float out_scale, int out_zero_point, Tensor out_multiplier, Tensor out_shift, bool channel_last=False, *, Tensor(a!) out) -> Tensor(a!)"
)

m = Library("xtensa", "IMPL", "Meta")
Expand Down Expand Up @@ -58,18 +83,17 @@ def dequantize_per_tensor_meta(
return input.new_empty(input.size(), dtype=torch.float)


@impl(m, "quantized_linear_pt2")
def quantized_linear_pt2_meta(
@impl(m, "quantized_linear")
def quantized_linear_meta(
src: torch.Tensor,
weight: torch.Tensor,
bias: torch.Tensor,
in_scale: float,
in_zero_point: int,
weight_scale: float,
weight_zero_point: int,
out_multiplier: int,
out_shift: int,
weight_zero_point: torch.Tensor,
out_multiplier: torch.Tensor,
out_shift: torch.Tensor,
out_zero_point: int,
offset: Optional[torch.Tensor],
):
# src comes in shape [leading_dims, in_dim]
# weight comes in shape [out_dim, in_dim]
Expand All @@ -79,3 +103,58 @@ def quantized_linear_pt2_meta(
assert len(weight_size) == 2
out_size[-1] = weight_size[0]
return src.new_empty(out_size, dtype=torch.uint8)


@impl(m, "quantized_conv")
def quantized_conv_meta(
input: torch.Tensor,
weight: torch.Tensor,
bias: torch.Tensor,
stride: Tuple[int],
padding: Tuple[int],
dilation: Tuple[int],
groups: int,
in_zero_point: int,
weight_zero_point: torch.Tensor,
bias_scale: torch.Tensor,
output_scale: float,
output_zero_point: int,
out_multiplier: torch.Tensor,
out_shift: torch.Tensor,
channel_last: bool = False,
):
out_channels, _in_channels, *kernel_size = weight.shape
in_size = input.shape
# Assert that the input tensor has at least 3 dimensions, and at most 6
assert len(in_size) > 2
assert len(in_size) < 6

# Compute the output tensor size
output_size = get_conv1d_output_size(
in_size, out_channels, stride[0], padding[0], dilation[0], kernel_size[0]
)

return input.new_empty(output_size, dtype=input.dtype)


@impl(m, "quantized_layer_norm")
def quantized_layer_norm_meta(
input: torch.Tensor,
X_scale: torch.Tensor,
X_zero_point: torch.Tensor,
normalized_shape: int,
weight: torch.Tensor,
bias: torch.Tensor,
eps: float,
output_scale: float,
output_zero_point: int,
):
return input.new_empty(input.size(), dtype=torch.uint8)


@impl(m, "quantized_relu")
def quantized_relu_meta(
X: torch.Tensor,
X_zero_point: torch.Tensor,
):
return X.new_empty(X.size(), dtype=torch.uint8)
Loading