Skip to content

Commit 5c8b115

Browse files
shoumikhinfacebook-github-bot
authored andcommitted
Add CoreML tests. (#6203)
Summary: Pull Request resolved: #6203 . Reviewed By: metascroy Differential Revision: D64359459 fbshipit-source-id: acfa3990b1b90fd300ead0f47e71ebe82d70e7f9
1 parent 3a7056e commit 5c8b115

File tree

3 files changed

+111
-2
lines changed

3 files changed

+111
-2
lines changed

extension/benchmark/apple/Benchmark/Benchmark.xcodeproj/project.pbxproj

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
03DD00B22C8FE44600FE4619 /* backend_mps.xcframework in Frameworks */ = {isa = PBXBuildFile; fileRef = 03DD00A22C8FE44600FE4619 /* backend_mps.xcframework */; };
2929
03DD00B32C8FE44600FE4619 /* executorch.xcframework in Frameworks */ = {isa = PBXBuildFile; fileRef = 03DD00A32C8FE44600FE4619 /* executorch.xcframework */; settings = {ATTRIBUTES = (Required, ); }; };
3030
03DD00B52C8FE44600FE4619 /* kernels_quantized.xcframework in Frameworks */ = {isa = PBXBuildFile; fileRef = 03DD00A52C8FE44600FE4619 /* kernels_quantized.xcframework */; };
31+
03E7E6792CBDCAE900205E71 /* CoreMLTests.mm in Sources */ = {isa = PBXBuildFile; fileRef = 03E7E6782CBDC1C900205E71 /* CoreMLTests.mm */; };
3132
03ED6D0F2C8AAFE900F2D6EE /* libsqlite3.0.tbd in Frameworks */ = {isa = PBXBuildFile; fileRef = 03ED6D0E2C8AAFE900F2D6EE /* libsqlite3.0.tbd */; };
3233
03ED6D112C8AAFF200F2D6EE /* MetalPerformanceShadersGraph.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = 03ED6D102C8AAFF200F2D6EE /* MetalPerformanceShadersGraph.framework */; };
3334
03ED6D132C8AAFF700F2D6EE /* MetalPerformanceShaders.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = 03ED6D122C8AAFF700F2D6EE /* MetalPerformanceShaders.framework */; };
@@ -90,6 +91,7 @@
9091
03DD00A22C8FE44600FE4619 /* backend_mps.xcframework */ = {isa = PBXFileReference; lastKnownFileType = wrapper.xcframework; name = backend_mps.xcframework; path = Frameworks/backend_mps.xcframework; sourceTree = "<group>"; };
9192
03DD00A32C8FE44600FE4619 /* executorch.xcframework */ = {isa = PBXFileReference; lastKnownFileType = wrapper.xcframework; name = executorch.xcframework; path = Frameworks/executorch.xcframework; sourceTree = "<group>"; };
9293
03DD00A52C8FE44600FE4619 /* kernels_quantized.xcframework */ = {isa = PBXFileReference; lastKnownFileType = wrapper.xcframework; name = kernels_quantized.xcframework; path = Frameworks/kernels_quantized.xcframework; sourceTree = "<group>"; };
94+
03E7E6782CBDC1C900205E71 /* CoreMLTests.mm */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.cpp.objcpp; path = CoreMLTests.mm; sourceTree = "<group>"; };
9395
03ED6D0E2C8AAFE900F2D6EE /* libsqlite3.0.tbd */ = {isa = PBXFileReference; lastKnownFileType = "sourcecode.text-based-dylib-definition"; name = libsqlite3.0.tbd; path = Platforms/iPhoneOS.platform/Developer/SDKs/iPhoneOS17.5.sdk/usr/lib/libsqlite3.0.tbd; sourceTree = DEVELOPER_DIR; };
9496
03ED6D102C8AAFF200F2D6EE /* MetalPerformanceShadersGraph.framework */ = {isa = PBXFileReference; lastKnownFileType = wrapper.framework; name = MetalPerformanceShadersGraph.framework; path = Platforms/iPhoneOS.platform/Developer/SDKs/iPhoneOS17.5.sdk/System/Library/Frameworks/MetalPerformanceShadersGraph.framework; sourceTree = DEVELOPER_DIR; };
9597
03ED6D122C8AAFF700F2D6EE /* MetalPerformanceShaders.framework */ = {isa = PBXFileReference; lastKnownFileType = wrapper.framework; name = MetalPerformanceShaders.framework; path = Platforms/iPhoneOS.platform/Developer/SDKs/iPhoneOS17.5.sdk/System/Library/Frameworks/MetalPerformanceShaders.framework; sourceTree = DEVELOPER_DIR; };
@@ -232,6 +234,7 @@
232234
isa = PBXGroup;
233235
children = (
234236
032A73C92CAFBA8600932D36 /* LLaMA */,
237+
03E7E6782CBDC1C900205E71 /* CoreMLTests.mm */,
235238
03B2D3792C8A515C0046936E /* GenericTests.mm */,
236239
03B019502C8A80D30044D558 /* Tests.xcconfig */,
237240
037C96A02C8A570B00B3DF38 /* Tests.xctestplan */,
@@ -388,6 +391,7 @@
388391
032A741E2CAFBB7800932D36 /* tiktoken.cpp in Sources */,
389392
032A741F2CAFBB7800932D36 /* sampler.cpp in Sources */,
390393
03B011912CAD114E00054791 /* ResourceTestCase.m in Sources */,
394+
03E7E6792CBDCAE900205E71 /* CoreMLTests.mm in Sources */,
391395
032A74232CAFC1B300932D36 /* runner.cpp in Sources */,
392396
03B2D37A2C8A515C0046936E /* GenericTests.mm in Sources */,
393397
032A73CA2CAFBA8600932D36 /* LLaMATests.mm in Sources */,

extension/benchmark/apple/Benchmark/Benchmark.xcodeproj/xcshareddata/xcschemes/Benchmark.xcscheme

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,8 @@
2525
</BuildAction>
2626
<TestAction
2727
buildConfiguration = "Release"
28-
selectedDebuggerIdentifier = ""
29-
selectedLauncherIdentifier = "Xcode.IDEFoundation.Launcher.PosixSpawn"
28+
selectedDebuggerIdentifier = "Xcode.DebuggerFoundation.Debugger.LLDB"
29+
selectedLauncherIdentifier = "Xcode.DebuggerFoundation.Launcher.LLDB"
3030
shouldUseLaunchSchemeArgsEnv = "YES">
3131
<TestPlans>
3232
<TestPlanReference
Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#import "ResourceTestCase.h"
10+
11+
#import <CoreML/CoreML.h>
12+
13+
static MLMultiArray *DummyMultiArrayForFeature(MLFeatureDescription *feature, NSError **error) {
14+
MLMultiArray *array = [[MLMultiArray alloc] initWithShape:feature.multiArrayConstraint.shape
15+
dataType:feature.multiArrayConstraint.dataType == MLMultiArrayDataTypeInt32 ? MLMultiArrayDataTypeInt32 : MLMultiArrayDataTypeDouble
16+
error:error];
17+
for (auto index = 0; index < array.count; ++index) {
18+
array[index] = feature.multiArrayConstraint.dataType == MLMultiArrayDataTypeInt32 ? @1 : @1.0;
19+
}
20+
return array;
21+
}
22+
23+
static NSMutableDictionary *DummyInputsForModel(MLModel *model, NSError **error) {
24+
NSMutableDictionary *inputs = [NSMutableDictionary dictionary];
25+
NSDictionary<NSString *, MLFeatureDescription *> *inputDescriptions = model.modelDescription.inputDescriptionsByName;
26+
27+
for (NSString *inputName in inputDescriptions) {
28+
MLFeatureDescription *feature = inputDescriptions[inputName];
29+
30+
switch (feature.type) {
31+
case MLFeatureTypeMultiArray: {
32+
MLMultiArray *array = DummyMultiArrayForFeature(feature, error);
33+
inputs[inputName] = [MLFeatureValue featureValueWithMultiArray:array];
34+
break;
35+
}
36+
case MLFeatureTypeInt64:
37+
inputs[inputName] = [MLFeatureValue featureValueWithInt64:1];
38+
break;
39+
case MLFeatureTypeDouble:
40+
inputs[inputName] = [MLFeatureValue featureValueWithDouble:1.0];
41+
break;
42+
case MLFeatureTypeString:
43+
inputs[inputName] = [MLFeatureValue featureValueWithString:@"1"];
44+
break;
45+
default:
46+
break;
47+
}
48+
}
49+
return inputs;
50+
}
51+
52+
@interface CoreMLTests : ResourceTestCase
53+
@end
54+
55+
@implementation CoreMLTests
56+
57+
+ (NSArray<NSString *> *)directories {
58+
return @[@"Resources"];
59+
}
60+
61+
+ (NSDictionary<NSString *, BOOL (^)(NSString *)> *)predicates {
62+
return @{ @"model" : ^BOOL(NSString *filename) {
63+
return [filename hasSuffix:@".mlpackage"];
64+
}};
65+
}
66+
67+
+ (NSDictionary<NSString *, void (^)(XCTestCase *)> *)dynamicTestsForResources:(NSDictionary<NSString *, NSString *> *)resources {
68+
NSString *modelPath = resources[@"model"];
69+
70+
return @{
71+
@"prediction" : ^(XCTestCase *testCase) {
72+
NSError *error = nil;
73+
NSURL *compiledModelURL = [MLModel compileModelAtURL:[NSURL fileURLWithPath:modelPath] error:&error];
74+
if (error || !compiledModelURL) {
75+
XCTFail(@"Failed to compile model: %@", error.localizedDescription);
76+
return;
77+
}
78+
MLModel *model = [MLModel modelWithContentsOfURL:compiledModelURL error:&error];
79+
if (error || !model) {
80+
XCTFail(@"Failed to load model: %@", error.localizedDescription);
81+
return;
82+
}
83+
NSMutableDictionary *inputs = DummyInputsForModel(model, &error);
84+
if (error || !inputs) {
85+
XCTFail(@"Failed to prepare inputs: %@", error.localizedDescription);
86+
return;
87+
}
88+
MLDictionaryFeatureProvider *featureProvider = [[MLDictionaryFeatureProvider alloc] initWithDictionary:inputs error:&error];
89+
if (error || !featureProvider) {
90+
XCTFail(@"Failed to create input provider: %@", error.localizedDescription);
91+
return;
92+
}
93+
[testCase measureWithMetrics:@[[XCTClockMetric new], [XCTMemoryMetric new]]
94+
block:^{
95+
NSError *error = nil;
96+
id<MLFeatureProvider> prediction = [model predictionFromFeatures:featureProvider error:&error];
97+
if (error || !prediction) {
98+
XCTFail(@"Prediction failed: %@", error.localizedDescription);
99+
}
100+
}];
101+
}
102+
};
103+
}
104+
105+
@end

0 commit comments

Comments
 (0)