Skip to content

Commit 07838a4

Browse files
committed
Merge branch 'main' into jz/re-land-ao-api
2 parents f4795df + 12f4431 commit 07838a4

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

60 files changed

+793
-269
lines changed

.ci/scripts/gather_test_models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
"dl3": "linux.4xlarge.memory",
3434
"emformer_join": "linux.4xlarge.memory",
3535
"emformer_predict": "linux.4xlarge.memory",
36-
"phi-4-mini": "linux.4xlarge.memory",
36+
"phi_4_mini": "linux.4xlarge.memory",
3737
}
3838
}
3939

.ci/scripts/test_llama_torchao_lowbit.sh

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,6 @@ ${PYTHON_EXECUTABLE} -m examples.models.llama.export_llama \
7878
-qmode "torchao:8da${QLINEAR_BITWIDTH}w" \
7979
--group_size ${QLINEAR_GROUP_SIZE} \
8080
-E "torchao:${QEMBEDDING_BITWIDTH},${QEMBEDDING_GROUP_SIZE}" \
81-
--disable_dynamic_shape \
8281
-d fp32
8382

8483
# Test run

.ci/scripts/test_model.sh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -100,11 +100,11 @@ test_model() {
100100
rm "./${MODEL_NAME}.pte"
101101
return # Skip running with portable executor runnner since portable doesn't support Qwen's biased linears.
102102
fi
103-
if [[ "${MODEL_NAME}" == "phi-4-mini" ]]; then
103+
if [[ "${MODEL_NAME}" == "phi_4_mini" ]]; then
104104
# Install requirements for export_llama
105105
bash examples/models/llama/install_requirements.sh
106106
# Test export_llama script: python3 -m examples.models.llama.export_llama.
107-
"${PYTHON_EXECUTABLE}" -m examples.models.llama.export_llama --model "${MODEL_NAME}" -c examples/models/llama/params/demo_rand_params.pth -p examples/models/phi-4-mini/config.json
107+
"${PYTHON_EXECUTABLE}" -m examples.models.llama.export_llama --model "${MODEL_NAME}" -c examples/models/llama/params/demo_rand_params.pth -p examples/models/phi_4_mini/config.json
108108
run_portable_executor_runner
109109
rm "./${MODEL_NAME}.pte"
110110
return

.ci/scripts/unittest-macos-buck2.sh

100644100755
File mode changed.

.github/workflows/pull.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ jobs:
106106
- model: emformer_join
107107
backend: xnnpack-quantization-delegation
108108
runner: linux.4xlarge.memory
109-
- model: phi-4-mini
109+
- model: phi_4_mini
110110
backend: portable
111111
runner: linux.4xlarge.memory
112112
- model: llama3_2_vision_encoder

.github/workflows/trunk.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ jobs:
7272
backend: portable
7373
- model: softmax
7474
backend: portable
75-
- model: phi-4-mini
75+
- model: phi_4_mini
7676
backend: portable
7777
- model: qwen2_5
7878
backend: portable

backends/apple/coreml/CMakeLists.txt

Lines changed: 39 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,9 @@
11
# Copyright © 2023 Apple Inc. All rights reserved.
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
# All rights reserved.
4+
#
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
27

38
cmake_minimum_required(VERSION 3.19)
49

@@ -111,32 +116,48 @@ set(PROTOBUF_SOURCES
111116
runtime/sdk/format/WordTagger.pb.cc
112117
)
113118

119+
find_library(FOUNDATION_FRAMEWORK Foundation)
120+
121+
# CoreML util
122+
add_library(coreml_util ${UTIL_SOURCES})
123+
target_include_directories(coreml_util PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/runtime/util)
124+
target_link_libraries(coreml_util PRIVATE ${FOUNDATION_FRAMEWORK})
125+
126+
install(
127+
TARGETS coreml_util
128+
DESTINATION lib
129+
INCLUDES
130+
DESTINATION ${_common_include_directories}
131+
)
132+
133+
# CoreML inmemoryfs
134+
add_library(coreml_inmemoryfs ${INMEMORYFS_SOURCES})
135+
target_include_directories(coreml_inmemoryfs PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/runtime/inmemoryfs)
136+
target_link_libraries(coreml_inmemoryfs PRIVATE coreml_util ${FOUNDATION_FRAMEWORK})
137+
138+
install(
139+
TARGETS coreml_inmemoryfs
140+
DESTINATION lib
141+
INCLUDES
142+
DESTINATION ${_common_include_directories}
143+
)
144+
114145
# Define the delegate library
115146
add_library(coremldelegate)
116-
target_sources(
117-
coremldelegate PRIVATE ${INMEMORYFS_SOURCES} ${KVSTORE_SOURCES}
118-
${DELEGATE_SOURCES} ${UTIL_SOURCES}
119-
)
147+
target_sources(coremldelegate PRIVATE ${KVSTORE_SOURCES} ${DELEGATE_SOURCES})
120148

121149
target_include_directories(
122150
coremldelegate PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/runtime/include
123151
)
124152
target_include_directories(
125153
coremldelegate PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/runtime/kvstore
126154
)
127-
target_include_directories(
128-
coremldelegate PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/runtime/inmemoryfs
129-
)
130155
target_include_directories(
131156
coremldelegate PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/runtime/delegate
132157
)
133-
target_include_directories(
134-
coremldelegate PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/runtime/util
135-
)
136158
target_include_directories(coremldelegate PRIVATE ${EXECUTORCH_ROOT}/..)
137159
target_include_directories(coremldelegate PRIVATE ${EXECUTORCH_ROOT}/runtime/core/portable_type/c10)
138160
target_compile_definitions(coremldelegate PRIVATE C10_USING_CUSTOM_GENERATED_MACROS)
139-
target_link_libraries(coremldelegate PRIVATE executorch_core)
140161

141162
if(EXECUTORCH_BUILD_DEVTOOLS)
142163
target_sources(coremldelegate PRIVATE ${SDK_SOURCES} ${PROTOBUF_SOURCES})
@@ -156,13 +177,17 @@ endif()
156177

157178
find_library(ACCELERATE_FRAMEWORK Accelerate)
158179
find_library(COREML_FRAMEWORK CoreML)
159-
find_library(FOUNDATION_FRAMEWORK Foundation)
160180
find_library(SQLITE_LIBRARY sqlite3)
161181

162182
target_link_libraries(
163183
coremldelegate
164-
PRIVATE executorch_core ${ACCELERATE_FRAMEWORK} ${COREML_FRAMEWORK}
165-
${FOUNDATION_FRAMEWORK} ${SQLITE_LIBRARY}
184+
PUBLIC coreml_util
185+
coreml_inmemoryfs
186+
PRIVATE executorch_core
187+
${ACCELERATE_FRAMEWORK}
188+
${COREML_FRAMEWORK}
189+
${FOUNDATION_FRAMEWORK}
190+
${SQLITE_LIBRARY}
166191
)
167192

168193
target_link_options_shared_lib(coremldelegate)

backends/arm/_passes/arm_pass_manager.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,10 @@
5555
RetraceFoldedDtypesPass,
5656
)
5757
from executorch.backends.arm._passes.fuse_batchnorm2d_pass import FuseBatchnorm2DPass
58-
from executorch.backends.arm._passes.fuse_constant_ops_pass import FuseConstantOpsPass
58+
from executorch.backends.arm._passes.fuse_constant_ops_pass import (
59+
ComputeConstantOpsAOT,
60+
FuseConstantArgsPass,
61+
)
5962
from executorch.backends.arm._passes.fuse_quantized_activation_pass import ( # type: ignore[import-not-found]
6063
FuseQuantizedActivationPass,
6164
)
@@ -121,21 +124,23 @@ def _tosa_080_BI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
121124
self.add_pass(QuantizeOperatorArguments())
122125
self.add_pass(FoldAndAnnotateQParamsPass()) # type: ignore[call-arg]
123126
self.add_pass(RetraceFoldedDtypesPass())
127+
self.add_pass(UnsqueezeScalarPlaceholdersPass(exported_program))
128+
self.add_pass(MatchArgRanksPass(exported_program))
129+
self.add_pass(ComputeConstantOpsAOT(exported_program))
124130

