Skip to content

[cortex-m] initial commit #10265

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Apr 23, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions backends/cortex_m/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# Cortex-M Backend

WIP. This is a temporary/placeholder backend for Cortex-M CPUs. It is not intended to be used in production, but rather as a proof of concept. Things will change without notice.
21 changes: 21 additions & 0 deletions backends/cortex_m/ops/TARGETS
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

load("@fbcode_macros//build_defs:python_library.bzl", "python_library")
load("@fbcode_macros//build_defs:export_files.bzl", "export_file")
load("@fbsource//xplat/executorch/codegen:codegen.bzl", "executorch_generated_lib")

oncall("executorch")

python_library(
name = "ops",
srcs = [
"operators.py",
],
deps = [
"fbcode//caffe2:torch",
]
)
98 changes: 98 additions & 0 deletions backends/cortex_m/ops/operators.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import torch
from executorch.exir.dialects._ops import (
ops as exir_ops,
) # To provide the implementation of the operators
from torch.library import impl, Library, register_fake

# New operator library with a custom namespace to allow fusion etc.
lib = Library("cortex_m", "DEF")

###
# dequantize_per_tensor
###

lib.define(
"quantize_per_tensor(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype) -> (Tensor Z)"
)

lib.define(
"quantize_per_tensor.out(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!)"
)


@register_fake("cortex_m::quantize_per_tensor")
def quantize_per_tensor_meta(
input: torch.Tensor,
scale: float,
zero_point: int,
quant_min: int,
quant_max: int,
dtype: torch.dtype,
) -> torch.Tensor:
return torch.empty_like(input, dtype=dtype)


@impl(lib, "quantize_per_tensor", "CompositeExplicitAutograd")
def quantize_per_tensor_impl(
input: torch.Tensor,
scale: float,
zero_point: int,
quant_min: int,
quant_max: int,
dtype: torch.dtype,
) -> torch.Tensor:
"""
The implementation of the quantize_per_tensor operator is the same as the
quantize_per_tensor operator in the edge dialect.
"""
return exir_ops.edge.quantized_decomposed.quantize_per_tensor.default(
input, scale, zero_point, quant_min, quant_max, dtype
)


###
# dequantize_per_tensor
###

lib.define(
"dequantize_per_tensor(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype) -> (Tensor Z)"
)
lib.define(
"dequantize_per_tensor.out(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!)"
)


@register_fake("cortex_m::dequantize_per_tensor")
def dequantize_per_tensor_meta(
input: torch.Tensor,
scale: float,
zero_point: int,
quant_min: int,
quant_max: int,
dtype: torch.dtype,
) -> torch.Tensor:
return torch.empty_like(input, dtype=torch.float)


@impl(lib, "dequantize_per_tensor", "CompositeExplicitAutograd")
def dequantize_per_tensor_impl(
input: torch.Tensor,
scale: float,
zero_point: int,
quant_min: int,
quant_max: int,
dtype: torch.dtype,
) -> torch.Tensor:
"""
The implementation of the dequantize_per_tensor operator is the same as the
dequantize_per_tensor operator in the edge dialect.
"""
return exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default(
input, scale, zero_point, quant_min, quant_max, dtype
)
21 changes: 21 additions & 0 deletions backends/cortex_m/passes/TARGETS
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

load("@fbcode_macros//build_defs:python_library.bzl", "python_library")

oncall("executorch")

python_library(
name = "replace_quant_nodes_pass",
srcs = ["replace_quant_nodes_pass.py"],
deps = [
"//caffe2:torch",
"//executorch/exir:lib",
"//executorch/exir:pass_base",
"//executorch/exir/dialects:lib",
"//executorch/backends/cortex_m/ops:ops",
]
)
62 changes: 62 additions & 0 deletions backends/cortex_m/passes/replace_quant_nodes_pass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Callable, Dict, Tuple

import executorch.backends.cortex_m.ops.operators # noqa
import torch

from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.dialects.edge._ops import EdgeOpOverload
from executorch.exir.pass_base import ExportPass, NodeMetadata, ProxyValue


class ReplaceQuantNodesPass(ExportPass):
"""
Replace quantize and dequantize nodes with the corresponding
cortex_m.quantize_per_tensor and cortex_m.dequantize_per_tensor nodes.
"""

@staticmethod
def _is_qualified_int8_node(args) -> bool:
return (
args[3] >= torch.iinfo(torch.int8).min # qmin
and args[4] <= torch.iinfo(torch.int8).max # qmax
and args[5] == torch.int8 # dtype
)

def __init__(self):
super().__init__()
self.op_replacements = {
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default: {
"new_target": exir_ops.edge.cortex_m.quantize_per_tensor.default,
"qualifier": self._is_qualified_int8_node,
},
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default: {
"new_target": exir_ops.edge.cortex_m.dequantize_per_tensor.default,
"qualifier": self._is_qualified_int8_node,
},
}

def call_operator(
self,
op: Callable[..., object],
args: Tuple[object, ...],
kwargs: Dict[str, object],
meta: NodeMetadata,
) -> ProxyValue:
assert isinstance(
op, EdgeOpOverload
), "Op must be an EdgeOpOverload. Run this pass after to_edge()."

if op in self.op_replacements and self.op_replacements[op]["qualifier"](args):
return super().call_operator(
self.op_replacements[op]["new_target"],
args,
kwargs,
meta,
)
return super().call_operator(op, args, kwargs, meta)
18 changes: 18 additions & 0 deletions backends/cortex_m/test/TARGETS
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

load("@fbcode_macros//build_defs:python_unittest.bzl", "python_unittest")

python_unittest(
name = "test_replace_quant_nodes",
srcs = ["test_replace_quant_nodes.py"],
deps = [
"//pytorch/ao:torchao", # @manual
"//caffe2:torch",
"//executorch/backends/cortex_m/passes:replace_quant_nodes_pass",
"//executorch/backends/cortex_m/ops:ops",
],
)
Loading
Loading