5
5
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6
6
//
7
7
// ===----------------------------------------------------------------------===//
8
+ //
9
+ // This pass legalizes vector operations so they can be lowered to ArmSME.
10
+ // Currently, this only implements the decomposition of vector operations that
11
+ // use vector sizes larger than an SME tile, into multiple SME-sized operations.
12
+ //
13
+ // Note: In the context of this pass 'tile' always refers to an SME tile.
14
+ //
15
+ // ===----------------------------------------------------------------------===//
8
16
9
17
#include " mlir/Dialect/ArmSME/IR/ArmSME.h"
10
18
#include " mlir/Dialect/ArmSME/Transforms/Passes.h"
@@ -35,35 +43,49 @@ static constexpr StringLiteral MATCH_FAILURE_UNSUPPORTED_MASK_OP(
35
43
static constexpr StringLiteral
36
44
MATCH_FAILURE_NON_PERMUTATION_MAP (" op affine map is not a permutation" );
37
45
38
- struct SMETile {
46
+ // / An SMESubTile represents a single SME-sized sub-tile from decomposing a
47
+ // / larger vector type. The (`row`, `col`) are the position of the tile in the
48
+ // / original vector type. For example for an [8]x[8] tile would have four
49
+ // / [4]x[4] sub-tiles.
50
+ // /
51
+ // / 8 x vscale
52
+ // / ┌─────────────┬─────────────┐
53
+ // / │(0,0) │(0,4) │
54
+ // / │ │ │
55
+ // / ├─────────────┼─────────────┤ 8 x vscale
56
+ // / │(4,0) │(4,4) │
57
+ // / │ │ │
58
+ // / └─────────────┴─────────────┘
59
+ struct SMESubTile {
39
60
// Note: The units of (row, col) are vscale (as SME tiles are scalable).
40
61
int row{0 };
41
62
int col{0 };
63
+ // The SME tile type.
42
64
VectorType type;
43
65
};
44
66
45
- // / Adds a constant scalable offset to `indices` (which are of equal length).
46
- // / For example, in the 2D case this would return:
67
+ // / Adds a constant elementwise scalable offset to `indices` (which are of equal
68
+ // / length). For example, in the 2D case this would return:
47
69
// { indices[0] + offset[0] * vscale, indices[1] + offset[1] * vscale }
48
70
SmallVector<Value, 2 > addConstantScalableOffset (OpBuilder &builder,
49
71
Location loc,
50
72
ValueRange indices,
51
- ArrayRef<int > scalableOffset ) {
73
+ ArrayRef<int > scalableOffsets ) {
52
74
auto vscale = builder.create <vector::VectorScaleOp>(loc);
53
75
return llvm::map_to_vector (
54
- llvm::zip_equal (indices, scalableOffset ), [&](auto pair) -> Value {
76
+ llvm::zip_equal (indices, scalableOffsets ), [&](auto pair) -> Value {
55
77
auto [index, base] = pair;
56
78
auto offset = builder.create <arith::MulIOp>(
57
79
loc, builder.create <arith::ConstantIndexOp>(loc, base), vscale);
58
80
return builder.create <arith::AddIOp>(loc, index, offset);
59
81
});
60
82
}
61
83
62
- // / Remaps `indices` (e.g. from a load/store) for a larger vector type to
63
- // / indices for one of the SME tiles it will decompose into.
84
+ // / Adjusts `indices` (e.g. from a load/store) for a larger vector type to
85
+ // / indices for one of the SME sub- tiles it will decompose into.
64
86
// /
65
87
// / For example, if you were to decompose an 8x8 load into four 4x4 tiles, the
66
- // / indices for each tile would need to be remapped as follows:
88
+ // / indices for each tile would need to be adjusted as follows:
67
89
// /
68
90
// / initial indices = [a,b], inital size = 8x8, target size = 4x4
69
91
// / ┌─────────────┬─────────────┐
@@ -73,11 +95,11 @@ SmallVector<Value, 2> addConstantScalableOffset(OpBuilder &builder,
73
95
// / │[a+4,b] │[a+4,b+4] │
74
96
// / │ │ │
75
97
// / └─────────────┴─────────────┘
76
- SmallVector<Value, 2 > remapIndicesForSMETile (OpBuilder &builder, Location loc,
77
- ValueRange indices,
78
- SMETile tileTile ) {
98
+ SmallVector<Value, 2 > getSMESubTileIndices (OpBuilder &builder, Location loc,
99
+ ValueRange indices,
100
+ SMESubTile smeTile ) {
79
101
return addConstantScalableOffset (builder, loc, indices,
80
- {tileTile .row , tileTile .col });
102
+ {smeTile .row , smeTile .col });
81
103
}
82
104
83
105
// / Returns true if `mask` is generated by an operation that can be decomposed
@@ -86,21 +108,21 @@ bool isSupportedMaskOp(Value mask) {
86
108
return !mask || mask.getDefiningOp <vector::CreateMaskOp>();
87
109
}
88
110
89
- // / Extracts a mask for an SME tile from the mask of a larger vector type.
111
+ // / Extracts a mask for an SME sub- tile from the mask of a larger vector type.
90
112
Value extractSMEMask (OpBuilder &builder, Location loc, Value mask,
91
- SMETile tileTile ) {
113
+ SMESubTile smeTile ) {
92
114
assert (isSupportedMaskOp (mask));
93
115
if (!mask)
94
116
return Value{};
95
117
auto createMask = mask.getDefiningOp <vector::CreateMaskOp>();
96
118
// The operands of `vector.create_mask` (from a 2D perspective) are the
97
119
// coordinates where the mask ends. So we subtract where this tile starts,
98
- // from the mask operands to get the parameters for this tile tile.
99
- auto tileMaskDims = addConstantScalableOffset (
100
- builder, loc, createMask.getOperands (), {-tileTile .row , -tileTile .col });
101
- auto createTileMask = builder.create <vector::CreateMaskOp>(
102
- loc, tileTile .type .clone (builder.getI1Type ()), tileMaskDims );
103
- return createTileMask .getResult ();
120
+ // from the mask operands to get the parameters for this sub- tile.
121
+ auto smeTileMaskDims = addConstantScalableOffset (
122
+ builder, loc, createMask.getOperands (), {-smeTile .row , -smeTile .col });
123
+ auto smeTileCreateMask = builder.create <vector::CreateMaskOp>(
124
+ loc, smeTile .type .clone (builder.getI1Type ()), smeTileMaskDims );
125
+ return smeTileCreateMask .getResult ();
104
126
}
105
127
106
128
// / Constructs an iterator that returns each SME tile (with coordinates)
@@ -110,7 +132,8 @@ Value extractSMEMask(OpBuilder &builder, Location loc, Value mask,
110
132
auto decomposeToSMETiles (OpBuilder &builder, VectorType type,
111
133
VectorType smeTileType,
112
134
bool transposeIndices = false ) {
113
- assert (isMultipleOfSMETileVectorType (type));
135
+ assert (isMultipleOfSMETileVectorType (type) &&
136
+ " `type` not multiple of SME tiles" );
114
137
return llvm::map_range (
115
138
StaticTileOffsetRange (type.getShape (), {smeTileType.getDimSize (0 ),
116
139
smeTileType.getDimSize (1 )}),
@@ -119,14 +142,15 @@ auto decomposeToSMETiles(OpBuilder &builder, VectorType type,
119
142
int col = int (indices[1 ]);
120
143
if (transposeIndices)
121
144
std::swap (row, col);
122
- return SMETile {row, col, smeTileType};
145
+ return SMESubTile {row, col, smeTileType};
123
146
});
124
147
}
125
148
126
149
// / Returns the number of SME tiles that fit into the (2D-scalable) vector type
127
150
// / `type`.
128
151
int getNumberOfSMETilesForVectorType (VectorType type) {
129
- assert (isMultipleOfSMETileVectorType (type));
152
+ assert (isMultipleOfSMETileVectorType (type) &&
153
+ " `type` not multiple of SME tiles" );
130
154
int64_t vectorRows = type.getDimSize (0 );
131
155
int64_t vectorCols = type.getDimSize (1 );
132
156
auto elementType = type.getElementType ();
@@ -162,25 +186,25 @@ struct LegalizeVectorOuterProductOpsByDecomposition
162
186
MATCH_FAILURE_UNSUPPORTED_MASK_OP);
163
187
164
188
ValueRange accSMETiles = adaptor.getAcc ();
165
- auto tileType = getSMETileTypeForElement (vectorType.getElementType ());
166
- VectorType sliceType = VectorType::Builder (tileType ).dropDim (0 );
189
+ auto smeTileType = getSMETileTypeForElement (vectorType.getElementType ());
190
+ VectorType sliceType = VectorType::Builder (smeTileType ).dropDim (0 );
167
191
168
192
SmallVector<Value> resultSMETiles;
169
- for (auto [index, tileTile ] :
170
- llvm::enumerate ( decomposeToSMETiles (rewriter, vectorType, tileType ))) {
193
+ for (auto [index, smeTile ] : llvm::enumerate (
194
+ decomposeToSMETiles (rewriter, vectorType, smeTileType ))) {
171
195
172
- auto tileMask = extractSMEMask (rewriter, loc, mask, tileTile );
196
+ auto smeMask = extractSMEMask (rewriter, loc, mask, smeTile );
173
197
auto lhs = rewriter.create <vector::ScalableExtractOp>(
174
- loc, sliceType, outerProductOp.getLhs (), tileTile .row );
198
+ loc, sliceType, outerProductOp.getLhs (), smeTile .row );
175
199
auto rhs = rewriter.create <vector::ScalableExtractOp>(
176
- loc, sliceType, outerProductOp.getRhs (), tileTile .col );
177
- auto tileOuterProduct = rewriter.create <vector::OuterProductOp>(
178
- loc, tileType , lhs, rhs,
200
+ loc, sliceType, outerProductOp.getRhs (), smeTile .col );
201
+ auto smeOuterProduct = rewriter.create <vector::OuterProductOp>(
202
+ loc, smeTileType , lhs, rhs,
179
203
!accSMETiles.empty () ? accSMETiles[index] : Value{},
180
204
outerProductOp.getKind ());
181
205
182
206
auto maskedOuterProduct =
183
- vector::maskOperation (rewriter, tileOuterProduct, tileMask );
207
+ vector::maskOperation (rewriter, smeOuterProduct, smeMask );
184
208
resultSMETiles.push_back (maskedOuterProduct->getResult (0 ));
185
209
}
186
210
@@ -241,18 +265,18 @@ struct LegalizeTransferReadOpsByDecomposition
241
265
bool transposed = !permutationMap.isIdentity ();
242
266
243
267
auto loc = readOp.getLoc ();
244
- auto tileType = getSMETileTypeForElement (vectorType.getElementType ());
268
+ auto smeTileType = getSMETileTypeForElement (vectorType.getElementType ());
245
269
246
270
SmallVector<Value> resultSMETiles;
247
- for (SMETile tileTile :
248
- decomposeToSMETiles (rewriter, vectorType, tileType , transposed)) {
249
- auto tileMask = extractSMEMask (rewriter, loc, mask, tileTile );
250
- auto transferRead = rewriter.create <vector::TransferReadOp>(
251
- loc, tileType , readOp.getSource (),
252
- remapIndicesForSMETile (rewriter, loc, readOp.getIndices (), tileTile ),
253
- readOp.getPermutationMapAttr (), readOp.getPadding (), tileMask ,
271
+ for (SMESubTile smeTile :
272
+ decomposeToSMETiles (rewriter, vectorType, smeTileType , transposed)) {
273
+ auto smeMask = extractSMEMask (rewriter, loc, mask, smeTile );
274
+ auto smeRead = rewriter.create <vector::TransferReadOp>(
275
+ loc, smeTileType , readOp.getSource (),
276
+ getSMESubTileIndices (rewriter, loc, readOp.getIndices (), smeTile ),
277
+ readOp.getPermutationMapAttr (), readOp.getPadding (), smeMask ,
254
278
readOp.getInBoundsAttr ());
255
- resultSMETiles.push_back (transferRead );
279
+ resultSMETiles.push_back (smeRead );
256
280
}
257
281
258
282
rewriter.replaceOp (readOp, resultSMETiles, adaptor.getResultMapping ());
@@ -289,19 +313,19 @@ struct LegalizeTransferWriteOpsByDecomposition
289
313
bool transposed = !permutationMap.isIdentity ();
290
314
291
315
auto loc = writeOp.getLoc ();
292
- auto tileType = getSMETileTypeForElement (vectorType.getElementType ());
316
+ auto smeTileType = getSMETileTypeForElement (vectorType.getElementType ());
293
317
auto inputSMETiles = adaptor.getVector ();
294
318
295
319
Value destTensorOrMemref = writeOp.getSource ();
296
- for (auto [index, tileTile ] : llvm::enumerate (
297
- decomposeToSMETiles ( rewriter, vectorType, tileType , transposed))) {
298
- auto tileMask = extractSMEMask (rewriter, loc, mask, tileTile );
299
- auto tileWrite = rewriter.create <vector::TransferWriteOp>(
320
+ for (auto [index, smeTile ] : llvm::enumerate ( decomposeToSMETiles (
321
+ rewriter, vectorType, smeTileType , transposed))) {
322
+ auto smeMask = extractSMEMask (rewriter, loc, mask, smeTile );
323
+ auto smeWrite = rewriter.create <vector::TransferWriteOp>(
300
324
loc, inputSMETiles[index], destTensorOrMemref,
301
- remapIndicesForSMETile (rewriter, loc, writeOp.getIndices (), tileTile ),
302
- writeOp.getPermutationMapAttr (), tileMask , writeOp.getInBoundsAttr ());
325
+ getSMESubTileIndices (rewriter, loc, writeOp.getIndices (), smeTile ),
326
+ writeOp.getPermutationMapAttr (), smeMask , writeOp.getInBoundsAttr ());
303
327
if (writeOp.hasPureTensorSemantics ())
304
- destTensorOrMemref = tileWrite .getResult ();
328
+ destTensorOrMemref = smeWrite .getResult ();
305
329
}
306
330
307
331
if (writeOp.hasPureTensorSemantics ())
@@ -326,9 +350,10 @@ struct VectorLegalizationPass
326
350
SmallVectorImpl<Type> &types) -> std::optional<LogicalResult> {
327
351
if (!isMultipleOfSMETileVectorType (vectorType))
328
352
return std::nullopt;
329
- auto tileTileCount = getNumberOfSMETilesForVectorType (vectorType);
330
- auto tileType = getSMETileTypeForElement (vectorType.getElementType ());
331
- types = SmallVector<Type>(tileTileCount, tileType);
353
+ auto smeTileTileCount = getNumberOfSMETilesForVectorType (vectorType);
354
+ auto smeTileType =
355
+ getSMETileTypeForElement (vectorType.getElementType ());
356
+ types = SmallVector<Type>(smeTileTileCount, smeTileType);
332
357
return success ();
333
358
});
334
359
0 commit comments