125131
self.add_pass(RemoveClonePass())
126132
self.add_pass(SizeAdjustConv2DPass())
127133
self.add_pass(ConvertExpandCopyToRepeatPass())
128134
self.add_pass(UnsqueezeBeforeRepeatPass())
129-
self.add_pass(UnsqueezeScalarPlaceholdersPass(exported_program))
130135
self.add_pass(CastInt64ToInt32Pass(exported_program))
131-
self.add_pass(MatchArgRanksPass(exported_program))
132136
self.add_pass(KeepDimsFalseToSqueezePass())
133137
self.add_pass(Conv1dUnsqueezePass(exported_program))
134138
self.add_pass(DecomposeSelectPass())
135139
self.add_pass(ConvertSqueezesToViewPass())
136140

137141
self.add_pass(FuseViewCopyTransform())
138-
self.add_pass(FuseConstantOpsPass(exported_program))
142+
self.add_pass(FuseConstantArgsPass(exported_program))
143+
139144
self.add_pass(InsertTableOpsPass(exported_program))
140145
self.add_pass(AnnotateChannelsLastDimOrder())
141146
self.add_pass(InsertRescalePass())
@@ -166,21 +171,22 @@ def _tosa_080_MI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
166171
self.add_pass(QuantizeOperatorArguments())
167172
self.add_pass(FoldAndAnnotateQParamsPass()) # type: ignore[call-arg]
168173
self.add_pass(RetraceFoldedDtypesPass())
174+
self.add_pass(UnsqueezeScalarPlaceholdersPass(exported_program))
175+
self.add_pass(MatchArgRanksPass(exported_program))
176+
self.add_pass(ComputeConstantOpsAOT(exported_program))
169177

