7
7
// ===----------------------------------------------------------------------===//
8
8
//
9
9
// This pass legalizes vector operations so they can be lowered to ArmSME.
10
- // Currently, this only implements the decomposition of vector operations that
11
- // use vector sizes larger than an SME tile, into multiple SME-sized operations.
12
10
//
13
11
// Note: In the context of this pass 'tile' always refers to an SME tile.
14
12
//
15
13
// ===----------------------------------------------------------------------===//
16
14
15
+ #include " mlir/Dialect/Arith/Utils/Utils.h"
17
16
#include " mlir/Dialect/ArmSME/IR/ArmSME.h"
18
17
#include " mlir/Dialect/ArmSME/Transforms/Passes.h"
19
18
#include " mlir/Dialect/ArmSME/Utils/Utils.h"
@@ -35,6 +34,10 @@ using namespace mlir::arm_sme;
35
34
36
35
namespace {
37
36
37
+ // ===----------------------------------------------------------------------===//
38
+ // Decomposition of vector operations larger than an SME tile
39
+ // ===----------------------------------------------------------------------===//
40
+
38
41
// Common match failure reasons.
39
42
static constexpr StringLiteral MATCH_FAILURE_NOT_SME_TILE_TYPE_MULTIPLE (
40
43
" op vector size is not multiple of SME tiles" );
@@ -338,13 +341,86 @@ struct LegalizeTransferWriteOpsByDecomposition
338
341
}
339
342
};
340
343
344
+ // ===----------------------------------------------------------------------===//
345
+ // ArmSME-specific fixup canonicalizations/folds
346
+ // ===----------------------------------------------------------------------===//
347
+
348
+ // / Folds an extract from a 3D `vector.create_mask` (which is a vector of
349
+ // / SME-like masks), into a compare and a 2D `vector.create_mask`. This is
350
+ // / necessary for the mask to be lowered to ArmSME.
351
+ // /
352
+ // / Example:
353
+ // /
354
+ // / BEFORE:
355
+ // / ```mlir
356
+ // / %mask = vector.create_mask %nonConstantDim, %a, %b : vector<4x[4]x[4]xi1>
357
+ // / %subMask = vector.extract %mask[2]
358
+ // / : vector<[4]x[4]xi1> from vector<4x[4]x[4]xi1>
359
+ // / ```
360
+ // /
361
+ // / AFTER:
362
+ // / ```mlir
363
+ // / %extractionInTrueRegion = arith.cmpi slt, %c2, %nonConstantDim : index
364
+ // / %newMaskFrontDim = arith.select %extractionInTrueRegion, %a, %c0 : index
365
+ // / %subMask = vector.create_mask %newMaskFrontDim, %b : vector<[4]x[4]xi1>
366
+ // / ```
367
+ struct FoldExtractFromVectorOfSMELikeCreateMasks
368
+ : public OpRewritePattern<vector::ExtractOp> {
369
+ using OpRewritePattern<vector::ExtractOp>::OpRewritePattern;
370
+
371
+ LogicalResult matchAndRewrite (vector::ExtractOp extractOp,
372
+ PatternRewriter &rewriter) const override {
373
+ auto loc = extractOp.getLoc ();
374
+ auto createMaskOp =
375
+ extractOp.getVector ().getDefiningOp <vector::CreateMaskOp>();
376
+ if (!createMaskOp)
377
+ return rewriter.notifyMatchFailure (
378
+ extractOp, " extract not from vector.create_mask op" );
379
+
380
+ VectorType extractedMaskType =
381
+ llvm::dyn_cast<VectorType>(extractOp.getResult ().getType ());
382
+ if (!extractedMaskType)
383
+ return rewriter.notifyMatchFailure (extractOp,
384
+ " extracted type is not a vector type" );
385
+
386
+ auto numScalable = llvm::count (extractedMaskType.getScalableDims (), true );
387
+ if (numScalable != 2 )
388
+ return rewriter.notifyMatchFailure (
389
+ extractOp, " expected extracted type to be an SME-like mask" );
390
+
391
+ // TODO: Support multiple extraction indices.
392
+ if (extractOp.getStaticPosition ().size () != 1 )
393
+ return rewriter.notifyMatchFailure (
394
+ extractOp, " only a single extraction index is supported" );
395
+
396
+ auto frontMaskDim = createMaskOp.getOperand (0 );
397
+ if (frontMaskDim.getDefiningOp <arith::ConstantOp>())
398
+ return rewriter.notifyMatchFailure (
399
+ extractOp,
400
+ " constant vector.create_masks dims should be folded elsewhere" );
401
+
402
+ auto zero = rewriter.create <arith::ConstantIndexOp>(loc, 0 );
403
+ auto extractionIndex = getValueOrCreateConstantIndexOp (
404
+ rewriter, loc, extractOp.getMixedPosition ()[0 ]);
405
+ auto extractionInTrueRegion = rewriter.create <arith::CmpIOp>(
406
+ loc, rewriter.getI1Type (), arith::CmpIPredicate::slt, extractionIndex,
407
+ frontMaskDim);
408
+ auto newMaskFrontDim = rewriter.create <arith::SelectOp>(
409
+ loc, extractionInTrueRegion, createMaskOp.getOperand (1 ), zero);
410
+
411
+ rewriter.replaceOpWithNewOp <vector::CreateMaskOp>(
412
+ extractOp, extractedMaskType,
413
+ ValueRange{newMaskFrontDim, createMaskOp.getOperand (2 )});
414
+ return success ();
415
+ }
416
+ };
417
+
341
418
struct VectorLegalizationPass
342
419
: public arm_sme::impl::VectorLegalizationBase<VectorLegalizationPass> {
343
420
void runOnOperation () override {
344
421
auto *context = &getContext ();
345
422
OneToNTypeConverter converter;
346
423
RewritePatternSet patterns (context);
347
-
348
424
converter.addConversion ([](Type type) { return type; });
349
425
converter.addConversion (
350
426
[](VectorType vectorType,
@@ -358,6 +434,7 @@ struct VectorLegalizationPass
358
434
return success ();
359
435
});
360
436
437
+ patterns.add <FoldExtractFromVectorOfSMELikeCreateMasks>(context);
361
438
// Note: High benefit to ensure masked outer products are lowered first.
362
439
patterns.add <LegalizeMaskedVectorOuterProductOpsByDecomposition>(
363
440
converter, context, 1024 );
0 commit comments