Skip to content

Commit e170aa5

Browse files
committed
Update base for Update on "Add a simple sdpa"
Add a simple sdpa so it's decomposed to simpler ops instead of the decompose F.scaled_dot_product_attention, which includes 29 ops including `torch.where` ``` def forward(self, q, k, v): aten_mul_scalar = executorch_exir_dialects_edge__ops_aten_mul_Scalar(q, 0.5946035575013605); q = None aten_full_default = executorch_exir_dialects_edge__ops_aten_full_default([8, 8], True, dtype = torch.bool, layout = torch.strided, device = device(type='cpu'), pin_memory = False) aten_arange_start_step = executorch_exir_dialects_edge__ops_aten_arange_start_step(0, 8, layout = torch.strided, device = device(type='cpu'), pin_memory = False) aten_unsqueeze_copy_default = executorch_exir_dialects_edge__ops_aten_unsqueeze_copy_default(aten_arange_start_step, -2); aten_arange_start_step = None aten_arange_start_step_1 = executorch_exir_dialects_edge__ops_aten_arange_start_step(0, 8, layout = torch.strided, device = device(type='cpu'), pin_memory = False) aten_unsqueeze_copy_default_1 = executorch_exir_dialects_edge__ops_aten_unsqueeze_copy_default(aten_arange_start_step_1, -1); aten_arange_start_step_1 = None aten_sub_tensor = executorch_exir_dialects_edge__ops_aten_sub_Tensor(aten_unsqueeze_copy_default, aten_unsqueeze_copy_default_1); aten_unsqueeze_copy_default = aten_unsqueeze_copy_default_1 = None aten_le_scalar = executorch_exir_dialects_edge__ops_aten_le_Scalar(aten_sub_tensor, 0); aten_sub_tensor = None aten_logical_and_default = executorch_exir_dialects_edge__ops_aten_logical_and_default(aten_le_scalar, aten_full_default); aten_le_scalar = aten_full_default = None aten_full_like_default = executorch_exir_dialects_edge__ops_aten_full_like_default(aten_logical_and_default, 0, dtype = torch.float32, pin_memory = False, memory_format = torch.preserve_format) aten_logical_not_default = executorch_exir_dialects_edge__ops_aten_logical_not_default(aten_logical_and_default); aten_logical_and_default = None aten_scalar_tensor_default = executorch_exir_dialects_edge__ops_aten_scalar_tensor_default(-inf, dtype = torch.float32, layout = torch.strided, device = device(type='cpu')) aten_where_self = executorch_exir_dialects_edge__ops_aten_where_self(aten_logical_not_default, aten_scalar_tensor_default, aten_full_like_default); aten_logical_not_default = aten_scalar_tensor_default = aten_full_like_default = None aten_permute_copy_default = executorch_exir_dialects_edge__ops_aten_permute_copy_default(k, [0, 1, 3, 2]); k = None aten_mul_scalar_1 = executorch_exir_dialects_edge__ops_aten_mul_Scalar(aten_permute_copy_default, 0.5946035575013605); aten_permute_copy_default = None aten_expand_copy_default = executorch_exir_dialects_edge__ops_aten_expand_copy_default(aten_mul_scalar, [1, 1, 8, 8]); aten_mul_scalar = None aten_view_copy_default = executorch_exir_dialects_edge__ops_aten_view_copy_default(aten_expand_copy_default, [1, 8, 8]); aten_expand_copy_default = None aten_expand_copy_default_1 = executorch_exir_dialects_edge__ops_aten_expand_copy_default(aten_mul_scalar_1, [1, 1, 8, 8]); aten_mul_scalar_1 = None aten_view_copy_default_1 = executorch_exir_dialects_edge__ops_aten_view_copy_default(aten_expand_copy_default_1, [1, 8, 8]); aten_expand_copy_default_1 = None aten_bmm_default = executorch_exir_dialects_edge__ops_aten_bmm_default(aten_view_copy_default, aten_view_copy_default_1); aten_view_copy_default = aten_view_copy_default_1 = None aten_view_copy_default_2 = executorch_exir_dialects_edge__ops_aten_view_copy_default(aten_bmm_default, [1, 1, 8, 8]); aten_bmm_default = None aten_add_tensor = executorch_exir_dialects_edge__ops_aten_add_Tensor(aten_view_copy_default_2, aten_where_self); aten_view_copy_default_2 = aten_where_self = None aten__softmax_default = executorch_exir_dialects_edge__ops_aten__softmax_default(aten_add_tensor, -1, False); aten_add_tensor = None aten_expand_copy_default_2 = executorch_exir_dialects_edge__ops_aten_expand_copy_default(aten__softmax_default, [1, 1, 8, 8]); aten__softmax_default = None aten_view_copy_default_3 = executorch_exir_dialects_edge__ops_aten_view_copy_default(aten_expand_copy_default_2, [1, 8, 8]); aten_expand_copy_default_2 = None aten_expand_copy_default_3 = executorch_exir_dialects_edge__ops_aten_expand_copy_default(v, [1, 1, 8, 8]); v = None aten_view_copy_default_4 = executorch_exir_dialects_edge__ops_aten_view_copy_default(aten_expand_copy_default_3, [1, 8, 8]); aten_expand_copy_default_3 = None aten_bmm_default_1 = executorch_exir_dialects_edge__ops_aten_bmm_default(aten_view_copy_default_3, aten_view_copy_default_4); aten_view_copy_default_3 = aten_view_copy_default_4 = None aten_view_copy_default_5 = executorch_exir_dialects_edge__ops_aten_view_copy_default(aten_bmm_default_1, [1, 1, 8, 8]); aten_bmm_default_1 = None return (aten_view_copy_default_5,) ``` Differential Revision: [D56119737](https://our.internmc.facebook.com/intern/diff/D56119737/) [ghstack-poisoned]
2 parents 0ff6838 + 4c552d4 commit e170aa5

