Skip to content

Commit 5f39bfe

Browse files
mcremon-metafacebook-github-bot
authored andcommitted
Update OSS repo (#2033)
Summary: Update the OSS Xtensa repo with more up to date compiler and quantizer things. Introduce a test folder and a conv1d test. Reviewed By: cccclai Differential Revision: D54034581
1 parent 20714e7 commit 5f39bfe

File tree

9 files changed

+655
-102
lines changed

9 files changed

+655
-102
lines changed

examples/xtensa/aot/export_example.py

Lines changed: 3 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -29,26 +29,7 @@
2929
logging.basicConfig(level=logging.INFO, format=FORMAT)
3030

3131

32-
if __name__ == "__main__":
33-
in_features = 32
34-
out_features = 16
35-
bias = True
36-
shape = [64, in_features]
37-
38-
class QuantizedLinear(torch.nn.Module):
39-
def __init__(self, in_features: int, out_features: int, bias: bool):
40-
super().__init__()
41-
self.output_linear = torch.nn.Linear(in_features, out_features, bias=bias)
42-
43-
def forward(self, x: torch.Tensor):
44-
output_linear_out = self.output_linear(x)
45-
return output_linear_out
46-
47-
model = QuantizedLinear(in_features, out_features, bias)
48-
model.eval()
49-
50-
example_inputs = (torch.ones(shape),)
51-
32+
def export_xtensa_model(model, example_inputs):
5233
# Quantizer
5334
quantizer = XtensaQuantizer()
5435

@@ -77,14 +58,14 @@ def forward(self, x: torch.Tensor):
7758
export_to_edge(
7859
converted_model_exp,
7960
example_inputs,
80-
EdgeCompileConfig(
61+
edge_compile_config=EdgeCompileConfig(
8162
_check_ir_validity=False,
8263
),
8364
)
8465
.transform(
8566
[ReplacePT2QuantWithXtensaQuant(), ReplacePT2DequantWithXtensaDequant()]
8667
)
87-
.to_executorch(config=ExecutorchBackendConfig(extract_constant_segment=False))
68+
.to_executorch(config=ExecutorchBackendConfig())
8869
)
8970

9071
logging.info(f"Final exported graph:\n{exec_prog.exported_program().graph}")

examples/xtensa/aot/meta_registrations.py

Lines changed: 50 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,14 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
from typing import Tuple
8+
79
import torch
810
from executorch.exir.scalar_type import ScalarType
911
from torch.library import impl, Library
1012

13+
from .utils import get_conv1d_output_size
14+
1115
lib = Library("xtensa", "DEF")
1216

1317
lib.define(
@@ -25,10 +29,17 @@
2529
)
2630

2731
lib.define(
28-
"quantized_linear_pt2(Tensor src, Tensor weight, Tensor bias, float src_scale, int src_zero_point, float weight_scale, int weight_zero_point, Tensor out_multiplier, Tensor out_shift, int out_zero_point) -> (Tensor Z)"
32+
"quantized_linear(Tensor src, Tensor weight, Tensor bias, int src_zero_point, Tensor weight_zero_point, Tensor out_multiplier, Tensor out_shift, int out_zero_point) -> (Tensor Z)"
33+
)
34+
lib.define(
35+
"quantized_linear.out(Tensor src, Tensor weight, Tensor bias, int src_zero_point, Tensor weight_zero_point, Tensor out_multiplier, Tensor out_shift, int out_zero_point, *, Tensor(a!) out) -> Tensor(a!)"
36+
)
37+
38+
lib.define(
39+
"quantized_conv(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, Tensor weight_zero_point, Tensor bias_scale, float out_scale, int out_zero_point, Tensor out_multiplier, Tensor out_shift, bool channel_last=False) -> (Tensor Z)"
2940
)
3041
lib.define(
31-
"quantized_linear_pt2.out(Tensor src, Tensor weight, Tensor bias, float src_scale, int src_zero_point, float weight_scale, int weight_zero_point, Tensor out_multiplier, Tensor out_shift, int out_zero_point, *, Tensor(a!) out) -> Tensor(a!)"
42+
"quantized_conv.out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, Tensor weight_zero_point, Tensor bias_scale, float out_scale, int out_zero_point, Tensor out_multiplier, Tensor out_shift, bool channel_last=False, *, Tensor(a!) out) -> Tensor(a!)"
3243
)
3344

3445
m = Library("xtensa", "IMPL", "Meta")
@@ -58,17 +69,15 @@ def dequantize_per_tensor_meta(
5869
return input.new_empty(input.size(), dtype=torch.float)
5970

6071

61-
@impl(m, "quantized_linear_pt2")
62-
def quantized_linear_pt2_meta(
72+
@impl(m, "quantized_linear")
73+
def quantized_linear_meta(
6374
src: torch.Tensor,
6475
weight: torch.Tensor,
6576
bias: torch.Tensor,
66-
in_scale: float,
6777
in_zero_point: int,
68-
weight_scale: float,
69-
weight_zero_point: int,
70-
out_multiplier: int,
71-
out_shift: int,
78+
weight_zero_point: torch.Tensor,
79+
out_multiplier: torch.Tensor,
80+
out_shift: torch.Tensor,
7281
out_zero_point: int,
7382
):
7483
# src comes in shape [leading_dims, in_dim]
@@ -79,3 +88,35 @@ def quantized_linear_pt2_meta(
7988
assert len(weight_size) == 2
8089
out_size[-1] = weight_size[0]
8190
return src.new_empty(out_size, dtype=torch.uint8)
91+
92+
93+
@impl(m, "quantized_conv")
94+
def quantized_conv_meta(
95+
input: torch.Tensor,
96+
weight: torch.Tensor,
97+
bias: torch.Tensor,
98+
stride: Tuple[int],
99+
padding: Tuple[int],
100+
dilation: Tuple[int],
101+
groups: int,
102+
in_zero_point: int,
103+
weight_zero_point: torch.Tensor,
104+
bias_scale: torch.Tensor,
105+
output_scale: float,
106+
output_zero_point: int,
107+
out_multiplier: torch.Tensor,
108+
out_shift: torch.Tensor,
109+
channel_last: bool = False,
110+
):
111+
out_channels, _in_channels, *kernel_size = weight.shape
112+
in_size = input.shape
113+
# Assert that the input tensor has at least 3 dimensions, and at most 6
114+
assert len(in_size) > 2
115+
assert len(in_size) < 6
116+
117+
# Compute the output tensor size
118+
output_size = get_conv1d_output_size(
119+
in_size, out_channels, stride[0], padding[0], dilation[0], kernel_size[0]
120+
)
121+
122+
return input.new_empty(output_size, dtype=input.dtype)

0 commit comments

Comments
 (0)