Skip to content

Commit c92b580

Browse files
authored
[mlir][ArmSME] Remove ConvertIllegalShapeCastOpsToTransposes (#139706)
As a follow-up to PR #135841 (see discussion for background), this patch removes the `ConvertIllegalShapeCastOpsToTransposes` pattern from the SME legalization pass. This change unblocks folding for ShapeCastOp involving scalable vectors. Originally, the `ConvertIllegalShapeCastOpsToTransposes` pattern was introduced to rewrite certain `vector.shape_cast` ops that could not be lowered otherwise. Based on local end-to-end testing, this workaround is no longer required, and the pattern can now be safely removed. This patch also removes a special case from `ShapeCastOp::fold`, simplifying the fold logic. As a side effect of removing `ConvertIllegalShapeCastOpsToTransposes`, we lose the mechanism that enabled lowering of certain ops like: ```mlir %res = vector.transfer_read %mem[%a, %b] (...) : memref<?x?xf32>, vector<[4]x1xf32> ``` Previously, such cases were handled by: * Rewriting a nearby `vector.shape_cast` to a `vector.transpose` (via `ConvertIllegalShapeCastOpsToTransposes`) * Then lowering the result with `LiftIllegalVectorTransposeToMemory`. This patch introduces a new dedicated pattern, `LowerColumnTransferReadToLoops`, that directly handles illegal `vector.transfer_read` ops involving leading scalable dimensions.
1 parent c08502d commit c92b580

File tree

4 files changed

+207
-142
lines changed

4 files changed

+207
-142
lines changed

mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp

Lines changed: 114 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -724,59 +724,6 @@ struct LiftIllegalVectorTransposeToMemory
724724
}
725725
};
726726

727-
/// A rewrite to turn unit dim transpose-like vector.shape_casts into
728-
/// vector.transposes. The shape_cast has to be from an illegal vector type to a
729-
/// legal one (as defined by isLegalVectorType).
730-
///
731-
/// The reasoning for this is if we've got to this pass and we still have
732-
/// shape_casts of illegal types, then they likely will not cancel out. Turning
733-
/// them into transposes gives LiftIllegalVectorTransposeToMemory a chance to
734-
/// eliminate them.
735-
///
736-
/// Example:
737-
///
738-
/// BEFORE:
739-
/// ```mlir
740-
/// %0 = vector.shape_cast %a : vector<[4]x1xf32> to vector<1x[4]xf32>
741-
/// ```
742-
///
743-
/// AFTER:
744-
/// ```mlir
745-
/// %0 = vector.transpose %0, [1, 0] : vector<[4]x1xf32> to vector<1x[4]xf32>
746-
/// ```
747-
struct ConvertIllegalShapeCastOpsToTransposes
748-
: public OpRewritePattern<vector::ShapeCastOp> {
749-
using OpRewritePattern<vector::ShapeCastOp>::OpRewritePattern;
750-
751-
LogicalResult matchAndRewrite(vector::ShapeCastOp shapeCastOp,
752-
PatternRewriter &rewriter) const override {
753-
auto sourceType = shapeCastOp.getSourceVectorType();
754-
auto resultType = shapeCastOp.getResultVectorType();
755-
if (isLegalVectorType(sourceType) || !isLegalVectorType(resultType))
756-
return rewriter.notifyMatchFailure(shapeCastOp,
757-
kMatchFailureNotIllegalToLegal);
758-
759-
// Note: If we know that `sourceType` is an illegal vector type (and 2D)
760-
// then dim 0 is scalable and dim 1 is fixed.
761-
if (sourceType.getRank() != 2 || sourceType.getDimSize(1) != 1)
762-
return rewriter.notifyMatchFailure(
763-
shapeCastOp, "expected source to be a 2D scalable vector with a "
764-
"trailing unit dim");
765-
766-
auto loc = shapeCastOp.getLoc();
767-
auto transpose = rewriter.create<vector::TransposeOp>(
768-
loc, shapeCastOp.getSource(), ArrayRef<int64_t>{1, 0});
769-
770-
if (resultType.getRank() == 1)
771-
rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(shapeCastOp, resultType,
772-
transpose);
773-
else
774-
rewriter.replaceOp(shapeCastOp, transpose);
775-
776-
return success();
777-
}
778-
};
779-
780727
/// Rewrites an illegal/unsupported SVE transfer_write(transpose) to instead use
781728
/// the ZA state. This workaround rewrite to support these transposes when ZA is
782729
/// available.
@@ -920,6 +867,116 @@ struct LowerIllegalTransposeStoreViaZA
920867
}
921868
};
922869

