Skip to content

Commit 3eda42a

Browse files
committed
Add OS checks
1 parent e6f3d7b commit 3eda42a

File tree

6 files changed

+17
-11
lines changed

6 files changed

+17
-11
lines changed

backends/apple/mps/runtime/MPSDevice.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ enum class MacOSVersion : uint32_t {
3232
MACOS_VER_13_2_PLUS,
3333
MACOS_VER_13_3_PLUS,
3434
MACOS_VER_14_0_PLUS,
35+
MACOS_VER_15_0_PLUS,
3536
};
3637

3738
enum class LibraryType : uint32_t {
@@ -82,7 +83,8 @@ class MPSDevice {
8283
MPSDevice();
8384
};
8485

85-
bool isMacOS13OrNewer(MacOSVersion version = MacOSVersion::MACOS_VER_13_0_PLUS);
86+
bool is_macos_13_or_newer(
87+
MacOSVersion version = MacOSVersion::MACOS_VER_13_0_PLUS);
8688

8789
} // namespace delegate
8890
} // namespace mps

backends/apple/mps/runtime/MPSDevice.mm

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ static inline MTLLanguageVersion getMetalLanguageVersion(const id<MTLDevice>& de
7676
static bool _macos_13_3_plus = [compileOptions respondsToSelector:@selector(maxTotalThreadsPerThreadgroup)] == YES;
7777

7878
static bool _macos_14_0_plus = [mpsCD instancesRespondToSelector:@selector(conjugateWithTensor:name:)] == YES;
79-
79+
static bool _macos_15_0_plus = [mpsCD instancesRespondToSelector:@selector(scaledDotProductAttentionWithQueryTensor:keyTensor:valueTensor:maskTensor:scale:name:)] == YES;
8080
switch (version) {
8181
case MacOSVersion::MACOS_VER_13_0_PLUS:
8282
return _macos_13_0_plus;
@@ -88,6 +88,8 @@ static inline MTLLanguageVersion getMetalLanguageVersion(const id<MTLDevice>& de
8888
return _macos_13_3_plus;
8989
case MacOSVersion::MACOS_VER_14_0_PLUS:
9090
return _macos_14_0_plus;
91+
case MacOSVersion::MACOS_VER_15_0_PLUS:
92+
return _macos_15_0_plus;
9193
default:
9294
return false;
9395
}
@@ -144,7 +146,7 @@ static inline MTLLanguageVersion getMetalLanguageVersion(const id<MTLDevice>& de
144146
return err;
145147
}
146148

147-
bool isMacOS13OrNewer(MacOSVersion version) {
149+
bool is_macos_13_or_newer(MacOSVersion version) {
148150
return MPSDevice::getInstance()->isMacOS13Plus(version);
149151
}
150152

backends/apple/mps/runtime/MPSExecutor.mm

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ @interface MPSGraphExecutable()
2929
#if TARGET_OS_SIMULATOR or defined(__x86_64__)
3030
_use_shared_mem = false;
3131
#endif
32-
if (!isMacOS13OrNewer(MacOSVersion::MACOS_VER_14_0_PLUS)) {
32+
if (!is_macos_13_or_newer(MacOSVersion::MACOS_VER_14_0_PLUS)) {
3333
_use_shared_mem = false;
3434
}
3535

backends/apple/mps/runtime/operations/BinaryOps.mm

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@
119119
graphNode->output_id() \
120120
); \
121121
ET_CHECK_OR_RETURN_ERROR( \
122-
isMacOS13OrNewer(), NotSupported, \
122+
is_macos_13_or_newer(), NotSupported, \
123123
"%s supported by MPS on MacOS13.0+/iOS16.1+", #aot_name); \
124124
\
125125
_idToMPSGraphTensor[graphNode->output_id()] = binaryOpTensor( \
@@ -176,7 +176,7 @@
176176
return inputTensor;
177177
}
178178

179-
if (!isMacOS13OrNewer(MacOSVersion::MACOS_VER_13_0_PLUS)) {
179+
if (!is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_0_PLUS)) {
180180
MPSGraphTensor* zeroTensor = [mpsGraph constantWithScalar:0.0 dataType:inputTensor.dataType];
181181
MPSGraphTensor* predicateTensor = [mpsGraph lessThanWithPrimaryTensor:inputTensor
182182
secondaryTensor:zeroTensor

backends/apple/mps/runtime/operations/QuantDequant.mm

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,19 +22,21 @@
2222
graphNode->output_id()
2323
);
2424

25+
ET_CHECK_OR_RETURN_ERROR(
26+
is_macos_13_or_newer(MacOSVersion::MACOS_VER_15_0_PLUS),
27+
NotImplemented,
28+
"[ERROR] Operation %s is supported starting with macOS 15.0+ | iOS 18.0 + | iPadOS 18+ | tvOS 18+ | visionOS 2.0+ !",
29+
mpsgraph::EnumNameMPSNodeUnion(nodePtr->mpsnode_union_type()));
30+
2531
MPSGraphTensor* inputTensor = getMPSGraphTensor(graphNode->input1_id());
2632
MPSGraphTensor* scalesTensor = getMPSGraphTensor(graphNode->scales_id());
2733

2834
MPSGraphTensor *zpTensor = [_mpsGraph constantWithScalar:0
2935
dataType:MPSDataTypeInt4];
3036

31-
MPSGraphTensor *minTensor = [_mpsGraph constantWithScalar:0
32-
dataType:MPSDataTypeFloat16];
33-
3437
MPSGraphTensor *wDqTensor = [_mpsGraph dequantizeTensor:inputTensor
3538
scaleTensor:scalesTensor
3639
zeroPointTensor:zpTensor
37-
minTensor:minTensor
3840
dataType:MPSDataTypeFloat16
3941
name:nil];
4042

backends/apple/mps/runtime/operations/UnaryOps.mm

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
_idToMPSGraphTensor[graphNode->output_id()] = [_mpsGraph notWithTensor:inputTensor name:nil];
3434
} else {
3535
ET_CHECK_OR_RETURN_ERROR(
36-
isMacOS13OrNewer(), NotSupported,
36+
is_macos_13_or_newer(), NotSupported,
3737
"mpsBitwiseNotOp supported by MPS on MacOS13.0+/iOS16.1+");
3838
_idToMPSGraphTensor[graphNode->output_id()] = [_mpsGraph bitwiseNOTWithTensor:inputTensor name:nil];
3939
}

0 commit comments

Comments
 (0)