Skip to content

Commit 52ea3c6

Browse files
DenisVieriu97facebook-github-bot
authored andcommitted
Build MPS delegate with Werror=1 (#1736)
Summary: Summary of changes: - Build MPS Delegate with Werror=1 by default (tested on macOS12.0, 13.0 and 14.0) - Cast the placeholders to FP16 only if initial placeholder is FP32. Other data types remain unchanged. cc shoumikhin, cccclai, larryliu0820 Pull Request resolved: #1736 Reviewed By: cccclai Differential Revision: D53150094 Pulled By: shoumikhin fbshipit-source-id: f8f3b4d74ffbf06cfc97a20a115eefd0a44b7ef3
1 parent 359faa6 commit 52ea3c6

File tree

14 files changed

+232
-73
lines changed

14 files changed

+232
-73
lines changed

backends/apple/mps/operators/node_visitor.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -374,7 +374,10 @@ def process_placeholder_nodes(
374374
input_id = placeholder_visitor.define_tensor(node, mps_graph)
375375
mps_graph.input_ids.append(input_id)
376376

377-
if placeholder_visitor.convert_model_to_fp16:
377+
if (
378+
placeholder_visitor.convert_model_to_fp16
379+
and node.meta["val"].dtype == torch.float32
380+
):
378381
mps_node = MPSNode(
379382
mpsnode_union=MPSCast(
380383
input1_id=input_id,
@@ -393,7 +396,10 @@ def process_output_node(
393396
output_id = output_visitor.define_tensor(output_node, mps_graph)
394397
mps_graph.output_ids.append(output_id)
395398

396-
if output_visitor.convert_model_to_fp16:
399+
if (
400+
output_visitor.convert_model_to_fp16
401+
and output_node.meta["val"].dtype == torch.float32
402+
):
397403
mps_node = MPSNode(
398404
mpsnode_union=MPSCast(
399405
input1_id=output_id,

backends/apple/mps/runtime/MPSGraphBuilder.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,11 @@
1616
#include <executorch/runtime/core/exec_aten/util/scalar_type_util.h>
1717

1818
// MPS headers
19+
#include <executorch/backends/apple/mps/runtime/operations/MPSGraphVenturaOps.h>
1920
#include <executorch/backends/apple/mps/runtime/operations/OperationUtils.h>
2021
#include <executorch/backends/apple/mps/schema_generated.h>
2122

23+
#include <unordered_map>
2224
#include <vector>
2325

2426
namespace torch {

backends/apple/mps/runtime/MPSStream.h

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -91,9 +91,6 @@ class MPSStream {
9191
MPSCommandBuffer* _commandBuffer = nil;
9292
MPSCommandBuffer* _prevCommandBuffer = nil;
9393
id<MTLComputeCommandEncoder> _commandEncoder = nil;
94-
MPSGraphExecutionDescriptor* _executionDescriptor = nil;
95-
MPSGraphExecutableExecutionDescriptor* _executableExecutionDescriptor = nil;
96-
MPSGraphCompilationDescriptor* _compilationDescriptor = nil;
9794
dispatch_queue_t _serialQueue = nullptr;
9895
// CommitAndContinue is disabled by default
9996
bool _enableCommitAndContinue = false;

backends/apple/mps/runtime/MPSStream.mm

Lines changed: 0 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -16,40 +16,18 @@ @interface MPSGraphExecutionDescriptor ()
1616
namespace mps {
1717
namespace delegate {
1818

19-
// threshold to perform adaptive commit if the accumulated size
20-
// of resources encoded on the command buffer exceeds that.
21-
static const size_t kCmdBufAdaptiveCommitThreshold = MB(64);
22-
2319
//-----------------------------------------------------------------
2420
// MPSStream
2521
//-----------------------------------------------------------------
2622

2723
MPSStream::MPSStream() {
2824
_commandQueue = [MPSDevice::getInstance()->device() newCommandQueue];
2925
_serialQueue = dispatch_queue_create("metal gpu stream", nullptr);
30-
_executionDescriptor = [MPSGraphExecutionDescriptor new];
31-
_executableExecutionDescriptor = [MPSGraphExecutableExecutionDescriptor new];
32-
_compilationDescriptor = [MPSGraphCompilationDescriptor new];
33-
34-
// internal CommitAndContinue heuristic of MPSGraph is disabled, and we
35-
// control it via Adaptive Commit in Executorch-side
36-
_executionDescriptor.enableCommitAndContinue = false;
37-
38-
// Choose level which optimizes for GPU
39-
_compilationDescriptor.optimizationLevel = MPSGraphOptimizationLevel0;
40-
_executionDescriptor.compilationDescriptor = _compilationDescriptor;
4126
}
4227

4328
MPSStream::~MPSStream() {
4429
[_commandQueue release];
4530
_commandQueue = nil;
46-
[_executionDescriptor release];
47-
[_compilationDescriptor release];
48-
[_executableExecutionDescriptor release];
49-
50-
_executionDescriptor = nil;
51-
_compilationDescriptor = nil;
52-
_executableExecutionDescriptor = nil;
5331

5432
assert(_commandBuffer == nil);
5533
}

backends/apple/mps/runtime/operations/BinaryOps.mm

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,9 @@
118118
graphNode->input2_id(), \
119119
graphNode->output_id() \
120120
); \
121+
ET_CHECK_OR_RETURN_ERROR( \
122+
isMacOS13OrNewer(), NotSupported, \
123+
"%s supported by MPS on MacOS13.0+/iOS16.1+", #aot_name); \
121124
\
122125
_idToMPSGraphTensor[graphNode->output_id()] = binaryOpTensor( \
123126
getMPSGraphTensor(graphNode->input1_id()), \
@@ -196,10 +199,7 @@
196199
MPSGraph* mpsGraph,
197200
const std::string& op_name) {
198201
MPSDataType mpsInputDataType = [primaryTensor dataType];
199-
MPSDataType mpsOtherDataType = [secondaryTensor dataType];
200-
201202
ScalarType inputDataType = getScalarType(mpsInputDataType);
202-
ScalarType otherDataType = getScalarType(mpsOtherDataType);
203203

204204
if(rounding_mode.has_value() && *rounding_mode == "trunc"){
205205
ET_CHECK_MSG(inputDataType != ScalarType::Half,

backends/apple/mps/runtime/operations/IndexingOps.mm

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,14 +22,14 @@
2222
if(castIndexTensor.dataType != MPSDataTypeInt32) {
2323
castIndexTensor = [mpsGraph castTensor:indexTensor
2424
toType:MPSDataTypeInt32
25-
name:nil];
25+
name:@"castTensor"];
2626
}
2727

2828
return [mpsGraph gatherWithUpdatesTensor:inputTensor
2929
indicesTensor:castIndexTensor
3030
axis:dim
3131
batchDimensions:0
32-
name:nil];
32+
name:@"indexSelect"];
3333
}
3434

3535
Error
@@ -48,7 +48,7 @@
4848
if(castIndexTensor.dataType != MPSDataTypeInt32) {
4949
castIndexTensor = [_mpsGraph castTensor:indexTensor
5050
toType:MPSDataTypeInt32
51-
name:nil];
51+
name:@"castTensor"];
5252
}
5353

5454
_idToMPSGraphTensor[graphNode->output_id()] =
Lines changed: 207 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,207 @@
1+
2+
//
3+
// Copyright (c) 2023 Apple Inc. All rights reserved.
4+
// Provided subject to the LICENSE file in the top level directory.
5+
//
6+
7+
#pragma once
8+
9+
#include <MetalPerformanceShadersGraph/MetalPerformanceShadersGraph.h>
10+
11+
@interface MPSGraph (VenturaOps)
12+
13+
#if !defined(__MAC_13_0) && (!defined(MAC_OS_X_VERSION_13_0) || (MAC_OS_X_VERSION_MIN_REQUIRED < MAC_OS_X_VERSION_13_0))
14+
15+
typedef NS_ENUM(NSUInteger, MPSGraphResizeNearestRoundingMode) {
16+
MPSGraphResizeNearestRoundingModeRoundPreferCeil = 0L,
17+
MPSGraphResizeNearestRoundingModeRoundPreferFloor = 1L,
18+
MPSGraphResizeNearestRoundingModeCeil = 2L,
19+
MPSGraphResizeNearestRoundingModeFloor = 3L,
20+
MPSGraphResizeNearestRoundingModeRoundToEven = 4L,
21+
MPSGraphResizeNearestRoundingModeRoundToOdd = 5L,
22+
};
23+
24+
// Define complex enums for MacOS 12
25+
#define MPSDataTypeComplexBit 0x01000000
26+
#define MPSDataTypeComplexFloat32 ((MPSDataType)(MPSDataTypeFloatBit | MPSDataTypeComplexBit | 64))
27+
#define MPSDataTypeComplexFloat16 ((MPSDataType)(MPSDataTypeFloatBit | MPSDataTypeComplexBit | 32))
28+
#endif
29+
30+
- (MPSGraphTensor *_Nonnull)cumulativeSumWithTensor:(MPSGraphTensor *_Nonnull)tensor
31+
axis:(NSInteger)axis
32+
name:(NSString *_Nullable)name;
33+
34+
- (MPSGraphTensor *_Nonnull)sortWithTensor:(MPSGraphTensor *_Nonnull)tensor
35+
axis:(NSInteger)axis
36+
name:(NSString *_Nullable)name;
37+
38+
- (MPSGraphTensor *_Nonnull)sortWithTensor:(MPSGraphTensor *_Nonnull)tensor
39+
axis:(NSInteger)axis
40+
descending:(BOOL)descending
41+
name:(NSString *_Nullable)name;
42+
43+
- (MPSGraphTensor *_Nonnull)sortWithTensor:(MPSGraphTensor *_Nonnull)tensor
44+
axisTensor:(MPSGraphTensor *_Nonnull)axisTensor
45+
descending:(BOOL)descending
46+
name:(NSString *_Nullable)name;
47+
48+
- (MPSGraphTensor *_Nonnull)sortWithTensor:(MPSGraphTensor *_Nonnull)tensor
49+
axisTensor:(MPSGraphTensor *_Nonnull)axisTensor
50+
name:(NSString *_Nullable)name;
51+
52+
- (MPSGraphTensor *_Nonnull)argSortWithTensor:(MPSGraphTensor *_Nonnull)tensor
53+
axis:(NSInteger)axis
54+
name:(NSString *_Nullable)name;
55+
56+
- (MPSGraphTensor *_Nonnull)argSortWithTensor:(MPSGraphTensor *_Nonnull)tensor
57+
axis:(NSInteger)axis
58+
descending:(BOOL)descending
59+
name:(NSString *_Nullable)name;
60+
61+
- (MPSGraphTensor *_Nonnull)argSortWithTensor:(MPSGraphTensor *_Nonnull)tensor
62+
axisTensor:(MPSGraphTensor *_Nonnull)axisTensor
63+
descending:(BOOL)descending
64+
name:(NSString *_Nullable)name;
65+
66+
- (MPSGraphTensor *_Nonnull)argSortWithTensor:(MPSGraphTensor *_Nonnull)tensor
67+
axisTensor:(MPSGraphTensor *_Nonnull)axisTensor
68+
name:(NSString *_Nullable)name;
69+
70+
- (MPSGraphTensor *_Nonnull)inverseOfTensor:(MPSGraphTensor *_Nonnull)inputTensor name:(NSString *_Nullable)name;
71+
72+
- (MPSGraphTensor *_Nonnull)resizeNearestWithTensor:(MPSGraphTensor *_Nonnull)imagesTensor
73+
sizeTensor:(MPSGraphTensor *_Nonnull)size
74+
nearestRoundingMode:(MPSGraphResizeNearestRoundingMode)nearestRoundingMode
75+
centerResult:(BOOL)centerResult
76+
alignCorners:(BOOL)alignCorners
77+
layout:(MPSGraphTensorNamedDataLayout)layout
78+
name:(NSString *_Nullable)name;
79+
80+
- (MPSGraphTensor *_Nonnull)resizeNearestWithTensor:(MPSGraphTensor *_Nonnull)imagesTensor
81+
sizeTensor:(MPSGraphTensor *_Nonnull)size
82+
scaleOffsetTensor:(MPSGraphTensor *_Nonnull)scaleOffset
83+
nearestRoundingMode:(MPSGraphResizeNearestRoundingMode)nearestRoundingMode
84+
layout:(MPSGraphTensorNamedDataLayout)layout
85+
name:(NSString *_Nullable)name;
86+
87+
- (MPSGraphTensor *_Nonnull)resizeBilinearWithTensor:(MPSGraphTensor *_Nonnull)imagesTensor
88+
sizeTensor:(MPSGraphTensor *_Nonnull)size
89+
centerResult:(BOOL)centerResult
90+
alignCorners:(BOOL)alignCorners
91+
layout:(MPSGraphTensorNamedDataLayout)layout
92+
name:(NSString *_Nullable)name;
93+
94+
- (MPSGraphTensor *_Nonnull)resizeBilinearWithTensor:(MPSGraphTensor *_Nonnull)imagesTensor
95+
sizeTensor:(MPSGraphTensor *_Nonnull)size
96+
scaleOffsetTensor:(MPSGraphTensor *_Nonnull)scaleOffset
97+
layout:(MPSGraphTensorNamedDataLayout)layout
98+
name:(NSString *_Nullable)name;
99+
100+
- (MPSGraphTensor *_Nonnull)resizeNearestWithGradientTensor:(MPSGraphTensor *_Nonnull)gradient
101+
input:(MPSGraphTensor *_Nonnull)input
102+
nearestRoundingMode:(MPSGraphResizeNearestRoundingMode)nearestRoundingMode
103+
centerResult:(BOOL)centerResult
104+
alignCorners:(BOOL)alignCorners
105+
layout:(MPSGraphTensorNamedDataLayout)layout
106+
name:(NSString *_Nullable)name;
107+
108+
- (MPSGraphTensor *_Nonnull)resizeNearestWithGradientTensor:(MPSGraphTensor *_Nonnull)gradient
109+
input:(MPSGraphTensor *_Nonnull)input
110+
scaleOffsetTensor:(MPSGraphTensor *_Nonnull)scaleOffset
111+
nearestRoundingMode:(MPSGraphResizeNearestRoundingMode)nearestRoundingMode
112+
layout:(MPSGraphTensorNamedDataLayout)layout
113+
name:(NSString *_Nullable)name;
114+
115+
- (MPSGraphTensor *_Nonnull)resizeBilinearWithGradientTensor:(MPSGraphTensor *_Nonnull)gradient
116+
input:(MPSGraphTensor *_Nonnull)input
117+
centerResult:(BOOL)centerResult
118+
alignCorners:(BOOL)alignCorners
119+
layout:(MPSGraphTensorNamedDataLayout)layout
120+
name:(NSString *_Nullable)name;
121+
122+
- (MPSGraphTensor *_Nonnull)resizeBilinearWithGradientTensor:(MPSGraphTensor *_Nonnull)gradient
123+
input:(MPSGraphTensor *_Nonnull)input
124+
scaleOffsetTensor:(MPSGraphTensor *_Nonnull)scaleOffset
125+
layout:(MPSGraphTensorNamedDataLayout)layout
126+
name:(NSString *_Nullable)name;
127+
128+
- (MPSGraphTensor *_Nonnull)sampleGridWithSourceTensor:(MPSGraphTensor *_Nonnull)source
129+
coordinateTensor:(MPSGraphTensor *_Nonnull)coordinates
130+
layout:(MPSGraphTensorNamedDataLayout)layout
131+
normalizeCoordinates:(BOOL)normalizeCoordinates
132+
relativeCoordinates:(BOOL)relativeCoordinates
133+
alignCorners:(BOOL)alignCorners
134+
paddingMode:(MPSGraphPaddingMode)paddingMode
135+
samplingMode:(MPSGraphResizeMode)samplingMode
136+
constantValue:(double)constantValue
137+
name:(NSString *_Nullable)name;
138+
139+
- (MPSGraphTensor *_Nonnull)sampleGridWithSourceTensor:(MPSGraphTensor *_Nonnull)source
140+
coordinateTensor:(MPSGraphTensor *_Nonnull)coordinates
141+
layout:(MPSGraphTensorNamedDataLayout)layout
142+
normalizeCoordinates:(BOOL)normalizeCoordinates
143+
relativeCoordinates:(BOOL)relativeCoordinates
144+
alignCorners:(BOOL)alignCorners
145+
paddingMode:(MPSGraphPaddingMode)paddingMode
146+
nearestRoundingMode:(MPSGraphResizeNearestRoundingMode)nearestRoundingMode
147+
constantValue:(double)constantValue
148+
name:(NSString *_Nullable)name;
149+
150+
- (MPSGraphTensor *_Nonnull)truncateWithTensor:(MPSGraphTensor *_Nonnull)tensor name:(NSString *_Nullable)name;
151+
152+
- (MPSGraphTensor *_Nonnull)transposeTensor:(MPSGraphTensor *_Nonnull)tensor
153+
permutation:(NSArray<NSNumber *> *_Nonnull)permutation
154+
name:(NSString *_Nullable)name;
155+
156+
- (MPSGraphTensor *_Nonnull)bitwiseANDWithPrimaryTensor:(MPSGraphTensor *_Nonnull)primaryTensor
157+
secondaryTensor:(MPSGraphTensor *_Nonnull)secondaryTensor
158+
name:(NSString *_Nullable)name;
159+
160+
- (MPSGraphTensor *_Nonnull)bitwiseORWithPrimaryTensor:(MPSGraphTensor *_Nonnull)primaryTensor
161+
secondaryTensor:(MPSGraphTensor *_Nonnull)secondaryTensor
162+
name:(NSString *_Nullable)name;
163+
164+
- (MPSGraphTensor *_Nonnull)bitwiseXORWithPrimaryTensor:(MPSGraphTensor *_Nonnull)primaryTensor
165+
secondaryTensor:(MPSGraphTensor *_Nonnull)secondaryTensor
166+
name:(NSString *_Nullable)name;
167+
168+
- (MPSGraphTensor *_Nonnull)bitwiseNOTWithTensor:(MPSGraphTensor *_Nonnull)tensor name:(NSString *_Nullable)name;
169+
170+
#if !defined(MAC_OS_X_VERSION_12_2) || (MAC_OS_X_VERSION_MIN_REQUIRED < MAC_OS_X_VERSION_12_2)
171+
- (MPSGraphTensor *_Nullable)expandDimsOfTensor:(MPSGraphTensor *_Nullable)tensor
172+
axis:(NSInteger)axis
173+
name:(NSString *_Nullable)name;
174+
175+
- (MPSGraphTensor *_Nullable)expandDimsOfTensor:(MPSGraphTensor *_Nullable)tensor
176+
axes:(NSArray<NSNumber *> *_Nullable)axes
177+
name:(NSString *_Nullable)name;
178+
179+
- (MPSGraphTensor *_Nullable)squeezeTensor:(MPSGraphTensor *_Nullable)tensor
180+
axes:(NSArray<NSNumber *> *_Nullable)axes
181+
name:(NSString *_Nullable)name;
182+
183+
- (MPSGraphTensor *_Nullable)squeezeTensor:(MPSGraphTensor *_Nullable)tensor
184+
axis:(NSInteger)axis
185+
name:(NSString *_Nullable)name;
186+
187+
- (NSArray<MPSGraphTensor *> *_Nullable)
188+
maxPooling2DReturnIndicesWithSourceTensor:(MPSGraphTensor *_Nullable)source
189+
descriptor:(MPSGraphPooling2DOpDescriptor *_Nullable)descriptor
190+
name:(NSString *_Nullable)name;
191+
192+
- (MPSGraphTensor *_Nullable)coordinateAlongAxis:(NSInteger)axis
193+
withShapeTensor:(MPSGraphTensor *_Nullable)shapeTensor
194+
name:(NSString *_Nullable)name;
195+
196+
- (NSArray<MPSGraphTensor *> *_Nullable)splitTensor:(MPSGraphTensor *_Nullable)tensor
197+
splitSizesTensor:(MPSGraphTensor *_Nullable)splitSizesTensor
198+
axis:(NSInteger)axis
199+
name:(NSString *_Nullable)name;
200+
201+
- (NSArray<MPSGraphTensor *> *_Nullable)splitTensor:(MPSGraphTensor *_Nullable)tensor
202+
splitSizes:(NSArray<NSNumber *> *_Nullable)splitSizes
203+
axis:(NSInteger)axis
204+
name:(NSString *_Nullable)name;
205+
#endif
206+
207+
@end

backends/apple/mps/runtime/operations/OperationUtils.mm

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -227,7 +227,6 @@
227227

228228
MPSGraphTensor*
229229
MPSGraphBuilder::getMPSGraphTensor(int32_t id) {
230-
static int32_t cacheEntries = _idToMPSGraphTensor.size();
231230
return _idToMPSGraphTensor[id];
232231
}
233232

backends/apple/mps/runtime/operations/PadOps.mm

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@
2727
"invalid padding argument of size %d", padding_size);
2828

2929
auto input_sizes = getMPSShapeVec(input.shape);
30-
int64_t nbatch = 1;
3130
int64_t ndims = input_sizes.size();
3231

3332
ET_CHECK_MSG(
@@ -39,7 +38,6 @@
3938
int dim_w = padding_dim;
4039
int dim_h = padding_dim - 1;
4140
int dim_d = padding_dim - 2;
42-
int dim_slices = 0;
4341

4442
if (mode != MPSGraphPaddingModeConstant && ndims > padding_dim) {
4543
bool valid_dims = input_sizes[1] != 0 && input_sizes[padding_dim] != 0;
@@ -59,8 +57,6 @@
5957
dim_w += dim_diff;
6058
dim_h += dim_diff;
6159
dim_d += dim_diff;
62-
dim_slices++;
63-
nbatch = input_sizes[0];
6460
}
6561

6662
int64_t pad_l = padding[0];
@@ -70,13 +66,11 @@
7066
int64_t pad_front = padding_size > 4 ? padding[4] : 0;
7167
int64_t pad_back = padding_size > 4 ? padding[5] : 0;
7268

73-
int64_t nplane = input_sizes[dim_slices];
7469
int64_t input_w = input_sizes[dim_w];
7570
int64_t output_w = input_w + pad_l + pad_r;
7671
int64_t input_h = padding_dim > 1 ? input_sizes[dim_h] : 0;
7772
int64_t output_h = padding_dim > 1 ? input_h + pad_t + pad_b : 0;
7873
int64_t input_d = padding_dim > 2 ? input_sizes[dim_d] : 0;
79-
int64_t output_d = padding_dim > 2 ? input_d + pad_front + pad_back : 0;
8074

8175
ET_CHECK_MSG(
8276
output_w >= 1 || output_h >= padding_dim - 1,

0 commit comments

Comments
 (0)