Skip to content

Commit bc2162b

Browse files
authored
Merge branch 'main' into constant-ops-aot
2 parents 9467d4f + 043c7a0 commit bc2162b

File tree

123 files changed

+2346
-1575
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

123 files changed

+2346
-1575
lines changed

.ci/scripts/build_android_instrumentation.sh

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,11 @@ fi
1313
which "${PYTHON_EXECUTABLE}"
1414

1515
build_android_test() {
16-
pushd extension/android_test
17-
ANDROID_HOME="${ANDROID_SDK:-/opt/android/sdk}" ./gradlew testDebugUnitTest
18-
ANDROID_HOME="${ANDROID_SDK:-/opt/android/sdk}" ./gradlew build assembleAndroidTest
16+
mkdir -p extension/android/executorch_android/src/androidTest/resources
17+
cp extension/module/test/resources/add.pte extension/android/executorch_android/src/androidTest/resources
18+
pushd extension/android
19+
ANDROID_HOME="${ANDROID_SDK:-/opt/android/sdk}" ./gradlew :executorch_android:testDebugUnitTest
20+
ANDROID_HOME="${ANDROID_SDK:-/opt/android/sdk}" ./gradlew :executorch_android:assembleAndroidTest
1921
popd
2022
}
2123

@@ -24,8 +26,7 @@ collect_artifacts_to_be_uploaded() {
2426
# Collect Java library test
2527
JAVA_LIBRARY_TEST_DIR="${ARTIFACTS_DIR_NAME}/library_test_dir"
2628
mkdir -p "${JAVA_LIBRARY_TEST_DIR}"
27-
cp extension/android_test/build/outputs/apk/debug/*.apk "${JAVA_LIBRARY_TEST_DIR}"
28-
cp extension/android_test/build/outputs/apk/androidTest/debug/*.apk "${JAVA_LIBRARY_TEST_DIR}"
29+
cp extension/android/executorch_android/build/outputs/apk/androidTest/debug/*.apk "${JAVA_LIBRARY_TEST_DIR}"
2930
}
3031

3132
main() {

.github/workflows/_android.yml

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,14 +28,16 @@ jobs:
2828
PYTHON_EXECUTABLE=python bash .ci/scripts/setup-linux.sh --build-tool buck2
2929
export ARTIFACTS_DIR_NAME=artifacts-to-be-uploaded
3030
31+
mkdir -p ${ARTIFACTS_DIR_NAME}/fp32-xnnpack-custom
32+
bash examples/models/llama/install_requirements.sh
33+
bash ".ci/scripts/test_llama.sh" -model stories110M -build_tool cmake -dtype fp16 -mode portable -upload ${ARTIFACTS_DIR_NAME}/fp32-xnnpack-custom
34+
3135
# Build LLM Demo for Android
3236
export BUILD_AAR_DIR=aar-out
37+
mkdir -p $BUILD_AAR_DIR
3338
bash build/build_android_library.sh ${ARTIFACTS_DIR_NAME}
3439
bash .ci/scripts/build_android_instrumentation.sh ${ARTIFACTS_DIR_NAME}
3540
36-
mkdir -p ${ARTIFACTS_DIR_NAME}/fp32-xnnpack-custom
37-
bash ".ci/scripts/test_llama.sh" -model stories110M -build_tool cmake -dtype fp16 -mode portable -upload ${ARTIFACTS_DIR_NAME}/fp32-xnnpack-custom
38-
3941
mkdir -p examples/demo-apps/android/LlamaDemo/app/libs
4042
cp aar-out/executorch.aar examples/demo-apps/android/LlamaDemo/app/libs
4143
pushd examples/demo-apps/android/LlamaDemo
@@ -94,7 +96,6 @@ jobs:
9496
curl -O https://gha-artifacts.s3.amazonaws.com/${{ github.repository }}/${{ github.run_id }}/artifacts/llm_demo/app-debug.apk
9597
curl -O https://gha-artifacts.s3.amazonaws.com/${{ github.repository }}/${{ github.run_id }}/artifacts/llm_demo/app-debug-androidTest.apk
9698
curl -O https://gha-artifacts.s3.amazonaws.com/${{ github.repository }}/${{ github.run_id }}/artifacts/fp32-xnnpack-custom/model.zip
97-
curl -o android-test-debug.apk https://gha-artifacts.s3.amazonaws.com/${{ github.repository }}/${{ github.run_id }}/artifacts/library_test_dir/executorch-debug.apk
9899
curl -o android-test-debug-androidTest.apk https://gha-artifacts.s3.amazonaws.com/${{ github.repository }}/${{ github.run_id }}/artifacts/library_test_dir/executorch-debug-androidTest.apk
99100
unzip model.zip
100101
mv *.pte model.pte

.github/workflows/doc-build.yml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ jobs:
2626
with:
2727
job-name: Build doc
2828
runner: linux.2xlarge
29-
docker-image: executorch-ubuntu-22.04-clang12
29+
docker-image: executorch-ubuntu-22.04-clang12-android
3030
submodules: 'true'
3131
repository: pytorch/executorch
3232
upload-artifact: docs
@@ -70,8 +70,8 @@ jobs:
7070
7171
# Build javadoc:
7272
cd extension/android
73-
./gradlew javadoc
74-
cp -rf build/docs/javadoc "${RUNNER_DOCS_DIR}"
73+
ANDROID_HOME="${ANDROID_SDK:-/opt/android/sdk}" ./gradlew :executorch_android:javaDocReleaseGeneration
74+
cp -rf executorch_android/build/intermediates/java_doc_dir/release/javaDocReleaseGeneration "${RUNNER_DOCS_DIR}/javadoc"
7575
cd ../..
7676
7777
# If it's main branch, add noindex tag to all .html files to exclude from Google Search indexing.

.github/workflows/lint.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -76,8 +76,8 @@ jobs:
7676
ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }}
7777
timeout: 90
7878
script: |
79-
FILES_NEEDS_FORMAT=$(/opt/google-java-format -n extension/android/src/main/java/org/pytorch/executorch/*.java \
80-
examples/demo-apps/android/ExecuTorchDemo/app/src/main/java/com/example/executorchdemo/*.java \
79+
FILES_NEEDS_FORMAT=$(/opt/google-java-format -n \
80+
extension/android/executorch_android/src/main/java/org/pytorch/executorch/*.java \
8181
examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/*.java \
8282
extension/benchmark/android/benchmark/app/src/main/java/org/pytorch/minibench/*.java)
8383
if [ -n "$FILES_NEEDS_FORMAT" ]; then

.github/workflows/pull.yml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ jobs:
6060
- runner: linux.arm64.2xlarge
6161
docker-image: executorch-ubuntu-22.04-clang12
6262
# TODO: Need to figure out why buck2 doesnt work on Graviton instances.
63-
- runner: linux.arm64.2xlarge
63+
- runner: linux.arm64.2xlarge
6464
build-tool: buck2
6565
fail-fast: false
6666
with:
@@ -420,7 +420,6 @@ jobs:
420420
permissions:
421421
id-token: write
422422
contents: read
423-
needs: test-llama-runner-linux
424423

425424
unittest:
426425
uses: ./.github/workflows/_unittest.yml

backends/arm/_passes/insert_rescales_pass.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -38,17 +38,17 @@ def rescale_fake(
3838
"""Casts the input tensor to dtype `dtype` to produce the correct tensor meta for a _rescale op.
3939
Additionally validates TOSA constraints of a RESCALE op.
4040
"""
41-
if not (dtype == torch.int32 or dtype == torch.int8):
41+
if dtype not in (torch.int32, torch.int8, torch.int16):
4242
raise NotImplementedError(
43-
"tosa::rescale currently only supports int32 and int8."
43+
f"tosa::rescale currently only supports int32, int16 and int8, not {dtype}"
4444
)
45-
if dtype == torch.int32 and out_zp != 0:
45+
if dtype in (torch.int32, torch.int16) and out_zp != 0:
4646
raise ValueError(
47-
"TOSA requires output_zp to be zero when the output dtype is int32."
47+
f"TOSA requires output_zp to be zero when the output dtype is {dtype}."
4848
)
49-
if x.dtype == torch.int32 and in_zp != 0:
49+
if x.dtype in (torch.int32, torch.int16) and in_zp != 0:
5050
raise ValueError(
51-
"TOSA requires input_zp to be zero when the input dtype is int32."
51+
f"TOSA requires input_zp to be zero when the input dtype is {dtype}"
5252
)
5353
if x.dtype == torch.int8 and not -128 <= in_zp <= 127:
5454
raise ValueError(f"{in_zp=} outside valid range (-128,127) for int8.")

backends/arm/_passes/insert_table_ops.py

Lines changed: 106 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
# Copyright 2024-2025 Arm Limited and/or its affiliates.
2-
# All rights reserved.
32
#
43
# This source code is licensed under the BSD-style license found in the
54
# LICENSE file in the root directory of this source tree.
@@ -18,6 +17,7 @@
1817

1918
from executorch.exir.pass_base import ExportPass, PassResult
2019
from torch.fx import GraphModule
20+
2121
from torch.library import impl, Library
2222

2323
lib = Library("tosa", "DEF")
@@ -26,7 +26,10 @@
2626

2727
@impl(lib, "_table")
2828
def _table_impl(*args, **kwargs): # pyre-ignore
29-
return args[0]
29+
in_dtype = args[0].dtype
30+
if in_dtype == torch.int8:
31+
return args[0]
32+
return args[0].to(dtype=torch.int32)
3033

3134

3235
class InsertTableOpsPass(ExportPass):
@@ -59,29 +62,105 @@ def register_buffer(self, buffer_name: str, buffer: torch.Tensor) -> None:
5962
"""
6063
self.exported_program.state_dict[buffer_name] = buffer
6164

62-
def generate_table_values(
65+
def generate_8bit_table_values(
6366
self,
6467
torch_op: Callable[[torch.Tensor], torch.Tensor],
6568
in_quantargs: QuantArgs,
6669
out_quantargs: QuantArgs,
67-
) -> torch.Tensor:
70+
) -> tuple[torch.Tensor, int]:
71+
"""Compute LUT values for a INT8 TOSA.TABLE. Also returns 0 since no shifting is required after 8bit table.
72+
The INT8 table is a simple 256 value 1-1 LUT.
73+
"""
74+
6875
def f(x: torch.Tensor) -> torch.Tensor:
6976
x = in_quantargs.dequantize_value(x)
7077
x = torch_op(x)
7178
return out_quantargs.quantize_value(x)
7279

73-
input_dtype = in_quantargs.dtype
74-
steps = in_quantargs.qmax - in_quantargs.qmin + 1
75-
return f(
80+
return (
81+
f(
82+
torch.linspace(
83+
start=in_quantargs.qmin,
84+
end=in_quantargs.qmax,
85+
steps=256,
86+
# use torch.int64 to avoid overflow when dequantizing (subtracting zp).
87+
# e.g. torch.tensor(-50, dtype=torch.int8) - 100 == torch.tensor(106, dtype=torch.int8)
88+
dtype=torch.int64,
89+
)
90+
).to(dtype=torch.int8),
91+
0,
92+
)
93+
94+
def generate_16_bit_table_values(
95+
self,
96+
torch_op: Callable[[torch.Tensor], torch.Tensor],
97+
in_quantargs: QuantArgs,
98+
out_quantargs: QuantArgs,
99+
) -> tuple[torch.Tensor, int]:
100+
"""Compute LUT values for a INT16 TOSA.TABLE with 32 bit output.
101+
In practice the output is 23 bits that should be interpreted as 16 'whole' bits and 7 fractional bits, see
102+
the specification: https://www.mlplatform.org/tosa/tosa_spec.html#_table. This means that the output
103+
will interpreted as 2**7=128 times too large unless accounted for by rescaling down the table output.
104+
105+
Quantization can be either int16 or int32 which means that the op output could be larger than the 23 bits from
106+
the TOSA.TABLE output. In that case, we need to rescale up the output.
107+
108+
To handle this we need to:
109+
1) Make sure that our table values fit within 16 bits.
110+
2) Insert a rescale after the table to handle the x128 from the fractional bits and match the quantization.
111+
112+
The function returns rescale_lshift which says how much to rescale after the table. This value can negative.
113+
"""
114+
115+
def f(x: torch.Tensor) -> torch.Tensor:
116+
# Dont use the 7 LSBs.
117+
x = in_quantargs.dequantize_value((x & ~0x7F))
118+
x = torch_op(x)
119+
return out_quantargs.quantize_value(x)
120+
121+
lut_values = f(
76122
torch.linspace(
77123
start=in_quantargs.qmin,
78-
end=in_quantargs.qmax,
79-
steps=steps,
124+
end=in_quantargs.qmax + 1,
125+
steps=513,
80126
# use torch.int64 to avoid overflow when dequantizing (subtracting zp).
81127
# e.g. torch.tensor(-50, dtype=torch.int8) - 100 == torch.tensor(106, dtype=torch.int8)
82128
dtype=torch.int64,
83129
)
84-
).to(dtype=input_dtype)
130+
)
131+
# Calculate how much we need to shift table values to fit in 16 signed bits
132+
# ceil(log2(max absolute table value)) + 1 bit for signedness - 16
133+
# Example:
134+
# Max value in the table is 70 000. We want to fit it in 16 signed bits.
135+
# 70 000=0b10001000101110000 (17 digits) has ceil(log2(70 000)) = ceil(16.095) = 17 bits.
136+
# If we shift it 17-16=1 bit, we do get 16 bits (0b1000100010111000),
137+
# but due to signedness this is a negative number! So we need to shift it one more bit.
138+
# Note: for out_quantargs.dtype=torch.int16, rshift == 0 and rescale_lshift = -7.
139+
rshift = int(torch.ceil(torch.log2(lut_values.abs().max()))) + 1 - 16
140+
# The 7 fractional bits are equivalent to a lshift of 7, so subtract 7 from the lshift we do.
141+
rescale_lshift = rshift - 7
142+
lut_values = lut_values >> rshift
143+
return lut_values.to(dtype=torch.int16), rescale_lshift
144+
145+
def generate_table_values(
146+
self,
147+
torch_op: Callable[[torch.Tensor], torch.Tensor],
148+
in_quantargs: QuantArgs,
149+
out_quantargs: QuantArgs,
150+
) -> tuple[torch.Tensor, int]:
151+
match out_quantargs.dtype:
152+
case torch.int8:
153+
return self.generate_8bit_table_values(
154+
torch_op, in_quantargs, out_quantargs
155+
)
156+
case torch.int16 | torch.int32:
157+
return self.generate_16_bit_table_values(
158+
torch_op, in_quantargs, out_quantargs
159+
)
160+
case _:
161+
raise ValueError(
162+
f"Unsupported output dtype for table: {out_quantargs.dtype}"
163+
)
85164

86165
def call(self, graph_module: GraphModule) -> PassResult:
87166
modified = False
@@ -100,10 +179,12 @@ def call(self, graph_module: GraphModule) -> PassResult:
100179
op_target=torch.ops.tosa._table.default,
101180
args=(node.args[0],),
102181
)
182+
output_node = table_node
103183
assert len(input_qparams) == 1
104184
assert len(output_qparams) == 1
105-
# Generate table buffer
106-
buffer = self.generate_table_values(
185+
186+
# Generate table buffer and how much to lshift the table output.
187+
buffer, lshift = self.generate_table_values(
107188
torch_op=self.table_ops[node.target],
108189
in_quantargs=input_qparams[0],
109190
out_quantargs=output_qparams[0],
@@ -114,10 +195,20 @@ def call(self, graph_module: GraphModule) -> PassResult:
114195
self.register_buffer(
115196
buffer_name=table_node.name.replace("_default", ""), buffer=buffer
116197
)
117-
node.replace_all_uses_with(table_node)
198+
199+
if lshift != 0:
200+
scale = 2.0**lshift
201+
rescale_node = create_node(
202+
graph=graph_module.graph,
203+
op_target=torch.ops.tosa._rescale.default,
204+
args=(table_node, output_qparams[0].dtype, scale, 0, 0),
205+
)
206+
output_node = rescale_node
207+
208+
node.replace_all_uses_with(output_node)
118209
graph_module.graph.erase_node(node)
119-
table_node.meta["input_qparams"] = input_qparams
120-
table_node.meta["output_qparams"] = output_qparams
210+
output_node.meta["input_qparams"] = input_qparams
211+
output_node.meta["output_qparams"] = output_qparams
121212
modified = True
122213

123214
if modified:

backends/arm/operator_support/tosa_supported_operators.py

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -212,26 +212,41 @@ def is_node_supported(
212212
class EthosU55NotSupported(OperatorSupportBase):
213213
"""
214214
Certain operators are not supported on U55. These are listed in `unsupported_ops`.
215+
The comment mentions the unsupported TOSA operator that the aten operator maps to where it is not obvious.
216+
For unimplemented operators, this is the anticipated mapping, and it might be incorrect.
215217
"""
216218

217219
unsupported_ops = [
218-
exir_ops.edge.aten.any.default,
219-
exir_ops.edge.aten.any.dim,
220-
exir_ops.edge.aten.any.dims,
220+
exir_ops.edge.aten.any.default, # REDUCE_ANY
221+
exir_ops.edge.aten.any.dim, # REDUCE_ANY
222+
exir_ops.edge.aten.any.dims, # REDUCE_ANY
221223
exir_ops.edge.aten.bitwise_and.Tensor,
222224
exir_ops.edge.aten.bitwise_or.Tensor,
223225
exir_ops.edge.aten.bitwise_xor.Tensor,
226+
exir_ops.edge.aten.bitwise_not,
224227
exir_ops.edge.aten.logical_and.default,
225228
exir_ops.edge.aten.logical_or.default,
226229
exir_ops.edge.aten.logical_xor.default,
227230
exir_ops.edge.aten.logical_not.default,
228-
exir_ops.edge.aten.amax.default,
229-
exir_ops.edge.aten.amin.default,
231+
exir_ops.edge.aten.amax.default, # REDUCE_MAX
232+
exir_ops.edge.aten.amin.default, # REDUCE_MIN
230233
exir_ops.edge.aten.eq.Tensor,
231234
exir_ops.edge.aten.ge.Tensor,
232235
exir_ops.edge.aten.gt.Tensor,
233236
exir_ops.edge.aten.le.Tensor,
234237
exir_ops.edge.aten.lt.Tensor,
238+
exir_ops.edge.aten.flip.default, # REVERSE
239+
exir_ops.edge.aten.grid_sampler_2d, # GATHER
240+
exir_ops.edge.aten.scatter.src,
241+
exir_ops.edge.aten.scatter.value,
242+
exir_ops.edge.aten.select_scatter.default,
243+
exir_ops.edge.aten.scatter_reduce.two,
244+
exir_ops.edge.aten.scatter_add.default,
245+
exir_ops.edge.aten.upsample_nearest2d.vec, # RESIZE
246+
exir_ops.edge.aten.upsample_bilinear2d.vec, # RESIZE
247+
exir_ops.edge.aten.reflection_pad1d.default, # REVERSE
248+
exir_ops.edge.aten.reflection_pad2d.default, # REVERSE
249+
exir_ops.edge.aten.reflection_pad3d.default, # REVERSE
235250
]
236251

237252
def __init__(self, reporter: WhyNoPartitionReporter):

backends/arm/operators/node_visitor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ class NodeVisitor:
3030
]
3131

3232
def __init__(self, exported_program: ExportedProgram, tosa_spec: TosaSpecification):
33-
self._exported_program = exported_program or None
33+
self._exported_program = exported_program
3434
self.tosa_spec = tosa_spec
3535

3636
def define_node(

backends/arm/operators/op_rescale.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,6 @@ def define_node(
3838
input_zp = cast(int, node.args[3])
3939
output_zp = cast(int, node.args[4])
4040

41-
# Skip int16 cases for now.
4241
if input_dtype != map_dtype(torch.int8) and input_zp != 0:
4342
raise ValueError(
4443
f"If input dtype is not int8, input_zp must be 0. Got input_dtype{ts.DTypeNames[input_dtype]}, {input_zp=}"
@@ -48,7 +47,10 @@ def define_node(
4847
f"If output dtype is not int8, output_zp must be 0. Got {output_dtype=}, {output_zp=}"
4948
)
5049

51-
scale_width = 32 if output_dtype == torch.int32 else 16
50+
# scale32 gives higher accuracy but for a higher HW cost.
51+
# For now, always go for scale32.
52+
scale_32 = True
53+
scale_width = 32 if scale_32 else 16
5254
multiplier, shift = tosa_quant_utils.compute_multiplier_and_shift(
5355
[scale], scale_width
5456
)
@@ -58,7 +60,7 @@ def define_node(
5860
output_zp=output_zp,
5961
multiplier=multiplier,
6062
shift=shift,
61-
scale32=output_dtype == torch.int32,
63+
scale32=scale_32,
6264
double_round=False,
6365
per_channel=False,
6466
input_unsigned=False,

0 commit comments

Comments
 (0)