17
17
#include " mlir/Dialect/ArmSME/Transforms/Passes.h"
18
18
#include " mlir/Dialect/ArmSME/Utils/Utils.h"
19
19
#include " mlir/Dialect/Func/IR/FuncOps.h"
20
- #include " mlir/Dialect/Func/Transforms/OneToNFuncConversions .h"
20
+ #include " mlir/Dialect/Func/Transforms/FuncConversions .h"
21
21
#include " mlir/Dialect/Index/IR/IndexDialect.h"
22
22
#include " mlir/Dialect/Index/IR/IndexOps.h"
23
23
#include " mlir/Dialect/MemRef/IR/MemRef.h"
24
24
#include " mlir/Dialect/SCF/IR/SCF.h"
25
25
#include " mlir/Dialect/SCF/Transforms/Patterns.h"
26
26
#include " mlir/Dialect/Utils/IndexingUtils.h"
27
27
#include " mlir/Dialect/Vector/Utils/VectorUtils.h"
28
- #include " mlir/Transforms/OneToNTypeConversion.h"
28
+ #include " mlir/Transforms/DialectConversion.h"
29
+ #include " mlir/Transforms/GreedyPatternRewriteDriver.h"
29
30
30
31
#define DEBUG_TYPE " arm-sme-vector-legalization"
31
32
@@ -172,12 +173,12 @@ int getNumberOfSMETilesForVectorType(VectorType type) {
172
173
// / Legalize `arith.constant dense<value>` splat operations to fit within SME
173
174
// / tiles by decomposing them into tile-sized operations.
174
175
struct LegalizeArithConstantOpsByDecomposition
175
- : public OneToNOpConversionPattern <arith::ConstantOp> {
176
- using OneToNOpConversionPattern::OneToNOpConversionPattern ;
176
+ : public OpConversionPattern <arith::ConstantOp> {
177
+ using OpConversionPattern::OpConversionPattern ;
177
178
178
179
LogicalResult
179
180
matchAndRewrite (arith::ConstantOp constantOp, OpAdaptor adaptor,
180
- OneToNPatternRewriter &rewriter) const override {
181
+ ConversionPatternRewriter &rewriter) const override {
181
182
auto vectorType = dyn_cast<VectorType>(constantOp.getType ());
182
183
auto denseAttr = dyn_cast<DenseElementsAttr>(constantOp.getValueAttr ());
183
184
if (!vectorType || !denseAttr || !denseAttr.isSplat ())
@@ -191,8 +192,8 @@ struct LegalizeArithConstantOpsByDecomposition
191
192
auto tileCount = getNumberOfSMETilesForVectorType (vectorType);
192
193
auto tileSplat = rewriter.create <arith::ConstantOp>(
193
194
constantOp.getLoc (), denseAttr.resizeSplat (smeTileType));
194
- rewriter. replaceOp (constantOp, SmallVector<Value>(tileCount, tileSplat),
195
- adaptor. getResultMapping () );
195
+ SmallVector<Value> repl (tileCount, tileSplat);
196
+ rewriter. replaceOpWithMultiple (constantOp, {repl} );
196
197
197
198
return success ();
198
199
}
@@ -201,12 +202,13 @@ struct LegalizeArithConstantOpsByDecomposition
201
202
// / Legalize `vector.outerproduct` operations to fit within SME tiles by
202
203
// / decomposing them into tile-sized operations.
203
204
struct LegalizeVectorOuterProductOpsByDecomposition
204
- : public OneToNOpConversionPattern <vector::OuterProductOp> {
205
- using OneToNOpConversionPattern::OneToNOpConversionPattern ;
205
+ : public OpConversionPattern <vector::OuterProductOp> {
206
+ using OpConversionPattern::OpConversionPattern ;
206
207
207
208
LogicalResult
208
- matchAndRewrite (vector::OuterProductOp outerProductOp, OpAdaptor adaptor,
209
- OneToNPatternRewriter &rewriter) const override {
209
+ matchAndRewrite (vector::OuterProductOp outerProductOp,
210
+ OneToNOpAdaptor adaptor,
211
+ ConversionPatternRewriter &rewriter) const override {
210
212
auto vectorType = outerProductOp.getResultVectorType ();
211
213
if (!isMultipleOfSMETileVectorType (vectorType))
212
214
return rewriter.notifyMatchFailure (outerProductOp,
@@ -219,6 +221,7 @@ struct LegalizeVectorOuterProductOpsByDecomposition
219
221
auto maskOp = outerProductOp.getMaskingOp ();
220
222
mask = maskOp.getMask ();
221
223
rootOp = maskOp;
224
+ rewriter.setInsertionPoint (rootOp);
222
225
}
223
226
224
227
if (!isSupportedMaskOp (mask))
@@ -248,7 +251,7 @@ struct LegalizeVectorOuterProductOpsByDecomposition
248
251
resultSMETiles.push_back (maskedOuterProduct->getResult (0 ));
249
252
}
250
253
251
- rewriter.replaceOp (rootOp, resultSMETiles, adaptor. getResultMapping () );
254
+ rewriter.replaceOpWithMultiple (rootOp, { resultSMETiles} );
252
255
return success ();
253
256
}
254
257
};
@@ -259,12 +262,12 @@ struct LegalizeVectorOuterProductOpsByDecomposition
259
262
// (invalid). This pattern matches on `vector.mask` then calls into the
260
263
// `vector.outerproduct` pattern to work around this issue.
261
264
struct LegalizeMaskedVectorOuterProductOpsByDecomposition
262
- : public OneToNOpConversionPattern <vector::MaskOp> {
263
- using OneToNOpConversionPattern::OneToNOpConversionPattern ;
265
+ : public OpConversionPattern <vector::MaskOp> {
266
+ using OpConversionPattern::OpConversionPattern ;
264
267
265
268
LogicalResult
266
- matchAndRewrite (vector::MaskOp maskOp, OpAdaptor adaptor,
267
- OneToNPatternRewriter &rewriter) const override {
269
+ matchAndRewrite (vector::MaskOp maskOp, OneToNOpAdaptor adaptor,
270
+ ConversionPatternRewriter &rewriter) const override {
268
271
if (auto outerProductOp = llvm::dyn_cast_or_null<vector::OuterProductOp>(
269
272
maskOp.getMaskableOp ())) {
270
273
LegalizeVectorOuterProductOpsByDecomposition pattern (*getTypeConverter (),
@@ -279,12 +282,12 @@ struct LegalizeMaskedVectorOuterProductOpsByDecomposition
279
282
// / Legalize `vector.transfer_read` operations to fit within SME tiles by
280
283
// / decomposing them into tile-sized operations.
281
284
struct LegalizeTransferReadOpsByDecomposition
282
- : public OneToNOpConversionPattern <vector::TransferReadOp> {
283
- using OneToNOpConversionPattern::OneToNOpConversionPattern ;
285
+ : public OpConversionPattern <vector::TransferReadOp> {
286
+ using OpConversionPattern::OpConversionPattern ;
284
287
285
288
LogicalResult
286
- matchAndRewrite (vector::TransferReadOp readOp, OpAdaptor adaptor,
287
- OneToNPatternRewriter &rewriter) const override {
289
+ matchAndRewrite (vector::TransferReadOp readOp, OneToNOpAdaptor adaptor,
290
+ ConversionPatternRewriter &rewriter) const override {
288
291
auto vectorType = readOp.getVectorType ();
289
292
if (!isMultipleOfSMETileVectorType (vectorType))
290
293
return rewriter.notifyMatchFailure (readOp,
@@ -319,20 +322,20 @@ struct LegalizeTransferReadOpsByDecomposition
319
322
resultSMETiles.push_back (smeRead);
320
323
}
321
324
322
- rewriter.replaceOp (readOp, resultSMETiles, adaptor. getResultMapping () );
325
+ rewriter.replaceOpWithMultiple (readOp, { resultSMETiles} );
323
326
return success ();
324
327
}
325
328
};
326
329
327
330
// / Legalize `vector.transfer_write` operations to fit within SME tiles by
328
331
// / decomposing them into tile-sized operations.
329
332
struct LegalizeTransferWriteOpsByDecomposition
330
- : public OneToNOpConversionPattern <vector::TransferWriteOp> {
331
- using OneToNOpConversionPattern::OneToNOpConversionPattern ;
333
+ : public OpConversionPattern <vector::TransferWriteOp> {
334
+ using OpConversionPattern::OpConversionPattern ;
332
335
333
336
LogicalResult
334
- matchAndRewrite (vector::TransferWriteOp writeOp, OpAdaptor adaptor,
335
- OneToNPatternRewriter &rewriter) const override {
337
+ matchAndRewrite (vector::TransferWriteOp writeOp, OneToNOpAdaptor adaptor,
338
+ ConversionPatternRewriter &rewriter) const override {
336
339
auto vectorType = writeOp.getVectorType ();
337
340
if (!isMultipleOfSMETileVectorType (vectorType))
338
341
return rewriter.notifyMatchFailure (writeOp,
@@ -409,12 +412,12 @@ struct LegalizeTransferWriteOpsByDecomposition
409
412
// / }
410
413
// / ```
411
414
struct LegalizeMultiTileTransferWriteAsStoreLoop
412
- : public OneToNOpConversionPattern <vector::TransferWriteOp> {
413
- using OneToNOpConversionPattern::OneToNOpConversionPattern ;
415
+ : public OpConversionPattern <vector::TransferWriteOp> {
416
+ using OpConversionPattern::OpConversionPattern ;
414
417
415
418
LogicalResult
416
- matchAndRewrite (vector::TransferWriteOp writeOp, OpAdaptor adaptor,
417
- OneToNPatternRewriter &rewriter) const override {
419
+ matchAndRewrite (vector::TransferWriteOp writeOp, OneToNOpAdaptor adaptor,
420
+ ConversionPatternRewriter &rewriter) const override {
418
421
if (writeOp.hasPureTensorSemantics ())
419
422
return rewriter.notifyMatchFailure (
420
423
writeOp, " TODO: tensor semantics are unsupported" );
@@ -936,10 +939,16 @@ struct VectorLegalizationPass
936
939
return success ();
937
940
});
938
941
939
- patterns.add <FoldExtractFromVectorOfSMELikeCreateMasks,
940
- LiftIllegalVectorTransposeToMemory,
941
- ConvertIllegalShapeCastOpsToTransposes,
942
- LowerIllegalTransposeStoreViaZA>(context);
942
+ // Apply preprocessing patterns.
943
+ RewritePatternSet rewritePatterns (context);
944
+ rewritePatterns.add <FoldExtractFromVectorOfSMELikeCreateMasks,
945
+ LiftIllegalVectorTransposeToMemory,
946
+ ConvertIllegalShapeCastOpsToTransposes,
947
+ LowerIllegalTransposeStoreViaZA>(context);
948
+ if (failed (
949
+ applyPatternsGreedily (getOperation (), std::move (rewritePatterns))))
950
+ return signalPassFailure ();
951
+
943
952
// Note: These two patterns are added with a high benefit to ensure:
944
953
// - Masked outer products are handled before unmasked ones
945
954
// - Multi-tile writes are lowered as a store loop (if possible)
@@ -950,11 +959,20 @@ struct VectorLegalizationPass
950
959
LegalizeVectorOuterProductOpsByDecomposition,
951
960
LegalizeTransferReadOpsByDecomposition,
952
961
LegalizeTransferWriteOpsByDecomposition>(converter, context);
953
- populateFuncTypeConversionPatterns (converter, patterns);
954
- scf::populateSCFStructuralOneToNTypeConversions (converter, patterns);
955
-
956
- if (failed (applyPartialOneToNConversion (getOperation (), converter,
957
- std::move (patterns))))
962
+ populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(patterns,
963
+ converter);
964
+ populateCallOpTypeConversionPattern (patterns, converter);
965
+ populateReturnOpTypeConversionPattern (patterns, converter);
966
+ scf::populateSCFStructuralTypeConversions (converter, patterns);
967
+
968
+ ConversionTarget target (getContext ());
969
+ target.markUnknownOpDynamicallyLegal (
970
+ [&](Operation *op) { return converter.isLegal (op); });
971
+ target.addDynamicallyLegalOp <func::FuncOp>([&](func::FuncOp op) {
972
+ return converter.isSignatureLegal (op.getFunctionType ());
973
+ });
974
+ if (failed (applyPartialConversion (getOperation (), target,
975
+ std::move (patterns))))
958
976
return signalPassFailure ();
959
977
}
960
978
};
0 commit comments