Skip to content

Commit 31613de

Browse files
[mlir][ArmSME] Migrate arm-sme-vector-legalization to dialect conversion (#121101)
Use the regular dialect conversion driver instead of the 1:N dialect conversion driver. The 1:N dialect conversion driver will be removed soon.
1 parent f0d6017 commit 31613de

File tree

1 file changed

+56
-38
lines changed

1 file changed

+56
-38
lines changed

mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp

Lines changed: 56 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -17,15 +17,16 @@
1717
#include "mlir/Dialect/ArmSME/Transforms/Passes.h"
1818
#include "mlir/Dialect/ArmSME/Utils/Utils.h"
1919
#include "mlir/Dialect/Func/IR/FuncOps.h"
20-
#include "mlir/Dialect/Func/Transforms/OneToNFuncConversions.h"
20+
#include "mlir/Dialect/Func/Transforms/FuncConversions.h"
2121
#include "mlir/Dialect/Index/IR/IndexDialect.h"
2222
#include "mlir/Dialect/Index/IR/IndexOps.h"
2323
#include "mlir/Dialect/MemRef/IR/MemRef.h"
2424
#include "mlir/Dialect/SCF/IR/SCF.h"
2525
#include "mlir/Dialect/SCF/Transforms/Patterns.h"
2626
#include "mlir/Dialect/Utils/IndexingUtils.h"
2727
#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
28-
#include "mlir/Transforms/OneToNTypeConversion.h"
28+
#include "mlir/Transforms/DialectConversion.h"
29+
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
2930

3031
#define DEBUG_TYPE "arm-sme-vector-legalization"
3132

@@ -172,12 +173,12 @@ int getNumberOfSMETilesForVectorType(VectorType type) {
172173
/// Legalize `arith.constant dense<value>` splat operations to fit within SME
173174
/// tiles by decomposing them into tile-sized operations.
174175
struct LegalizeArithConstantOpsByDecomposition
175-
: public OneToNOpConversionPattern<arith::ConstantOp> {
176-
using OneToNOpConversionPattern::OneToNOpConversionPattern;
176+
: public OpConversionPattern<arith::ConstantOp> {
177+
using OpConversionPattern::OpConversionPattern;
177178

178179
LogicalResult
179180
matchAndRewrite(arith::ConstantOp constantOp, OpAdaptor adaptor,
180-
OneToNPatternRewriter &rewriter) const override {
181+
ConversionPatternRewriter &rewriter) const override {
181182
auto vectorType = dyn_cast<VectorType>(constantOp.getType());
182183
auto denseAttr = dyn_cast<DenseElementsAttr>(constantOp.getValueAttr());
183184
if (!vectorType || !denseAttr || !denseAttr.isSplat())
@@ -191,8 +192,8 @@ struct LegalizeArithConstantOpsByDecomposition
191192
auto tileCount = getNumberOfSMETilesForVectorType(vectorType);
192193
auto tileSplat = rewriter.create<arith::ConstantOp>(
193194
constantOp.getLoc(), denseAttr.resizeSplat(smeTileType));
194-
rewriter.replaceOp(constantOp, SmallVector<Value>(tileCount, tileSplat),
195-
adaptor.getResultMapping());
195+
SmallVector<Value> repl(tileCount, tileSplat);
196+
rewriter.replaceOpWithMultiple(constantOp, {repl});
196197

197198
return success();
198199
}
@@ -201,12 +202,13 @@ struct LegalizeArithConstantOpsByDecomposition
201202
/// Legalize `vector.outerproduct` operations to fit within SME tiles by
202203
/// decomposing them into tile-sized operations.
203204
struct LegalizeVectorOuterProductOpsByDecomposition
204-
: public OneToNOpConversionPattern<vector::OuterProductOp> {
205-
using OneToNOpConversionPattern::OneToNOpConversionPattern;
205+
: public OpConversionPattern<vector::OuterProductOp> {
206+
using OpConversionPattern::OpConversionPattern;
206207

207208
LogicalResult
208-
matchAndRewrite(vector::OuterProductOp outerProductOp, OpAdaptor adaptor,
209-
OneToNPatternRewriter &rewriter) const override {
209+
matchAndRewrite(vector::OuterProductOp outerProductOp,
210+
OneToNOpAdaptor adaptor,
211+
ConversionPatternRewriter &rewriter) const override {
210212
auto vectorType = outerProductOp.getResultVectorType();
211213
if (!isMultipleOfSMETileVectorType(vectorType))
212214
return rewriter.notifyMatchFailure(outerProductOp,
@@ -219,6 +221,7 @@ struct LegalizeVectorOuterProductOpsByDecomposition
219221
auto maskOp = outerProductOp.getMaskingOp();
220222
mask = maskOp.getMask();
221223
rootOp = maskOp;
224+
rewriter.setInsertionPoint(rootOp);
222225
}
223226

224227
if (!isSupportedMaskOp(mask))
@@ -248,7 +251,7 @@ struct LegalizeVectorOuterProductOpsByDecomposition
248251
resultSMETiles.push_back(maskedOuterProduct->getResult(0));
249252
}
250253

251-
rewriter.replaceOp(rootOp, resultSMETiles, adaptor.getResultMapping());
254+
rewriter.replaceOpWithMultiple(rootOp, {resultSMETiles});
252255
return success();
253256
}
254257
};
@@ -259,12 +262,12 @@ struct LegalizeVectorOuterProductOpsByDecomposition
259262
// (invalid). This pattern matches on `vector.mask` then calls into the
260263
// `vector.outerproduct` pattern to work around this issue.
261264
struct LegalizeMaskedVectorOuterProductOpsByDecomposition
262-
: public OneToNOpConversionPattern<vector::MaskOp> {
263-
using OneToNOpConversionPattern::OneToNOpConversionPattern;
265+
: public OpConversionPattern<vector::MaskOp> {
266+
using OpConversionPattern::OpConversionPattern;
264267

265268
LogicalResult
266-
matchAndRewrite(vector::MaskOp maskOp, OpAdaptor adaptor,
267-
OneToNPatternRewriter &rewriter) const override {
269+
matchAndRewrite(vector::MaskOp maskOp, OneToNOpAdaptor adaptor,
270+
ConversionPatternRewriter &rewriter) const override {
268271
if (auto outerProductOp = llvm::dyn_cast_or_null<vector::OuterProductOp>(
269272
maskOp.getMaskableOp())) {
270273
LegalizeVectorOuterProductOpsByDecomposition pattern(*getTypeConverter(),
@@ -279,12 +282,12 @@ struct LegalizeMaskedVectorOuterProductOpsByDecomposition
279282
/// Legalize `vector.transfer_read` operations to fit within SME tiles by
280283
/// decomposing them into tile-sized operations.
281284
struct LegalizeTransferReadOpsByDecomposition
282-
: public OneToNOpConversionPattern<vector::TransferReadOp> {
283-
using OneToNOpConversionPattern::OneToNOpConversionPattern;
285+
: public OpConversionPattern<vector::TransferReadOp> {
286+
using OpConversionPattern::OpConversionPattern;
284287

285288
LogicalResult
286-
matchAndRewrite(vector::TransferReadOp readOp, OpAdaptor adaptor,
287-
OneToNPatternRewriter &rewriter) const override {
289+
matchAndRewrite(vector::TransferReadOp readOp, OneToNOpAdaptor adaptor,
290+
ConversionPatternRewriter &rewriter) const override {
288291
auto vectorType = readOp.getVectorType();
289292
if (!isMultipleOfSMETileVectorType(vectorType))
290293
return rewriter.notifyMatchFailure(readOp,
@@ -319,20 +322,20 @@ struct LegalizeTransferReadOpsByDecomposition
319322
resultSMETiles.push_back(smeRead);
320323
}
321324

322-
rewriter.replaceOp(readOp, resultSMETiles, adaptor.getResultMapping());
325+
rewriter.replaceOpWithMultiple(readOp, {resultSMETiles});
323326
return success();
324327
}
325328
};
326329

327330
/// Legalize `vector.transfer_write` operations to fit within SME tiles by
328331
/// decomposing them into tile-sized operations.
329332
struct LegalizeTransferWriteOpsByDecomposition
330-
: public OneToNOpConversionPattern<vector::TransferWriteOp> {
331-
using OneToNOpConversionPattern::OneToNOpConversionPattern;
333+
: public OpConversionPattern<vector::TransferWriteOp> {
334+
using OpConversionPattern::OpConversionPattern;
332335

333336
LogicalResult
334-
matchAndRewrite(vector::TransferWriteOp writeOp, OpAdaptor adaptor,
335-
OneToNPatternRewriter &rewriter) const override {
337+
matchAndRewrite(vector::TransferWriteOp writeOp, OneToNOpAdaptor adaptor,
338+
ConversionPatternRewriter &rewriter) const override {
336339
auto vectorType = writeOp.getVectorType();
337340
if (!isMultipleOfSMETileVectorType(vectorType))
338341
return rewriter.notifyMatchFailure(writeOp,
@@ -409,12 +412,12 @@ struct LegalizeTransferWriteOpsByDecomposition
409412
/// }
410413
/// ```
411414
struct LegalizeMultiTileTransferWriteAsStoreLoop
412-
: public OneToNOpConversionPattern<vector::TransferWriteOp> {
413-
using OneToNOpConversionPattern::OneToNOpConversionPattern;
415+
: public OpConversionPattern<vector::TransferWriteOp> {
416+
using OpConversionPattern::OpConversionPattern;
414417

415418
LogicalResult
416-
matchAndRewrite(vector::TransferWriteOp writeOp, OpAdaptor adaptor,
417-
OneToNPatternRewriter &rewriter) const override {
419+
matchAndRewrite(vector::TransferWriteOp writeOp, OneToNOpAdaptor adaptor,
420+
ConversionPatternRewriter &rewriter) const override {
418421
if (writeOp.hasPureTensorSemantics())
419422
return rewriter.notifyMatchFailure(
420423
writeOp, "TODO: tensor semantics are unsupported");
@@ -936,10 +939,16 @@ struct VectorLegalizationPass
936939
return success();
937940
});
938941

939-
patterns.add<FoldExtractFromVectorOfSMELikeCreateMasks,
940-
LiftIllegalVectorTransposeToMemory,
941-
ConvertIllegalShapeCastOpsToTransposes,
942-
LowerIllegalTransposeStoreViaZA>(context);
942+
// Apply preprocessing patterns.
943+
RewritePatternSet rewritePatterns(context);
944+
rewritePatterns.add<FoldExtractFromVectorOfSMELikeCreateMasks,
945+
LiftIllegalVectorTransposeToMemory,
946+
ConvertIllegalShapeCastOpsToTransposes,
947+
LowerIllegalTransposeStoreViaZA>(context);
948+
if (failed(
949+
applyPatternsGreedily(getOperation(), std::move(rewritePatterns))))
950+
return signalPassFailure();
951+
943952
// Note: These two patterns are added with a high benefit to ensure:
944953
// - Masked outer products are handled before unmasked ones
945954
// - Multi-tile writes are lowered as a store loop (if possible)
@@ -950,11 +959,20 @@ struct VectorLegalizationPass
950959
LegalizeVectorOuterProductOpsByDecomposition,
951960
LegalizeTransferReadOpsByDecomposition,
952961
LegalizeTransferWriteOpsByDecomposition>(converter, context);
953-
populateFuncTypeConversionPatterns(converter, patterns);
954-
scf::populateSCFStructuralOneToNTypeConversions(converter, patterns);
955-
956-
if (failed(applyPartialOneToNConversion(getOperation(), converter,
957-
std::move(patterns))))
962+
populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(patterns,
963+
converter);
964+
populateCallOpTypeConversionPattern(patterns, converter);
965+
populateReturnOpTypeConversionPattern(patterns, converter);
966+
scf::populateSCFStructuralTypeConversions(converter, patterns);
967+
968+
ConversionTarget target(getContext());
969+
target.markUnknownOpDynamicallyLegal(
970+
[&](Operation *op) { return converter.isLegal(op); });
971+
target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
972+
return converter.isSignatureLegal(op.getFunctionType());
973+
});
974+
if (failed(applyPartialConversion(getOperation(), target,
975+
std::move(patterns))))
958976
return signalPassFailure();
959977
}
960978
};

0 commit comments

Comments
 (0)