File tree Expand file tree Collapse file tree 6 files changed +17
-11
lines changed
backends/apple/mps/runtime Expand file tree Collapse file tree 6 files changed +17
-11
lines changed Original file line number Diff line number Diff line change @@ -32,6 +32,7 @@ enum class MacOSVersion : uint32_t {
32
32
MACOS_VER_13_2_PLUS,
33
33
MACOS_VER_13_3_PLUS,
34
34
MACOS_VER_14_0_PLUS,
35
+ MACOS_VER_15_0_PLUS,
35
36
};
36
37
37
38
enum class LibraryType : uint32_t {
@@ -82,7 +83,8 @@ class MPSDevice {
82
83
MPSDevice ();
83
84
};
84
85
85
- bool isMacOS13OrNewer (MacOSVersion version = MacOSVersion::MACOS_VER_13_0_PLUS);
86
+ bool is_macos_13_or_newer (
87
+ MacOSVersion version = MacOSVersion::MACOS_VER_13_0_PLUS);
86
88
87
89
} // namespace delegate
88
90
} // namespace mps
Original file line number Diff line number Diff line change @@ -76,7 +76,7 @@ static inline MTLLanguageVersion getMetalLanguageVersion(const id<MTLDevice>& de
76
76
static bool _macos_13_3_plus = [compileOptions respondsToSelector: @selector (maxTotalThreadsPerThreadgroup )] == YES ;
77
77
78
78
static bool _macos_14_0_plus = [mpsCD instancesRespondToSelector: @selector (conjugateWithTensor:name: )] == YES ;
79
-
79
+ static bool _macos_15_0_plus = [mpsCD instancesRespondToSelector: @selector ( scaledDotProductAttentionWithQueryTensor:keyTensor:valueTensor:maskTensor:scale:name: )] == YES ;
80
80
switch (version) {
81
81
case MacOSVersion::MACOS_VER_13_0_PLUS:
82
82
return _macos_13_0_plus;
@@ -88,6 +88,8 @@ static inline MTLLanguageVersion getMetalLanguageVersion(const id<MTLDevice>& de
88
88
return _macos_13_3_plus;
89
89
case MacOSVersion::MACOS_VER_14_0_PLUS:
90
90
return _macos_14_0_plus;
91
+ case MacOSVersion::MACOS_VER_15_0_PLUS:
92
+ return _macos_15_0_plus;
91
93
default :
92
94
return false ;
93
95
}
@@ -144,7 +146,7 @@ static inline MTLLanguageVersion getMetalLanguageVersion(const id<MTLDevice>& de
144
146
return err;
145
147
}
146
148
147
- bool isMacOS13OrNewer (MacOSVersion version) {
149
+ bool is_macos_13_or_newer (MacOSVersion version) {
148
150
return MPSDevice::getInstance ()->isMacOS13Plus (version);
149
151
}
150
152
Original file line number Diff line number Diff line change @@ -29,7 +29,7 @@ @interface MPSGraphExecutable()
29
29
#if TARGET_OS_SIMULATOR or defined(__x86_64__)
30
30
_use_shared_mem = false ;
31
31
#endif
32
- if (!isMacOS13OrNewer (MacOSVersion::MACOS_VER_14_0_PLUS)) {
32
+ if (!is_macos_13_or_newer (MacOSVersion::MACOS_VER_14_0_PLUS)) {
33
33
_use_shared_mem = false ;
34
34
}
35
35
Original file line number Diff line number Diff line change 119
119
graphNode->output_id () \
120
120
); \
121
121
ET_CHECK_OR_RETURN_ERROR ( \
122
- isMacOS13OrNewer (), NotSupported, \
122
+ is_macos_13_or_newer (), NotSupported, \
123
123
" %s supported by MPS on MacOS13.0+/iOS16.1+" , #aot_name); \
124
124
\
125
125
_idToMPSGraphTensor[graphNode->output_id ()] = binaryOpTensor ( \
176
176
return inputTensor;
177
177
}
178
178
179
- if (!isMacOS13OrNewer (MacOSVersion::MACOS_VER_13_0_PLUS)) {
179
+ if (!is_macos_13_or_newer (MacOSVersion::MACOS_VER_13_0_PLUS)) {
180
180
MPSGraphTensor* zeroTensor = [mpsGraph constantWithScalar: 0.0 dataType: inputTensor.dataType];
181
181
MPSGraphTensor* predicateTensor = [mpsGraph lessThanWithPrimaryTensor: inputTensor
182
182
secondaryTensor: zeroTensor
Original file line number Diff line number Diff line change 22
22
graphNode->output_id ()
23
23
);
24
24
25
+ ET_CHECK_OR_RETURN_ERROR (
26
+ is_macos_13_or_newer (MacOSVersion::MACOS_VER_15_0_PLUS),
27
+ NotImplemented,
28
+ " [ERROR] Operation %s is supported starting with macOS 15.0+ | iOS 18.0 + | iPadOS 18+ | tvOS 18+ | visionOS 2.0+ !" ,
29
+ mpsgraph::EnumNameMPSNodeUnion (nodePtr->mpsnode_union_type ()));
30
+
25
31
MPSGraphTensor* inputTensor = getMPSGraphTensor (graphNode->input1_id ());
26
32
MPSGraphTensor* scalesTensor = getMPSGraphTensor (graphNode->scales_id ());
27
33
28
34
MPSGraphTensor *zpTensor = [_mpsGraph constantWithScalar: 0
29
35
dataType: MPSDataTypeInt4];
30
36
31
- MPSGraphTensor *minTensor = [_mpsGraph constantWithScalar: 0
32
- dataType: MPSDataTypeFloat16];
33
-
34
37
MPSGraphTensor *wDqTensor = [_mpsGraph dequantizeTensor: inputTensor
35
38
scaleTensor: scalesTensor
36
39
zeroPointTensor: zpTensor
37
- minTensor: minTensor
38
40
dataType: MPSDataTypeFloat16
39
41
name: nil ];
40
42
Original file line number Diff line number Diff line change 33
33
_idToMPSGraphTensor[graphNode->output_id ()] = [_mpsGraph notWithTensor: inputTensor name: nil ];
34
34
} else {
35
35
ET_CHECK_OR_RETURN_ERROR (
36
- isMacOS13OrNewer (), NotSupported,
36
+ is_macos_13_or_newer (), NotSupported,
37
37
" mpsBitwiseNotOp supported by MPS on MacOS13.0+/iOS16.1+" );
38
38
_idToMPSGraphTensor[graphNode->output_id ()] = [_mpsGraph bitwiseNOTWithTensor: inputTensor name: nil ];
39
39
}
You can’t perform that action at this time.
0 commit comments