Skip to content

Commit aad4e7e

Browse files
committed
Fixups
1 parent abcbd07 commit aad4e7e

File tree

2 files changed

+74
-28
lines changed

2 files changed

+74
-28
lines changed

mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,8 +56,8 @@ scf::ForOp createLoopOverTileSlices(
5656
PatternRewriter &rewriter, Location loc, Value initTile,
5757
std::function<Value(OpBuilder &, Location, Value, Value)> makeLoopBody);
5858

59-
/// Returns true if `vType` is a multiple of an SME tile size. Note returns
60-
/// false if the `vType` exactly matches the size of an SME tile.
59+
/// Returns true if `vType` is a multiple of an SME tile size. Returns false if
60+
/// the `vType` exactly matches the size of an SME tile.
6161
bool isMultipleOfSMETileVectorType(VectorType vType);
6262

6363
/// Creates a vector type for the SME tile of `elementType`.

mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp

Lines changed: 72 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,11 @@
1+
//===- VectorLegalization.cpp - Legalize vectors for lowering to ArmSME ---===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
19
#include "mlir/Dialect/ArmSME/IR/ArmSME.h"
210
#include "mlir/Dialect/ArmSME/Transforms/Passes.h"
311
#include "mlir/Dialect/ArmSME/Utils/Utils.h"
@@ -19,15 +27,24 @@ using namespace mlir::arm_sme;
1927

2028
namespace {
2129

30+
// Common match failure reasons.
31+
static constexpr StringLiteral MATCH_FAILURE_NOT_SME_TILE_TYPE_MULTIPLE(
32+
"op vector size is not multiple of SME tiles");
33+
static constexpr StringLiteral MATCH_FAILURE_UNSUPPORTED_MASK_OP(
34+
"op mask is unsupported for legalization/decomposition");
35+
static constexpr StringLiteral
36+
MATCH_FAILURE_NON_PERMUTATION_MAP("op affine map is not a permutation");
37+
2238
struct SMETile {
2339
// Note: The units of (row, col) are vscale (as SME tiles are scalable).
2440
int row{0};
2541
int col{0};
2642
VectorType type;
2743
};
2844

29-
/// Adds a constant scalable offset to `indices`. i.e. for 2D:
30-
/// { indices[0] + offset[0] * vscale, indices[1] + offset[1] * vscale }
45+
/// Adds a constant scalable offset to `indices` (which are of equal length).
46+
/// For example, in the 2D case this would return:
47+
// { indices[0] + offset[0] * vscale, indices[1] + offset[1] * vscale }
3148
SmallVector<Value, 2> addConstantScalableOffset(OpBuilder &builder,
3249
Location loc,
3350
ValueRange indices,
@@ -42,8 +59,20 @@ SmallVector<Value, 2> addConstantScalableOffset(OpBuilder &builder,
4259
});
4360
}
4461

