@@ -46,6 +46,8 @@ static constexpr StringLiteral kMatchFailureUnsupportedMaskOp(
46
46
" op mask is unsupported for legalization/decomposition" );
47
47
static constexpr StringLiteral
48
48
kMatchFailureNonPermutationMap (" op affine map is not a permutation" );
49
+ static constexpr StringLiteral kMatchFailureNotIllegalToLegal (
50
+ " expected transpose from illegal type to legal type" );
49
51
50
52
// / An SMESubTile represents a single SME-sized sub-tile from decomposing a
51
53
// / larger vector type. The (`row`, `col`) are the position of the tile in the
@@ -416,6 +418,17 @@ struct FoldExtractFromVectorOfSMELikeCreateMasks
416
418
}
417
419
};
418
420
421
+ // / A vector type where no fixed dimension comes after a scalable dimension.
422
+ bool isLegalVectorType (VectorType vType) {
423
+ bool seenFixedDim = false ;
424
+ for (bool scalableFlag : llvm::reverse (vType.getScalableDims ())) {
425
+ seenFixedDim |= !scalableFlag;
426
+ if (seenFixedDim && scalableFlag)
427
+ return false ;
428
+ }
429
+ return true ;
430
+ }
431
+
419
432
// / Lifts an illegal vector.transpose and vector.transfer_read to a
420
433
// / memref.subview + memref.transpose, followed by a legal read.
421
434
// /
@@ -448,16 +461,6 @@ struct LiftIllegalVectorTransposeToMemory
448
461
: public OpRewritePattern<vector::TransposeOp> {
449
462
using OpRewritePattern<vector::TransposeOp>::OpRewritePattern;
450
463
451
- static bool isIllegalVectorType (VectorType vType) {
452
- bool seenFixedDim = false ;
453
- for (bool scalableFlag : llvm::reverse (vType.getScalableDims ())) {
454
- seenFixedDim |= !scalableFlag;
455
- if (seenFixedDim && scalableFlag)
456
- return true ;
457
- }
458
- return false ;
459
- }
460
-
461
464
static Value getExtensionSource (Operation *op) {
462
465
if (isa_and_present<arith::ExtSIOp, arith::ExtUIOp, arith::ExtFOp>(op))
463
466
return op->getOperand (0 );
@@ -468,9 +471,9 @@ struct LiftIllegalVectorTransposeToMemory
468
471
PatternRewriter &rewriter) const override {
469
472
auto sourceType = transposeOp.getSourceVectorType ();
470
473
auto resultType = transposeOp.getResultVectorType ();
471
- if (! isIllegalVectorType (sourceType) || isIllegalVectorType (resultType))
472
- return rewriter.notifyMatchFailure (
473
- transposeOp, " expected transpose from illegal type to legal type " );
474
+ if (isLegalVectorType (sourceType) || ! isLegalVectorType (resultType))
475
+ return rewriter.notifyMatchFailure (transposeOp,
476
+ kMatchFailureNotIllegalToLegal );
474
477
475
478
// Look through extend for transfer_read.
476
479
Value maybeRead = transposeOp.getVector ();
@@ -556,6 +559,59 @@ struct LiftIllegalVectorTransposeToMemory
556
559
}
557
560
};
558
561
562
+ // / A rewrite to turn unit dim transpose-like vector.shape_casts into
563
+ // / vector.transposes. The shape_cast has to be from an illegal vector type to a
564
+ // / legal one (as defined by isLegalVectorType).
565
+ // /
566
+ // / The reasoning for this is if we've got to this pass and we still have
567
+ // / shape_casts of illegal types, then they likely will not cancel out. Turning
568
+ // / them into transposes gives LiftIllegalVectorTransposeToMemory a chance to
569
+ // / eliminate them.
570
+ // /
571
+ // / Example:
572
+ // /
573
+ // / BEFORE:
574
+ // / ```mlir
575
+ // / %0 = vector.shape_cast %a : vector<[4]x1xf32> to vector<1x[4]xf32>
576
+ // / ```
577
+ // /
578
+ // / AFTER:
579
+ // / ```mlir
580
+ // / %0 = vector.transpose %0, [1, 0] : vector<[4]x1xf32> to vector<1x[4]xf32>
581
+ // / ```
582
+ struct ConvertIllegalShapeCastOpsToTransposes
583
+ : public OpRewritePattern<vector::ShapeCastOp> {
584
+ using OpRewritePattern<vector::ShapeCastOp>::OpRewritePattern;
585
+
586
+ LogicalResult matchAndRewrite (vector::ShapeCastOp shapeCastOp,
587
+ PatternRewriter &rewriter) const override {
588
+ auto sourceType = shapeCastOp.getSourceVectorType ();
589
+ auto resultType = shapeCastOp.getResultVectorType ();
590
+ if (isLegalVectorType (sourceType) || !isLegalVectorType (resultType))
591
+ return rewriter.notifyMatchFailure (shapeCastOp,
592
+ kMatchFailureNotIllegalToLegal );
593
+
594
+ // Note: If we know that `sourceType` is an illegal vector type (and 2D)
595
+ // then dim 0 is scalable and dim 1 is fixed.
596
+ if (sourceType.getRank () != 2 || sourceType.getDimSize (1 ) != 1 )
597
+ return rewriter.notifyMatchFailure (
598
+ shapeCastOp, " expected source to be a 2D scalable vector with a "
599
+ " trailing unit dim" );
600
+
601
+ auto loc = shapeCastOp.getLoc ();
602
+ auto transpose = rewriter.create <vector::TransposeOp>(
603
+ loc, shapeCastOp.getSource (), ArrayRef<int64_t >{1 , 0 });
604
+
605
+ if (resultType.getRank () == 1 )
606
+ rewriter.replaceOpWithNewOp <vector::ShapeCastOp>(shapeCastOp, resultType,
607
+ transpose);
608
+ else
609
+ rewriter.replaceOp (shapeCastOp, transpose);
610
+
611
+ return success ();
612
+ }
613
+ };
614
+
559
615
struct VectorLegalizationPass
560
616
: public arm_sme::impl::VectorLegalizationBase<VectorLegalizationPass> {
561
617
void runOnOperation () override {
@@ -576,7 +632,8 @@ struct VectorLegalizationPass
576
632
});
577
633
578
634
patterns.add <FoldExtractFromVectorOfSMELikeCreateMasks,
579
- LiftIllegalVectorTransposeToMemory>(context);
635
+ LiftIllegalVectorTransposeToMemory,
636
+ ConvertIllegalShapeCastOpsToTransposes>(context);
580
637
// Note: High benefit to ensure masked outer products are lowered first.
581
638
patterns.add <LegalizeMaskedVectorOuterProductOpsByDecomposition>(
582
639
converter, context, 1024 );
0 commit comments