170178
self.add_pass(RemoveClonePass())
171179
self.add_pass(SizeAdjustConv2DPass())
172180
self.add_pass(ConvertExpandCopyToRepeatPass())
173181
self.add_pass(UnsqueezeBeforeRepeatPass())
174-
self.add_pass(UnsqueezeScalarPlaceholdersPass(exported_program))
175182
self.add_pass(CastInt64ToInt32Pass(exported_program))
176-
self.add_pass(MatchArgRanksPass(exported_program))
177183
self.add_pass(KeepDimsFalseToSqueezePass())
178184
self.add_pass(Conv1dUnsqueezePass(exported_program))
179185
self.add_pass(DecomposeSelectPass())
180186
self.add_pass(ConvertSqueezesToViewPass())
181187

182188
self.add_pass(FuseViewCopyTransform())
183-
self.add_pass(FuseConstantOpsPass(exported_program))
189+
self.add_pass(FuseConstantArgsPass(exported_program))
184190
self.add_pass(InsertTableOpsPass(exported_program))
185191
self.add_pass(AnnotateChannelsLastDimOrder())
186192
self.add_pass(InsertRescalePass())

backends/arm/_passes/cast_int64_pass.py

Lines changed: 28 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
import logging
99

1010
import torch
11-
from executorch.backends.arm._passes.arm_pass_utils import is_param_node
1211
from executorch.exir.pass_base import ExportPass, PassResult
1312
from torch._export.utils import is_buffer
1413

@@ -25,35 +24,37 @@ def __init__(self, exported_program: torch.export.ExportedProgram):
2524
super(CastInt64ToInt32Pass, self).__init__()
2625
self.exported_program = exported_program
2726