870+
/// Lower `vector.transfer_read` of a scalable column to `scf::for`
871+
///
872+
/// Lowers a "read" of a scalable column from a MemRef for which there is no
873+
/// hardware pperation that we could use to a loop over the rows to read and
874+
/// loads one element at a time.
875+
///
876+
/// BEFORE:
877+
/// ```
878+
/// %res = vector.transfer_read %mem[%a, %b] (...)
879+
/// : memref<?x?xf32>, vector<[4]x1xf32>
880+
/// ```
881+
///
882+
/// AFTER:
883+
/// ```
884+
/// %cst = arith.constant (...) : vector<[4]xf32>
885+
/// %vscale = vector.vscale
886+
/// %c4_vscale = arith.muli %vscale, %c4 : index
887+
/// %scf = scf.for %lb = %c0 to %c4_vscale step %c1 iter_args(%arg4 = %cst)
888+
/// -> (vector<[4]xf32>) {
889+
///
890+
/// %load = memref.load %mem[%arg3 + %a, %b] : memref<?x?xf32>
891+
/// %vec = vector.insert %load, %cst [%arg3] : f32 into vector<[4]xf32>
892+
/// scf.yield %vec : vector<[4]xf32>
893+
/// }
894+
/// %res = vector.shape_cast %scf : vector<[4]xf32> to vector<[4]x1xf32>
895+
/// ```
896+
///
897+
/// TODO: This transformation isn't specific to SME - move it to the SVE
898+
/// dialect.
899+
/// TODO: Check the in_bounds attribute and generate vector.maskedload if
900+
/// required.
901+
struct LowerColumnTransferReadToLoops
902+
: public OpRewritePattern<vector::TransferReadOp> {
903+
using OpRewritePattern::OpRewritePattern;
904+
905+
LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
906+
PatternRewriter &rewriter) const override {
907+
// NOTE: This is a fairly low-level transformation, so we shouldn't be
908+
// adding support for Tensors without good rationale.
909+
if (readOp.hasPureTensorSemantics())
910+
return rewriter.notifyMatchFailure(
911+
readOp, "Tensor semantics are unsupported (either bufferize or "
912+
"extend this pattern)");
913+
914+
auto resType = readOp.getVectorType();
915+
916+
if (resType.getRank() != 2)
917+
return rewriter.notifyMatchFailure(readOp,
918+
"Only 2D vectors are supported!");
919+
920+
if (resType.getShape()[1] != 1)
921+
return rewriter.notifyMatchFailure(
922+
readOp, "The trailing output dim is != 1 (not supported ATM)");
923+
924+
if (!resType.getScalableDims()[0] || resType.getScalableDims()[1])
925+
return rewriter.notifyMatchFailure(
926+
readOp, "Expected the leading dim to be scalable and the trailing "
927+
"dim to be fixed.");
928+
929+
// Create new result type - similar to the original vector with the
930+
// trailing unit dim collapsed.
931+
int64_t numRows = resType.getShape()[0];
932+
VectorType newResType = VectorType::get(numRows, resType.getElementType(),
933+
/*scalableDims=*/{true});
934+
935+
// Create a loop over all rows and load one element at a time.
936+
auto loc = readOp.getLoc();
937+
auto lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0);
938+
auto createVscaleMultiple =
939+
vector::makeVscaleConstantBuilder(rewriter, loc);
940+
auto upperBound = createVscaleMultiple(numRows);
941+
auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
942+
Value init = rewriter.create<arith::ConstantOp>(
943+
loc, newResType, DenseElementsAttr::get(newResType, 0.0f));
944+
945+
scf::ForOp loadLoop;
946+
{
947+
OpBuilder::InsertionGuard g(rewriter);
948+
loadLoop = rewriter.create<scf::ForOp>(loc, lowerBound, upperBound, step,
949+
ValueRange{init});
950+
rewriter.setInsertionPointToStart(loadLoop.getBody());
951+
952+
auto tileSliceIndex = loadLoop.getInductionVar();
953+
954+
auto idx0 = rewriter.create<arith::AddIOp>(loc, tileSliceIndex,
955+
readOp.getIndices()[0]);
956+
auto idx1 = readOp.getIndices()[1];
957+
958+
Value scalar = rewriter.create<memref::LoadOp>(
959+
loc, readOp.getBase(), SmallVector<Value>({idx0, idx1}));
960+
961+
Operation *updateInit = rewriter.create<vector::InsertOp>(
962+
loc, scalar, loadLoop.getRegionIterArg(0), tileSliceIndex);
963+
964+
rewriter.create<scf::YieldOp>(loc, updateInit->getResult(0));
965+
}
966+
967+
// The read operation has been "legalized", but since the original result
968+
// type was a 2D vector, we need to cast before returning the result. This
969+
// ShapeCast should cancel-out with some other ShapeCast (i.e. it's a
970+
// no-op).
971+
auto sc = rewriter.create<vector::ShapeCastOp>(
972+
loc, readOp.getResult().getType(), loadLoop.getResult(0));
973+
974+
rewriter.replaceOp(readOp, sc);
975+
976+
return success();
977+
}
978+
};
979+
923980
struct VectorLegalizationPass
924981
: public arm_sme::impl::VectorLegalizationBase<VectorLegalizationPass> {
925982
void runOnOperation() override {
@@ -941,10 +998,10 @@ struct VectorLegalizationPass
941998

942999
// Apply preprocessing patterns.
9431000
RewritePatternSet rewritePatterns(context);
944-
rewritePatterns.add<FoldExtractFromVectorOfSMELikeCreateMasks,
945-
LiftIllegalVectorTransposeToMemory,
946-
ConvertIllegalShapeCastOpsToTransposes,
947-
LowerIllegalTransposeStoreViaZA>(context);
1001+
rewritePatterns
1002+
.add<FoldExtractFromVectorOfSMELikeCreateMasks,
1003+
LowerColumnTransferReadToLoops, LiftIllegalVectorTransposeToMemory,
1004+
LowerIllegalTransposeStoreViaZA>(context);
9481005
if (failed(
9491006
applyPatternsGreedily(getOperation(), std::move(rewritePatterns))))
9501007
return signalPassFailure();

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -5856,18 +5856,7 @@ OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) {
58565856

58575857
// shape_cast(transpose(x)) -> shape_cast(x)
58585858
if (auto transpose = getSource().getDefiningOp<TransposeOp>()) {
5859-
// This folder does
5860-
// shape_cast(transpose) -> shape_cast
5861-
// But another pattern, ConvertIllegalShapeCastOpsToTransposes, does
5862-
// shape_cast -> shape_cast(transpose)
5863-
// i.e. the complete opposite. When paired, these 2 patterns can cause
5864-
// infinite cycles in pattern rewriting.
5865-
// ConvertIllegalShapeCastOpsToTransposes only matches on scalable
5866-
// vectors, so by disabling this folder for scalable vectors the
5867-
// cycle is avoided.
5868-
// TODO: Check if ConvertIllegalShapeCastOpsToTransposes is
5869-
// still needed. If it's not, then we can fold here.
5870-
if (!transpose.getType().isScalable() && isOrderPreserving(transpose)) {
5859+
if (isOrderPreserving(transpose)) {
58715860
setOperand(transpose.getVector());
58725861
return getResult();
58735862
}

mlir/test/Dialect/ArmSME/vector-legalization.mlir

Lines changed: 56 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -491,51 +491,6 @@ func.func @illegal_transpose_no_defining_source_op(%vec: vector<[4]x1xf32>) -> v
491491

492492
// -----
493493

494-
// CHECK-LABEL: @illegal_shape_cast_to_transpose_2d(
495-
// CHECK-SAME: %[[VEC:.*]]: vector<[4]x1xf32>)
496-
func.func @illegal_shape_cast_to_transpose_2d(%vec: vector<[4]x1xf32>) -> vector<1x[4]xf32> {
497-
// CHECK: vector.transpose %[[VEC]], [1, 0] : vector<[4]x1xf32> to vector<1x[4]xf32>
498-
%0 = vector.shape_cast %vec : vector<[4]x1xf32> to vector<1x[4]xf32>
499-
return %0 : vector<1x[4]xf32>
500-
}
501-
502-
// -----
503-
504-
// CHECK-LABEL: @illegal_shape_cast_to_transpose_1d(
505-
// CHECK-SAME: %[[VEC:.*]]: vector<[4]x1xf32>)
506-
func.func @illegal_shape_cast_to_transpose_1d(%vec: vector<[4]x1xf32>) -> vector<[4]xf32> {
507-
// CHECK: %[[TRANSPOSE:.*]] = vector.transpose %[[VEC]], [1, 0] : vector<[4]x1xf32> to vector<1x[4]xf32>
508-
// CHECK: vector.shape_cast %[[TRANSPOSE]] : vector<1x[4]xf32> to vector<[4]xf32>
509-
%0 = vector.shape_cast %vec : vector<[4]x1xf32> to vector<[4]xf32>
510-
return %0 : vector<[4]xf32>
511-
}
512-
513-
// -----
514-
515-
// CHECK-LABEL: @lift_illegal_2d_shape_cast_to_memory
516-
func.func @lift_illegal_2d_shape_cast_to_memory(%a: index, %b: index, %memref: memref<?x?xf32>) -> vector<1x[4]xf32> {
517-
// CHECK: vector.transfer_read {{.*}} : memref<?x?xf32, {{.*}}>, vector<1x[4]xf32>
518-
// CHECK-NOT: vector.shape_cast
519-
%pad = arith.constant 0.0 : f32
520-
%illegalRead = vector.transfer_read %memref[%a, %b], %pad {in_bounds = [false, true]}: memref<?x?xf32>, vector<[4]x1xf32>
521-
%cast = vector.shape_cast %illegalRead : vector<[4]x1xf32> to vector<1x[4]xf32>
522-
return %cast : vector<1x[4]xf32>
523-
}
524-
525-
// -----
526-
527-
// CHECK-LABEL: @lift_illegal_1d_shape_cast_to_memory
528-
func.func @lift_illegal_1d_shape_cast_to_memory(%a: index, %b: index, %memref: memref<?x?xf32>) -> vector<[4]xf32> {
529-
// CHECK: vector.transfer_read {{.*}} : memref<?x?xf32, {{.*}}>, vector<1x[4]xf32>
530-
// CHECK-NOT: vector.shape_cast {{.*}} : vector<[4]x1xf32> to vector<[4]xf32>
531-
%pad = arith.constant 0.0 : f32
532-
%illegalRead = vector.transfer_read %memref[%a, %b], %pad {in_bounds = [false, true]}: memref<?x?xf32>, vector<[4]x1xf32>
533-
%cast = vector.shape_cast %illegalRead : vector<[4]x1xf32> to vector<[4]xf32>
534-
return %cast : vector<[4]xf32>
535-
}
536-
537-
// -----
538-
539494
// CHECK-LABEL: @multi_tile_splat
540495
func.func @multi_tile_splat() -> vector<[8]x[8]xi32>
541496
{
@@ -656,3 +611,59 @@ func.func @vector_mask_without_maskable_op(%mask: vector<16x2xi1>, %vec: vector<
656611
%0 = vector.mask %mask { vector.yield %vec : vector<16x16xf32> } : vector<16x2xi1> -> vector<16x16xf32>
657612
return %0 : vector<16x16xf32>
658613
}
614+
615+
// -----
616+
617+
//=============================================================================
618+
// 1D examples - to be moved to the SVE dialect
619+
//=============================================================================
620+
621+
/// TODO: Handle in_bounds
622+
623+
// CHECK-LABEL: func.func @xfer_read_scalable_column(
624+
// CHECK-SAME: %[[IDX_0:[a-zA-Z0-9]+]]: index,
625+
// CHECK-SAME: %[[IDX_1:[a-zA-Z0-9]+]]: index,
626+
// CHECK-SAME: %[[PAD:.*]]: f32,
627+
// CHECK-SAME: %[[SRC:.*]]: memref<?x?xf32>) -> vector<[4]x1xf32> {
628+
func.func @xfer_read_scalable_column(%a: index, %b: index, %pad: f32, %src: memref<?x?xf32>) -> (vector<[4]x1xf32>) {
629+
// CHECK: %[[INIT:.*]] = arith.constant dense<0.000000e+00> : vector<[4]xf32>
630+
// CHECK: %[[STEP:.*]] = arith.constant 1 : index
631+
// CHECK: %[[C4:.*]] = arith.constant 4 : index
632+
// CHECK: %[[LB:.*]] = arith.constant 0 : index
633+
// CHECK: %[[VSCALE:.*]] = vector.vscale
634+
// CHECK: %[[C4_VSCALE:.*]] = arith.muli %[[VSCALE]], %[[C4]] : index
635+
636+
// <scf.for>
637+
// CHECK: %[[SCF:.*]] = scf.for %[[IND_VAR:.*]] = %[[LB]] to %[[C4_VSCALE]] step %[[STEP]] iter_args(%[[SCF_RES:.*]] = %[[INIT]]) -> (vector<[4]xf32>) {
638+
// CHECK: %[[IDX_0_UPDATED:.*]] = arith.addi %[[IND_VAR]], %[[IDX_0]] : index
639+
// CHECK: %[[VAL_10:.*]] = memref.load %[[SRC]][%[[IDX_0_UPDATED]], %[[IDX_1]]] : memref<?x?xf32>
640+
// CHECK: %[[RES_UPDATED:.*]] = vector.insert %[[VAL_10]], %[[SCF_RES]] [%[[IND_VAR]]] : f32 into vector<[4]xf32>
641+
// CHECK: scf.yield %[[RES_UPDATED]] : vector<[4]xf32>
642+
// CHECK: }
643+
644+
// <shape-cast>
645+
// CHECK: %[[SC:.*]] = vector.shape_cast %[[SCF]] : vector<[4]xf32> to vector<[4]x1xf32>
646+
// CHECK: return %[[SC]]
647+
%read = vector.transfer_read %src[%a, %b], %pad : memref<?x?xf32>, vector<[4]x1xf32>
648+
return %read : vector<[4]x1xf32>
649+
}
650+
651+
// -----
652+
653+
// CHECK-LABEL: func.func @negative_xfer_read_scalable_column_x2
654+
func.func @negative_xfer_read_scalable_column_x2(%a: index, %b: index, %pad: f32, %src: memref<?x?xf32>) -> (vector<[4]x2xf32>) {
655+
// CHECK-NOT: scf.for
656+
// CHECK-NOT: memref.load
657+
%read = vector.transfer_read %src[%a, %b], %pad : memref<?x?xf32>, vector<[4]x2xf32>
658+
return %read : vector<[4]x2xf32>
659+
}
660+
661+
// -----
662+
663+
// CHECK-LABEL: func.func @negative_xfer_read_scalable_column_scalable_trailing_dim
664+
func.func @negative_xfer_read_scalable_column_scalable_trailing_dim(%a: index, %b: index, %pad: f32, %src: memref<?x?xf32>) -> (vector<4x[1]xf32>) {
665+
// CHECK-NOT: scf.for
666+
// CHECK-NOT: memref.load
667+
%read = vector.transfer_read %src[%a, %b], %pad : memref<?x?xf32>, vector<4x[1]xf32>
668+
return %read : vector<4x[1]xf32>
669+
}

0 commit comments

Comments
 (0)