1
+ // ===- VectorLegalization.cpp - Legalize vectors for lowering to ArmSME ---===//
2
+ //
3
+ // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4
+ // See https://llvm.org/LICENSE.txt for license information.
5
+ // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6
+ //
7
+ // ===----------------------------------------------------------------------===//
8
+
1
9
#include " mlir/Dialect/ArmSME/IR/ArmSME.h"
2
10
#include " mlir/Dialect/ArmSME/Transforms/Passes.h"
3
11
#include " mlir/Dialect/ArmSME/Utils/Utils.h"
@@ -19,15 +27,24 @@ using namespace mlir::arm_sme;
19
27
20
28
namespace {
21
29
30
+ // Common match failure reasons.
31
+ static constexpr StringLiteral MATCH_FAILURE_NOT_SME_TILE_TYPE_MULTIPLE (
32
+ " op vector size is not multiple of SME tiles" );
33
+ static constexpr StringLiteral MATCH_FAILURE_UNSUPPORTED_MASK_OP (
34
+ " op mask is unsupported for legalization/decomposition" );
35
+ static constexpr StringLiteral
36
+ MATCH_FAILURE_NON_PERMUTATION_MAP (" op affine map is not a permutation" );
37
+
22
38
struct SMETile {
23
39
// Note: The units of (row, col) are vscale (as SME tiles are scalable).
24
40
int row{0 };
25
41
int col{0 };
26
42
VectorType type;
27
43
};
28
44
29
- // / Adds a constant scalable offset to `indices`. i.e. for 2D:
30
- // / { indices[0] + offset[0] * vscale, indices[1] + offset[1] * vscale }
45
+ // / Adds a constant scalable offset to `indices` (which are of equal length).
46
+ // / For example, in the 2D case this would return:
47
+ // { indices[0] + offset[0] * vscale, indices[1] + offset[1] * vscale }
31
48
SmallVector<Value, 2 > addConstantScalableOffset (OpBuilder &builder,
32
49
Location loc,
33
50
ValueRange indices,
@@ -42,8 +59,20 @@ SmallVector<Value, 2> addConstantScalableOffset(OpBuilder &builder,
42
59
});
43
60
}
44
61
45
- // / Remaps indices (e.g. from a load/store) for a larger vector type to indices
46
- // / for one of the SME tiles it will decompose into.
62
+ // / Remaps `indices` (e.g. from a load/store) for a larger vector type to
63
+ // / indices for one of the SME tiles it will decompose into.
64
+ // /
65
+ // / For example, if you were to decompose an 8x8 load into four 4x4 tiles, the
66
+ // / indices for each tile would need to be remapped as follows:
67
+ // /
68
+ // / initial indices = [a,b], inital size = 8x8, target size = 4x4
69
+ // / ┌─────────────┬─────────────┐
70
+ // / │[a,b] │[a,b+4] │
71
+ // / │ │ │
72
+ // / ├─────────────┼─────────────┤
73
+ // / │[a+4,b] │[a+4,b+4] │
74
+ // / │ │ │
75
+ // / └─────────────┴─────────────┘
47
76
SmallVector<Value, 2 > remapIndicesForSMETile (OpBuilder &builder, Location loc,
48
77
ValueRange indices,
49
78
SMETile tileTile) {
@@ -64,7 +93,7 @@ Value extractSMEMask(OpBuilder &builder, Location loc, Value mask,
64
93
if (!mask)
65
94
return Value{};
66
95
auto createMask = mask.getDefiningOp <vector::CreateMaskOp>();
67
- // The the operands of `vector.create_mask` (from a 2D perspective) are the
96
+ // The operands of `vector.create_mask` (from a 2D perspective) are the
68
97
// coordinates where the mask ends. So we subtract where this tile starts,
69
98
// from the mask operands to get the parameters for this tile tile.
70
99
auto tileMaskDims = addConstantScalableOffset (
@@ -75,7 +104,9 @@ Value extractSMEMask(OpBuilder &builder, Location loc, Value mask,
75
104
}
76
105
77
106
// / Constructs an iterator that returns each SME tile (with coordinates)
78
- // / contained within a VectorType.
107
+ // / contained within a VectorType. For example, if decomposing an [8]x[8] into
108
+ // / [4]x[4] tiles, the iterator would yield the tiles: (0, 0), (0, 4), (4, 0),
109
+ // / (4, 4).
79
110
auto decomposeToSMETiles (OpBuilder &builder, VectorType type,
80
111
VectorType smeTileType,
81
112
bool transposeIndices = false ) {
@@ -92,7 +123,8 @@ auto decomposeToSMETiles(OpBuilder &builder, VectorType type,
92
123
});
93
124
}
94
125
95
- // / Returns the number of SME tiles that fit into the a vector type.
126
+ // / Returns the number of SME tiles that fit into the (2D-scalable) vector type
127
+ // / `type`.
96
128
int getNumberOfSMETilesForVectorType (VectorType type) {
97
129
assert (isMultipleOfSMETileVectorType (type));
98
130
int64_t vectorRows = type.getDimSize (0 );
@@ -102,8 +134,9 @@ int getNumberOfSMETilesForVectorType(VectorType type) {
102
134
return (vectorRows * vectorCols) / (minNumElts * minNumElts);
103
135
}
104
136
105
- // / Legalize `vector.outerproduct` operations to fit within SME tiles.
106
- struct LegalizeVectorOuterProductOp
137
+ // / Legalize `vector.outerproduct` operations to fit within SME tiles by
138
+ // / decomposing them into tile-sized operations.
139
+ struct LegalizeVectorOuterProductOpsByDecomposition
107
140
: public OneToNOpConversionPattern<vector::OuterProductOp> {
108
141
using OneToNOpConversionPattern::OneToNOpConversionPattern;
109
142
@@ -112,7 +145,8 @@ struct LegalizeVectorOuterProductOp
112
145
OneToNPatternRewriter &rewriter) const override {
113
146
auto vectorType = outerProductOp.getResultVectorType ();
114
147
if (!isMultipleOfSMETileVectorType (vectorType))
115
- return failure ();
148
+ return rewriter.notifyMatchFailure (
149
+ outerProductOp, MATCH_FAILURE_NOT_SME_TILE_TYPE_MULTIPLE);
116
150
117
151
Value mask;
118
152
Operation *rootOp = outerProductOp;
@@ -124,7 +158,8 @@ struct LegalizeVectorOuterProductOp
124
158
}
125
159
126
160
if (!isSupportedMaskOp (mask))
127
- return failure ();
161
+ return rewriter.notifyMatchFailure (outerProductOp,
162
+ MATCH_FAILURE_UNSUPPORTED_MASK_OP);
128
163
129
164
ValueRange accSMETiles = adaptor.getAcc ();
130
165
auto tileType = getSMETileTypeForElement (vectorType.getElementType ());
@@ -159,7 +194,7 @@ struct LegalizeVectorOuterProductOp
159
194
// conversion adding target materializations in the `vector.mask` region
160
195
// (invalid). This pattern matches on `vector.mask` then calls into the
161
196
// `vector.outerproduct` pattern to work around this issue.
162
- struct LegalizeMaskedVectorOuterProductOp
197
+ struct LegalizeMaskedVectorOuterProductOpsByDecomposition
163
198
: public OneToNOpConversionPattern<vector::MaskOp> {
164
199
using OneToNOpConversionPattern::OneToNOpConversionPattern;
165
200
@@ -168,16 +203,18 @@ struct LegalizeMaskedVectorOuterProductOp
168
203
OneToNPatternRewriter &rewriter) const override {
169
204
if (auto outerProductOp =
170
205
llvm::dyn_cast<vector::OuterProductOp>(maskOp.getMaskableOp ())) {
171
- LegalizeVectorOuterProductOp pattern (*getTypeConverter (), getContext ());
206
+ LegalizeVectorOuterProductOpsByDecomposition pattern (*getTypeConverter (),
207
+ getContext ());
172
208
return static_cast <RewritePattern &>(pattern).matchAndRewrite (
173
209
outerProductOp, rewriter);
174
210
}
175
211
return failure ();
176
212
}
177
213
};
178
214
179
- // / Legalize `vector.transfer_read` operations to fit within SME tiles.
180
- struct LegalizeTransferReadOp
215
+ // / Legalize `vector.transfer_read` operations to fit within SME tiles by
216
+ // / decomposing them into tile-sized operations.
217
+ struct LegalizeTransferReadOpsByDecomposition
181
218
: public OneToNOpConversionPattern<vector::TransferReadOp> {
182
219
using OneToNOpConversionPattern::OneToNOpConversionPattern;
183
220
@@ -186,15 +223,18 @@ struct LegalizeTransferReadOp
186
223
OneToNPatternRewriter &rewriter) const override {
187
224
auto vectorType = readOp.getVectorType ();
188
225
if (!isMultipleOfSMETileVectorType (vectorType))
189
- return failure ();
226
+ return rewriter.notifyMatchFailure (
227
+ readOp, MATCH_FAILURE_NOT_SME_TILE_TYPE_MULTIPLE);
190
228
191
229
auto mask = readOp.getMask ();
192
230
if (!isSupportedMaskOp (mask))
193
- return failure ();
231
+ return rewriter.notifyMatchFailure (readOp,
232
+ MATCH_FAILURE_UNSUPPORTED_MASK_OP);
194
233
195
234
auto permutationMap = readOp.getPermutationMap ();
196
235
if (!permutationMap.isPermutation ())
197
- return failure ();
236
+ return rewriter.notifyMatchFailure (readOp,
237
+ MATCH_FAILURE_NON_PERMUTATION_MAP);
198
238
199
239
// Note: For 2D vector types the only non-identity permutation is a simple
200
240
// tranpose [1, 0].
@@ -220,8 +260,9 @@ struct LegalizeTransferReadOp
220
260
}
221
261
};
222
262
223
- // / Legalize `vector.transfer_write` operations to fit within SME tiles.
224
- struct LegalizeTransferWriteOp
263
+ // / Legalize `vector.transfer_write` operations to fit within SME tiles by
264
+ // / decomposing them into tile-sized operations.
265
+ struct LegalizeTransferWriteOpsByDecomposition
225
266
: public OneToNOpConversionPattern<vector::TransferWriteOp> {
226
267
using OneToNOpConversionPattern::OneToNOpConversionPattern;
227
268
@@ -230,15 +271,18 @@ struct LegalizeTransferWriteOp
230
271
OneToNPatternRewriter &rewriter) const override {
231
272
auto vectorType = writeOp.getVectorType ();
232
273
if (!isMultipleOfSMETileVectorType (vectorType))
233
- return failure ();
274
+ return rewriter.notifyMatchFailure (
275
+ writeOp, MATCH_FAILURE_NOT_SME_TILE_TYPE_MULTIPLE);
234
276
235
277
auto mask = writeOp.getMask ();
236
278
if (!isSupportedMaskOp (mask))
237
- return failure ();
279
+ return rewriter.notifyMatchFailure (writeOp,
280
+ MATCH_FAILURE_UNSUPPORTED_MASK_OP);
238
281
239
282
auto permutationMap = writeOp.getPermutationMap ();
240
283
if (!permutationMap.isPermutation ())
241
- return failure ();
284
+ return rewriter.notifyMatchFailure (writeOp,
285
+ MATCH_FAILURE_NON_PERMUTATION_MAP);
242
286
243
287
// Note: For 2D vector types the only non-identity permutation is a simple
244
288
// tranpose [1, 0].
@@ -289,9 +333,11 @@ struct VectorLegalizationPass
289
333
});
290
334
291
335
// Note: High benefit to ensure masked outer products are lowered first.
292
- patterns.add <LegalizeMaskedVectorOuterProductOp>(converter, context, 1024 );
293
- patterns.add <LegalizeVectorOuterProductOp, LegalizeTransferReadOp,
294
- LegalizeTransferWriteOp>(converter, context);
336
+ patterns.add <LegalizeMaskedVectorOuterProductOpsByDecomposition>(
337
+ converter, context, 1024 );
338
+ patterns.add <LegalizeVectorOuterProductOpsByDecomposition,
339
+ LegalizeTransferReadOpsByDecomposition,
340
+ LegalizeTransferWriteOpsByDecomposition>(converter, context);
295
341
populateFuncTypeConversionPatterns (converter, patterns);
296
342
scf::populateSCFStructuralOneToNTypeConversions (converter, patterns);
297
343
0 commit comments