27+
def _assert_within_int32(self, tensor: torch.Tensor, node: torch.fx.Node):
28+
if torch.min(tensor) < torch.iinfo(torch.int32).min:
29+
raise RuntimeError(
30+
f"Node {node.name} has value < {torch.iinfo(torch.int32).min}"
31+
)
32+
if torch.max(tensor) > torch.iinfo(torch.int32).max:
33+
raise RuntimeError(
34+
f"Node {node.name} has value > {torch.iinfo(torch.int32).max}"
35+
)
36+
2837
def _to_int32(self, graph_module: torch.fx.GraphModule):
2938
for node in graph_module.graph.nodes:
3039
fake_tensor = node.meta["val"]
31-
if isinstance(fake_tensor, torch._subclasses.fake_tensor.FakeTensor):
32-
if node.meta["val"].dtype == torch.int64 and is_param_node(
33-
self.exported_program, node
34-
):
35-
if is_buffer(self.exported_program, node):
36-
node.meta["val"] = node.meta["val"].to(torch.int32)
37-
buffer_name = (
38-
self.exported_program.graph_signature.inputs_to_buffers[
39-
node.name
40-
]
41-
)
42-
buffer = self.exported_program.state_dict[node.name]
43-
logger.warning(
44-
f"Casting buffer {node.name} from torch.int64 to torch.int32"
45-
f" defined in {node.meta['stack_trace']}"
46-
)
47-
if torch.min(buffer) < torch.iinfo(torch.int32).min:
48-
raise RuntimeError(
49-
f"Buffer {node.name} has value < {torch.iinfo(torch.int32).min}"
50-
)
51-
if torch.max(buffer) > torch.iinfo(torch.int32).max:
52-
raise RuntimeError(
53-
f"Buffer {node.name} has value > {torch.iinfo(torch.int32).max}"
54-
)
55-
buffer_int32 = buffer.to(torch.int32)
56-
self.exported_program.state_dict[buffer_name] = buffer_int32
40+
if not isinstance(fake_tensor, torch._subclasses.fake_tensor.FakeTensor):
41+
continue
42+
if fake_tensor.dtype != torch.int64:
43+
continue
44+
if is_buffer(self.exported_program, node):
45+
node.meta["val"] = fake_tensor.to(torch.int32)
46+
buffer_name = self.exported_program.graph_signature.inputs_to_buffers[
47+
node.name
48+
]
49+
buffer = self.exported_program.state_dict[node.name]
50+
self._assert_within_int32(buffer, node)
51+
logger.warning(
52+
f"Casting buffer {node.name} from torch.int64 to torch.int32"
53+
f" defined in {node.meta.get('stack_trace','[no stack trace found]')}"
54+
)
55+
buffer_int32 = buffer.to(torch.int32)
56+
self.exported_program.state_dict[buffer_name] = buffer_int32
57+
continue
5758

5859
def call(self, graph_module: torch.fx.GraphModule):
5960
self._to_int32(graph_module)

backends/arm/_passes/fold_qdq_with_annotated_qparams_pass.py

Lines changed: 2 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -174,11 +174,8 @@ def call(self, graph_module: GraphModule) -> PassResult:
174174

175175
class QuantizeOperatorArguments(ExportPass):
176176
"""
177-
This pass makes sure that the arguments to full.default and clamp.default are quantized correctly.
177+
This pass makes sure that the arguments to clamp.default are quantized correctly.
178178
More specifically, this pass:
179-
- Makes sure the fill_value for full.default is quantized. This pass needs to be run before
180-
the folding pass above to make sure that the retraced output of the full.default op is
181-
the right dtype.
182179
- Makes sure the min and max values to clamp.default are quantized, if it's a quantized operator.
183180
"""
184181

@@ -189,7 +186,6 @@ def call(self, graph_module: GraphModule) -> PassResult:
189186
n = cast(Node, n)
190187
if n.target not in {
191188
exir_ops.edge.aten.clamp.default,
192-
exir_ops.edge.aten.full.default,
193189
}:
194190
continue
195191

@@ -200,16 +196,7 @@ def call(self, graph_module: GraphModule) -> PassResult:
200196

201197
qargs = QuantArgs.from_operator(user.target, user.args)
202198

203-
if n.target == exir_ops.edge.aten.full.default:
204-
if "dtype" not in n.kwargs.keys() or n.kwargs["dtype"] != qargs.dtype:
205-
# replace the node arg with a quantized dito and also set dtype
206-
# to get the right output according to the Edge IR specification:
207-
# exir/dialects/edge/edge.yaml:3596
208-
quantized_full_value = qargs.quantize_value(n.args[1]).item()
209-
n.update_arg(1, quantized_full_value)
210-
n.update_kwarg("dtype", qargs.dtype)
211-
modified = True
212-
elif n.target == exir_ops.edge.aten.clamp.default:
199+
if n.target == exir_ops.edge.aten.clamp.default:
213200
# Quantize the min and max arguments of clamp, if they are not None
214201
min_val = n.args[1]
215202
max_val = None if len(n.args) <= 2 else n.args[2]

0 commit comments

Comments
 (0)