File tree

169 files changed

+6213
-1318
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

169 files changed

+6213
-1318
lines changed

.ci/docker/ci_commit_pins/pytorch.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
868e5ced5df34f1aef3703654f76e03f5126b534
1+
19f50333e91e9e8b20a78517becd74bca70c7d46

.ci/scripts/test_llama.sh

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ source "$(dirname "${BASH_SOURCE[0]}")/utils.sh"
1212
MODEL_NAME=$1 # stories110M.pt
1313
BUILD_TOOL=$2 # buck2 or cmake
1414
DTYPE=$3 # fp16 or fp32
15-
MODE=${4:-"xnnpack"} # portable or xnnpack
15+
MODE=${4:-"xnnpack+custom"} # portable or xnnpack+custom or xnnpack+custom+qe
1616
if [[ $# -lt 4 ]]; then # Assuming 4 mandatory args
1717
echo "Expecting atleast 4 positional arguments"
1818
echo "Usage: [...]"
@@ -37,7 +37,7 @@ if [[ -z "${MODE:-}" ]]; then
3737
exit 1
3838
fi
3939

40-
if [[ "${MODE}" =~ xnnpack.* ]]; then
40+
if [[ "${MODE}" =~ .*xnnpack.* ]]; then
4141
XNNPACK=ON
4242
else
4343
XNNPACK=OFF
@@ -49,6 +49,12 @@ else
4949
CUSTOM=OFF
5050
fi
5151

52+
if [[ "${MODE}" =~ .*qe.* ]]; then
53+
QE=ON
54+
else
55+
QE=OFF
56+
fi
57+
5258
if [[ -z "${BUCK:-}" ]]; then
5359
BUCK=buck2
5460
fi
@@ -69,6 +75,7 @@ cmake_install_executorch_libraries() {
6975
-DEXECUTORCH_BUILD_EXTENSION_DATA_LOADER=ON \
7076
-DEXECUTORCH_BUILD_CUSTOM="$CUSTOM" \
7177
-DEXECUTORCH_BUILD_OPTIMIZED=ON \
78+
-DEXECUTORCH_BUILD_QUANTIZED=ON \
7279
-DEXECUTORCH_BUILD_XNNPACK="$XNNPACK" \
7380
-DPYTHON_EXECUTABLE="$PYTHON_EXECUTABLE" \
7481
-Bcmake-out .
@@ -84,7 +91,6 @@ cmake_build_llama_runner() {
8491
-DEXECUTORCH_BUILD_CUSTOM="$CUSTOM" \
8592
-DEXECUTORCH_BUILD_OPTIMIZED=ON \
8693
-DEXECUTORCH_BUILD_XNNPACK="$XNNPACK" \
87-
-DEXECUTORCH_BUILD_OPTIMIZED=ON \
8894
-DPYTHON_EXECUTABLE="$PYTHON_EXECUTABLE" \
8995
-Bcmake-out/${dir} \
9096
${dir}
@@ -126,9 +132,15 @@ fi
126132
# Export model.
127133
EXPORTED_MODEL_NAME="${EXPORTED_MODEL_NAME}.pte"
128134
echo "Exporting ${EXPORTED_MODEL_NAME}"
129-
EXPORT_ARGS="-c stories110M.pt -p ${PARAMS} -d ${DTYPE} -n ${EXPORTED_MODEL_NAME}"
130-
if [[ "${MODE}" == "xnnpack+kv+custom" ]]; then
131-
EXPORT_ARGS="${EXPORT_ARGS} -kv --use_sdpa_with_kv_cache -X -qmode 8da4w -G 128"
135+
EXPORT_ARGS="-c stories110M.pt -p ${PARAMS} -d ${DTYPE} -n ${EXPORTED_MODEL_NAME} -kv"
136+
if [[ "${XNNPACK}" == "ON" ]]; then
137+
EXPORT_ARGS="${EXPORT_ARGS} -X -qmode 8da4w -G 128"
138+
fi
139+
if [[ "${CUSTOM}" == "ON" ]]; then
140+
EXPORT_ARGS="${EXPORT_ARGS} --use_sdpa_with_kv_cache"
141+
fi
142+
if [[ "${QE}" == "ON" ]]; then
143+
EXPORT_ARGS="${EXPORT_ARGS} --embedding-quantize 8,1024"
132144
fi
133145
# Add dynamically linked library location
134146
$PYTHON_EXECUTABLE -m examples.models.llama2.export_llama ${EXPORT_ARGS}

.ci/scripts/test_quantized_aot_lib.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ build_cmake_quantized_aot_lib() {
2424
&& retry cmake -DBUCK2=buck2 \
2525
-DCMAKE_BUILD_TYPE=Release \
2626
-DCMAKE_PREFIX_PATH="$CMAKE_PREFIX_PATH" \
27-
-DEXECUTORCH_BUILD_QUANTIZED=ON \
27+
-DEXECUTORCH_BUILD_QUANTIZED_OPS_AOT=ON \
2828
-DPYTHON_EXECUTABLE="$PYTHON_EXECUTABLE" ..)
2929

3030
cmake --build ${CMAKE_OUTPUT_DIR} -j4

.github/workflows/android.yml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,8 @@ on:
1010
- .ci/docker/**
1111
- .github/workflows/android.yml
1212
- install_requirements.sh
13-
- examples/demo-apps/**
13+
- examples/demo-apps/android/**
14+
- extension/android/**
1415
- extension/module/**
1516
workflow_dispatch:
1617

@@ -101,7 +102,7 @@ jobs:
101102
android-app-archive: https://gha-artifacts.s3.amazonaws.com/${{ github.repository }}/${{ github.run_id }}/artifact/app-debug.apk
102103
android-test-archive: https://gha-artifacts.s3.amazonaws.com/${{ github.repository }}/${{ github.run_id }}/artifact/app-debug-androidTest.apk
103104
# The test spec can be downloaded from https://ossci-assets.s3.amazonaws.com/android-llama2-device-farm-test-spec.yml
104-
test-spec: arn:aws:devicefarm:us-west-2:308535385114:upload:02a2cf0f-6d9b-45ee-ba1a-a086587469e6/414cb54d-4d83-4576-8317-93244e4dc50e
105+
test-spec: arn:aws:devicefarm:us-west-2:308535385114:upload:02a2cf0f-6d9b-45ee-ba1a-a086587469e6/abd86868-fa63-467e-a5c7-218194665a77
105106
# The exported llama2 model and its tokenizer, can be downloaded from https://ossci-assets.s3.amazonaws.com/executorch-android-llama2-7b.zip.
106107
# Among the input, this is the biggest file and uploading it to AWS beforehand makes the test run much faster
107108
extra-data: arn:aws:devicefarm:us-west-2:308535385114:upload:02a2cf0f-6d9b-45ee-ba1a-a086587469e6/bd15825b-ddab-4e47-9fef-a9c8935778dd

.github/workflows/doc-build.yml

Lines changed: 18 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -68,23 +68,22 @@ jobs:
6868
make html
6969
cd ..
7070
71+
# If it's main branch, add noindex tag to all .html files to exclude from Google Search indexing.
72+
GITHUB_REF=${{ github.ref }}
73+
echo "GitHub Ref: ${GITHUB_REF}"
74+
if [[ "${{ github.ref }}" == 'refs/heads/main' ]]; then
75+
find docs/_build/html/ -name "*.html" -print0 | xargs -0 sed -i '/<head>/a \ \ <meta name="robots" content="noindex">';
76+
fi
77+
7178
cp -rf docs/_build/html/* "${RUNNER_DOCS_DIR}"
7279
7380
mv docs/_build/html "${RUNNER_ARTIFACT_DIR}"
7481
7582
ls -R "${RUNNER_ARTIFACT_DIR}"/*/*.html
7683
77-
# Enable preview later. Previews are available publicly
78-
#
79-
# upload-preview:
80-
# if: github.repository == 'pytorch/executorch' && github.event_name == 'push' &&
81-
# (github.ref_type == 'branch' && github.ref_name == 'main')
82-
# uses: pytorch/test-infra/.github/workflows/linux_job.yml@main
83-
8484
upload-gh-pages:
8585
needs: build
86-
if: github.repository == 'pytorch/executorch' && github.event_name == 'push' &&
87-
((github.ref_type == 'branch' && github.ref_name == 'main') || github.ref_type == 'tag')
86+
if: github.repository == 'pytorch/executorch' && github.event_name == 'push' && (github.ref == 'refs/heads/main' || startsWith(github.ref, 'refs/heads/release/') || startsWith(github.ref, 'refs/tags/v'))
8887
permissions:
8988
contents: write
9089
uses: pytorch/test-infra/.github/workflows/linux_job.yml@main
@@ -108,10 +107,16 @@ jobs:
108107
elif [[ "${REF_TYPE}" == tag ]]; then
109108
# Strip the leading "v" as well as the trailing patch version and "-rc" suffix.
110109
# For example: 'v0.1.2' -> '0.1' and 'v0.1.2-rc1' -> 0.1.
111-
TARGET_FOLDER=$(echo "${REF_NAME}" | sed 's/^v//i; s/-rc[0-9]*$//; s/\.[0-9]*$//')
112-
else
113-
echo "ERROR: Invalid REF_TYPE: ${REF_TYPE}. Expected 'branch' or 'tag'."
114-
exit 1
110+
case "${REF_NAME}" in
111+
*-rc*)
112+
echo "Aborting upload since this is an RC tag: ${REF_NAME}"
113+
# We don't generate -rc* documentation but for actual tag only.
114+
exit 0
115+
;;
116+
*)
117+
TARGET_FOLDER=$(echo "${REF_NAME}" | sed 's/v\([0-9]\+\)\.\([0-9]\+\)\.[0-9]\+/\1.\2/')
118+
;;
119+
esac
115120
fi
116121
echo "Target Folder: ${TARGET_FOLDER}"
117122
@@ -122,12 +127,6 @@ jobs:
122127
mv "${RUNNER_ARTIFACT_DIR}"/html/* "${TARGET_FOLDER}"
123128
git add "${TARGET_FOLDER}" || true
124129
125-
# If it's main branch, add noindex tag to all .html files to exclude from Google Search indexing.
126-
if [[ "${REF_NAME}" == 'main' ]]; then
127-
find "${TARGET_FOLDER}" -type f -name "*.html" -exec sed -i '/<head>/a <meta name="robots" content="noindex">' {} \;
128-
git add "${TARGET_FOLDER}"/**/*.html || true
129-
fi
130-
131130
git config user.name 'pytorchbot'
132131
git config user.email '[email protected]'
133132
git commit -m "Auto-generating sphinx docs" || true

.github/workflows/pull.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ jobs:
9090
matrix:
9191
dtype: [fp32]
9292
build-tool: [buck2, cmake]
93-
mode: [portable, xnnpack+kv+custom]
93+
mode: [portable, xnnpack+custom, xnnpack+custom+qe]
9494
fail-fast: false
9595
with:
9696
runner: linux.2xlarge

.gitmodules

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,3 +62,9 @@
6262
[submodule "examples/third-party/LLaVA"]
6363
path = examples/third-party/LLaVA
6464
url = https://github.com/haotian-liu/LLaVA.git
65+
[submodule "examples/models/llama2/third-party/re2"]
66+
path = examples/models/llama2/third-party/re2
67+
url = https://github.com/google/re2.git
68+
[submodule "examples/models/llama2/third-party/abseil-cpp"]
69+
path = examples/models/llama2/third-party/abseil-cpp
70+
url = https://github.com/abseil/abseil-cpp.git

backends/apple/coreml/CMakeLists.txt

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@ if(NOT EXECUTORCH_ROOT)
1313
set(EXECUTORCH_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/../../..)
1414
endif()
1515

16+
option(COREML_BUILD_EXECUTOR_RUNNER "Build CoreML executor runner." OFF)
17+
1618
# inmemoryfs sources
1719
set(INMEMORYFS_SOURCES
1820
runtime/inmemoryfs/inmemory_filesystem.cpp
@@ -181,6 +183,14 @@ target_link_libraries(coremldelegate
181183
${SQLITE_LIBRARY}
182184
)
183185

186+
if(COREML_BUILD_EXECUTOR_RUNNER)
187+
target_link_libraries(coremldelegate
188+
PRIVATE
189+
portable_ops_lib
190+
portable_kernels
191+
)
192+
endif()
193+
184194
target_compile_options(coremldelegate PRIVATE "-fobjc-arc")
185195
target_compile_options(coremldelegate PRIVATE "-fno-exceptions")
186196

backends/apple/coreml/README.md

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,10 @@ For delegating the Program to the **Core ML** backend, the client must be respon
2828
import executorch.exir as exir
2929
import torch
3030

31+
from torch.export import export
32+
33+
from executorch.exir import to_edge
34+
3135
from executorch.exir.backend.backend_api import to_backend
3236

3337
from executorch.backends.apple.coreml.compiler import CoreMLBackend
@@ -42,7 +46,7 @@ class LowerableSubModel(torch.nn.Module):
4246
# Convert the lowerable module to Edge IR Representation
4347
to_be_lowered = LowerableSubModel()
4448
example_input = (torch.ones(1), )
45-
to_be_lowered_exir_submodule = exir.capture(to_be_lowered, example_input).to_edge()
49+
to_be_lowered_exir_submodule = to_edge(export(to_be_lowered, example_input))
4650

4751
# Lower to Core ML backend
4852
lowered_module = to_backend('CoreMLBackend', to_be_lowered_exir_submodule.exported_program, [])

backends/apple/coreml/runtime/delegate/ETCoreMLAssetManager.mm

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -630,7 +630,7 @@ - (NSUInteger)_compact:(NSUInteger)sizeInBytes error:(NSError * __autoreleasing
630630
}
631631

632632
if (_estimatedSizeInBytes <= sizeInBytes) {
633-
return YES;
633+
return _estimatedSizeInBytes;
634634
}
635635

636636
std::error_code ec;

backends/apple/coreml/runtime/delegate/ETCoreMLDefaultModelExecutor.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,9 @@ __attribute__((objc_subclassing_restricted))
2828
/// The model.
2929
@property (readonly, strong, nonatomic) ETCoreMLModel* model;
3030

31+
/// If set to `YES` then output backing are ignored.
32+
@property (readwrite, atomic) BOOL ignoreOutputBackings;
33+
3134
@end
3235

3336
NS_ASSUME_NONNULL_END

backends/apple/coreml/runtime/delegate/ETCoreMLDefaultModelExecutor.mm

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,9 @@ - (instancetype)initWithModel:(ETCoreMLModel *)model {
2626
loggingOptions:(const executorchcoreml::ModelLoggingOptions& __unused)loggingOptions
2727
eventLogger:(const executorchcoreml::ModelEventLogger* _Nullable __unused)eventLogger
2828
error:(NSError * __autoreleasing *)error {
29+
if (self.ignoreOutputBackings) {
30+
predictionOptions.outputBackings = @{};
31+
}
2932
id<MLFeatureProvider> outputs = [self.model.mlModel predictionFromFeatures:inputs
3033
options:predictionOptions
3134
error:error];

backends/apple/coreml/runtime/delegate/ETCoreMLLogging.h

Lines changed: 24 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
#import <Foundation/Foundation.h>
99

10+
#import <executorch/runtime/platform/log.h>
1011
#import <os/log.h>
1112

1213
NS_ASSUME_NONNULL_BEGIN
@@ -48,7 +49,11 @@ typedef NS_ERROR_ENUM(ETCoreMLErrorDomain, ETCoreMLError) {
4849

4950
/// Record the error with `os_log_error` and fills `*errorOut` with `NSError`.
5051
#define ETCoreMLLogErrorAndSetNSError(errorOut, errorCode, formatString, ...) \
51-
os_log_error(ETCoreMLErrorUtils.loggingChannel, formatString, ##__VA_ARGS__); \
52+
if (ET_LOG_ENABLED) { \
53+
ET_LOG(Error, "%s", [NSString stringWithFormat:@formatString, ##__VA_ARGS__].UTF8String); \
54+
} else { \
55+
os_log_error(ETCoreMLErrorUtils.loggingChannel, formatString, ##__VA_ARGS__); \
56+
} \
5257
if (errorOut) { \
5358
*errorOut = \
5459
[NSError errorWithDomain:ETCoreMLErrorDomain \
@@ -58,24 +63,31 @@ typedef NS_ERROR_ENUM(ETCoreMLErrorDomain, ETCoreMLError) {
5863
}]; \
5964
}
6065

61-
/// Record the error and its underlying error with `os_log_error` and fills
62-
/// `*errorOut` with NSError.
66+
/// Record the error and its underlying error with `os_log_error` and fills `*errorOut` with `NSError`.
6367
#define ETCoreMLLogUnderlyingErrorAndSetNSError(errorOut, errorCode, underlyingNSError, formatString, ...) \
64-
os_log_error(ETCoreMLErrorUtils.loggingChannel, \
65-
formatString ", with underlying error= %@.", \
66-
##__VA_ARGS__, \
67-
(underlyingNSError).localizedDescription); \
68+
if (ET_LOG_ENABLED) { \
69+
ET_LOG(Error, "%s", [NSString stringWithFormat:@formatString, ##__VA_ARGS__].UTF8String); \
70+
} else { \
71+
os_log_error(ETCoreMLErrorUtils.loggingChannel, \
72+
formatString ", with underlying error= %@.", \
73+
##__VA_ARGS__, \
74+
(underlyingNSError).localizedDescription); \
75+
} \
6876
if (errorOut) { \
6977
*errorOut = [ETCoreMLErrorUtils errorWithCode:errorCode \
7078
underlyingError:underlyingNSError \
7179
format:@formatString, ##__VA_ARGS__]; \
7280
}
7381

74-
#define ETCoreMLLogError(error, formatString, ...) \
75-
os_log_error(ETCoreMLErrorUtils.loggingChannel, \
76-
formatString ", with error= %@.", \
77-
##__VA_ARGS__, \
78-
(error).localizedDescription);
82+
#define ETCoreMLLogError(error, formatString, ...) \
83+
if (ET_LOG_ENABLED) { \
84+
ET_LOG(Error, "%s", [NSString stringWithFormat:@formatString, ##__VA_ARGS__].UTF8String); \
85+
} else { \
86+
os_log_error(ETCoreMLErrorUtils.loggingChannel, \
87+
formatString ", with error= %@.", \
88+
##__VA_ARGS__, \
89+
(error).localizedDescription); \
90+
}
7991

8092

8193
#pragma clang diagnostic pop

backends/apple/coreml/runtime/delegate/ETCoreMLModel.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,18 @@
66
// Please refer to the license found in the LICENSE file in the root directory of the source tree.
77

88
#import <CoreML/CoreML.h>
9+
#import <vector>
910

1011
NS_ASSUME_NONNULL_BEGIN
1112

1213
@class ETCoreMLAsset;
1314

15+
namespace executorchcoreml {
16+
class MultiArray;
17+
}
18+
1419
/// Represents a ML model, the class is a thin wrapper over `MLModel` with additional properties.
20+
__attribute__((objc_subclassing_restricted))
1521
@interface ETCoreMLModel : NSObject
1622

1723
- (instancetype)init NS_UNAVAILABLE;
@@ -31,6 +37,12 @@ NS_ASSUME_NONNULL_BEGIN
3137
orderedOutputNames:(NSOrderedSet<NSString*>*)orderedOutputNames
3238
error:(NSError* __autoreleasing*)error NS_DESIGNATED_INITIALIZER;
3339

40+
- (nullable NSArray<MLMultiArray*>*)prepareInputs:(const std::vector<executorchcoreml::MultiArray>&)inputs
41+
error:(NSError* __autoreleasing*)error;
42+
43+
- (nullable NSArray<MLMultiArray*>*)prepareOutputBackings:(const std::vector<executorchcoreml::MultiArray>&)outputs
44+
error:(NSError* __autoreleasing*)error;
45+
3446
/// The underlying MLModel.
3547
@property (strong, readonly, nonatomic) MLModel* mlModel;
3648

0 commit comments

Comments
 (0)