45-
/// Remaps indices (e.g. from a load/store) for a larger vector type to indices
46-
/// for one of the SME tiles it will decompose into.
62+
/// Remaps `indices` (e.g. from a load/store) for a larger vector type to
63+
/// indices for one of the SME tiles it will decompose into.
64+
///
65+
/// For example, if you were to decompose an 8x8 load into four 4x4 tiles, the
66+
/// indices for each tile would need to be remapped as follows:
67+
///
68+
/// initial indices = [a,b], inital size = 8x8, target size = 4x4
69+
/// ┌─────────────┬─────────────┐
70+
/// │[a,b] │[a,b+4] │
71+
/// │ │ │
72+
/// ├─────────────┼─────────────┤
73+
/// │[a+4,b] │[a+4,b+4] │
74+
/// │ │ │
75+
/// └─────────────┴─────────────┘
4776
SmallVector<Value, 2> remapIndicesForSMETile(OpBuilder &builder, Location loc,
4877
ValueRange indices,
4978
SMETile tileTile) {
@@ -64,7 +93,7 @@ Value extractSMEMask(OpBuilder &builder, Location loc, Value mask,
6493
if (!mask)
6594
return Value{};
6695
auto createMask = mask.getDefiningOp<vector::CreateMaskOp>();
67-
// The the operands of `vector.create_mask` (from a 2D perspective) are the
96+
// The operands of `vector.create_mask` (from a 2D perspective) are the
6897
// coordinates where the mask ends. So we subtract where this tile starts,
6998
// from the mask operands to get the parameters for this tile tile.
7099
auto tileMaskDims = addConstantScalableOffset(
@@ -75,7 +104,9 @@ Value extractSMEMask(OpBuilder &builder, Location loc, Value mask,
75104
}
76105

77106
/// Constructs an iterator that returns each SME tile (with coordinates)
78-
/// contained within a VectorType.
107+
/// contained within a VectorType. For example, if decomposing an [8]x[8] into
108+
/// [4]x[4] tiles, the iterator would yield the tiles: (0, 0), (0, 4), (4, 0),
109+
/// (4, 4).
79110
auto decomposeToSMETiles(OpBuilder &builder, VectorType type,
80111
VectorType smeTileType,
81112
bool transposeIndices = false) {
@@ -92,7 +123,8 @@ auto decomposeToSMETiles(OpBuilder &builder, VectorType type,
92123
});
93124
}
94125

95-
/// Returns the number of SME tiles that fit into the a vector type.
126+
/// Returns the number of SME tiles that fit into the (2D-scalable) vector type
127+
/// `type`.
96128
int getNumberOfSMETilesForVectorType(VectorType type) {
97129
assert(isMultipleOfSMETileVectorType(type));
98130
int64_t vectorRows = type.getDimSize(0);
@@ -102,8 +134,9 @@ int getNumberOfSMETilesForVectorType(VectorType type) {
102134
return (vectorRows * vectorCols) / (minNumElts * minNumElts);
103135
}
104136

105-
/// Legalize `vector.outerproduct` operations to fit within SME tiles.
106-
struct LegalizeVectorOuterProductOp
137+
/// Legalize `vector.outerproduct` operations to fit within SME tiles by
138+
/// decomposing them into tile-sized operations.
139+
struct LegalizeVectorOuterProductOpsByDecomposition
107140
: public OneToNOpConversionPattern<vector::OuterProductOp> {
108141
using OneToNOpConversionPattern::OneToNOpConversionPattern;
109142

@@ -112,7 +145,8 @@ struct LegalizeVectorOuterProductOp
112145
OneToNPatternRewriter &rewriter) const override {
113146
auto vectorType = outerProductOp.getResultVectorType();
114147
if (!isMultipleOfSMETileVectorType(vectorType))
115-
return failure();
148+
return rewriter.notifyMatchFailure(
149+
outerProductOp, MATCH_FAILURE_NOT_SME_TILE_TYPE_MULTIPLE);
116150

117151
Value mask;
118152
Operation *rootOp = outerProductOp;
@@ -124,7 +158,8 @@ struct LegalizeVectorOuterProductOp
124158
}
125159

126160
if (!isSupportedMaskOp(mask))
127-
return failure();
161+
return rewriter.notifyMatchFailure(outerProductOp,
162+
MATCH_FAILURE_UNSUPPORTED_MASK_OP);
128163

129164
ValueRange accSMETiles = adaptor.getAcc();
130165
auto tileType = getSMETileTypeForElement(vectorType.getElementType());
@@ -159,7 +194,7 @@ struct LegalizeVectorOuterProductOp
159194
// conversion adding target materializations in the `vector.mask` region
160195
// (invalid). This pattern matches on `vector.mask` then calls into the
161196
// `vector.outerproduct` pattern to work around this issue.
162-
struct LegalizeMaskedVectorOuterProductOp
197+
struct LegalizeMaskedVectorOuterProductOpsByDecomposition
163198
: public OneToNOpConversionPattern<vector::MaskOp> {
164199
using OneToNOpConversionPattern::OneToNOpConversionPattern;
165200

@@ -168,16 +203,18 @@ struct LegalizeMaskedVectorOuterProductOp
168203
OneToNPatternRewriter &rewriter) const override {
169204
if (auto outerProductOp =
170205
llvm::dyn_cast<vector::OuterProductOp>(maskOp.getMaskableOp())) {
171-
LegalizeVectorOuterProductOp pattern(*getTypeConverter(), getContext());
206+
LegalizeVectorOuterProductOpsByDecomposition pattern(*getTypeConverter(),
207+
getContext());
172208
return static_cast<RewritePattern &>(pattern).matchAndRewrite(
173209
outerProductOp, rewriter);
174210
}
175211
return failure();
176212
}
177213
};
178214

179-
/// Legalize `vector.transfer_read` operations to fit within SME tiles.
180-
struct LegalizeTransferReadOp
215+
/// Legalize `vector.transfer_read` operations to fit within SME tiles by
216+
/// decomposing them into tile-sized operations.
217+
struct LegalizeTransferReadOpsByDecomposition
181218
: public OneToNOpConversionPattern<vector::TransferReadOp> {
182219
using OneToNOpConversionPattern::OneToNOpConversionPattern;
183220

@@ -186,15 +223,18 @@ struct LegalizeTransferReadOp
186223
OneToNPatternRewriter &rewriter) const override {
187224
auto vectorType = readOp.getVectorType();
188225
if (!isMultipleOfSMETileVectorType(vectorType))
189-
return failure();
226+
return rewriter.notifyMatchFailure(
227+
readOp, MATCH_FAILURE_NOT_SME_TILE_TYPE_MULTIPLE);
190228

191229
auto mask = readOp.getMask();
192230
if (!isSupportedMaskOp(mask))
193-
return failure();
231+
return rewriter.notifyMatchFailure(readOp,
232+
MATCH_FAILURE_UNSUPPORTED_MASK_OP);
194233

195234
auto permutationMap = readOp.getPermutationMap();
196235
if (!permutationMap.isPermutation())
197-
return failure();
236+
return rewriter.notifyMatchFailure(readOp,
237+
MATCH_FAILURE_NON_PERMUTATION_MAP);
198238

199239
// Note: For 2D vector types the only non-identity permutation is a simple
200240
// tranpose [1, 0].
@@ -220,8 +260,9 @@ struct LegalizeTransferReadOp
220260
}
221261
};
222262

223-
/// Legalize `vector.transfer_write` operations to fit within SME tiles.
224-
struct LegalizeTransferWriteOp
263+
/// Legalize `vector.transfer_write` operations to fit within SME tiles by
264+
/// decomposing them into tile-sized operations.
265+
struct LegalizeTransferWriteOpsByDecomposition
225266
: public OneToNOpConversionPattern<vector::TransferWriteOp> {
226267
using OneToNOpConversionPattern::OneToNOpConversionPattern;
227268

@@ -230,15 +271,18 @@ struct LegalizeTransferWriteOp
230271
OneToNPatternRewriter &rewriter) const override {
231272
auto vectorType = writeOp.getVectorType();
232273
if (!isMultipleOfSMETileVectorType(vectorType))
233-
return failure();
274+
return rewriter.notifyMatchFailure(
275+
writeOp, MATCH_FAILURE_NOT_SME_TILE_TYPE_MULTIPLE);
234276

235277
auto mask = writeOp.getMask();
236278
if (!isSupportedMaskOp(mask))
237-
return failure();
279+
return rewriter.notifyMatchFailure(writeOp,
280+
MATCH_FAILURE_UNSUPPORTED_MASK_OP);
238281

239282
auto permutationMap = writeOp.getPermutationMap();
240283
if (!permutationMap.isPermutation())
241-
return failure();
284+
return rewriter.notifyMatchFailure(writeOp,
285+
MATCH_FAILURE_NON_PERMUTATION_MAP);
242286

243287
// Note: For 2D vector types the only non-identity permutation is a simple
244288
// tranpose [1, 0].
@@ -289,9 +333,11 @@ struct VectorLegalizationPass
289333
});
290334

291335
// Note: High benefit to ensure masked outer products are lowered first.
292-
patterns.add<LegalizeMaskedVectorOuterProductOp>(converter, context, 1024);
293-
patterns.add<LegalizeVectorOuterProductOp, LegalizeTransferReadOp,
294-
LegalizeTransferWriteOp>(converter, context);
336+
patterns.add<LegalizeMaskedVectorOuterProductOpsByDecomposition>(
337+
converter, context, 1024);
338+
patterns.add<LegalizeVectorOuterProductOpsByDecomposition,
339+
LegalizeTransferReadOpsByDecomposition,
340+
LegalizeTransferWriteOpsByDecomposition>(converter, context);
295341
populateFuncTypeConversionPatterns(converter, patterns);
296342
scf::populateSCFStructuralOneToNTypeConversions(converter, patterns);
297343

0 commit comments

Comments
 (0)