Skip to content

Commit d04d335

Browse files
committed
fixup! [mlir][ArmSME] Remove ConvertIllegalShapeCastOpsToTransposes
Add LowerColumnTransferReadToLoops. Note, this is to address Ben's comment here: * https://github.com/llvm/llvm-project/pull/139706/files#r2088605443
1 parent de24b85 commit d04d335

File tree

2 files changed

+170
-3
lines changed

2 files changed

+170
-3
lines changed

mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp

Lines changed: 114 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -867,6 +867,116 @@ struct LowerIllegalTransposeStoreViaZA
867867
}
868868
};
869869

870+
/// Lower `vector.transfer_read` of a scalable column to `scf::for`
871+
///
872+
/// Lowers a "read" of a scalable column from a MemRef for which there is no
873+
/// hardware pperation that we could use to a loop over the rows to read and
874+
/// loads one element at a time.
875+
///
876+
/// BEFORE:
877+
/// ```
878+
/// %res = vector.transfer_read %mem[%a, %b] (...)
879+
/// : memref<?x?xf32>, vector<[4]x1xf32>
880+
/// ```
881+
///
882+
/// AFTER:
883+
/// ```
884+
/// %cst = arith.constant (...) : vector<[4]xf32>
885+
/// %vscale = vector.vscale
886+
/// %c4_vscale = arith.muli %vscale, %c4 : index
887+
/// %scf = scf.for %lb = %c0 to %c4_vscale step %c1 iter_args(%arg4 = %cst)
888+
/// -> (vector<[4]xf32>) {
889+
///
890+
/// %load = memref.load %mem[%arg3 + %a, %b] : memref<?x?xf32>
891+
/// %vec = vector.insert %load, %cst [%arg3] : f32 into vector<[4]xf32>
892+
/// scf.yield %vec : vector<[4]xf32>
893+
/// }
894+
/// %res = vector.shape_cast %scf : vector<[4]xf32> to vector<[4]x1xf32>
895+
/// ```
896+
///
897+
/// TODO: This transformation isn't specific to SME - move it to the SVE
898+
/// dialect.
899+
/// TODO: Check the in_bounds attribute and generate vector.maskedload if
900+
/// required.
901+
struct LowerColumnTransferReadToLoops
902+
: public OpRewritePattern<vector::TransferReadOp> {
903+
using OpRewritePattern::OpRewritePattern;
904+
905+
LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
906+
PatternRewriter &rewriter) const override {
907+
// NOTE: This is a fairly low-level transformation, so we shouldn't be
908+
// adding support for Tensors without good rationale.
909+
if (readOp.hasPureTensorSemantics())
910+
return rewriter.notifyMatchFailure(
911+
readOp, "Tensor semantics are unsupported (either bufferize or "
912+
"extend this pattern)");
913+
914+
auto resType = readOp.getVectorType();
915+
916+
if (resType.getRank() != 2)
917+
return rewriter.notifyMatchFailure(readOp,
918+
"Only 2D vectors are supported!");
919+
920+
if (resType.getShape()[1] != 1)
921+
return rewriter.notifyMatchFailure(
922+
readOp, "The trailing output dim is != 1 (not supported ATM)");
923+
924+
if (!resType.getScalableDims()[0] || resType.getScalableDims()[1])
925+
return rewriter.notifyMatchFailure(
926+
readOp, "Expected the leading dim to be scalable and the trailing "
927+
"dim to be fixed.");
928+
929+
// Create new result type - similar to the original vector with the
930+
// trailing unit dim collapsed.
931+
int64_t numRows = resType.getShape()[0];
932+
VectorType newResType = VectorType::get(numRows, resType.getElementType(),
933+
/*scalableDims=*/{true});
934+
935+
// Create a loop over all rows and load one element at a time.
936+
auto loc = readOp.getLoc();
937+
auto lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0);
938+
auto createVscaleMultiple =
939+
vector::makeVscaleConstantBuilder(rewriter, loc);
940+
auto upperBound = createVscaleMultiple(numRows);
941+
auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
942+
Value init = rewriter.create<arith::ConstantOp>(
943+
loc, newResType, DenseElementsAttr::get(newResType, 0.0f));
944+
945+
scf::ForOp loadLoop;
946+
{
947+
OpBuilder::InsertionGuard g(rewriter);
948+
loadLoop = rewriter.create<scf::ForOp>(loc, lowerBound, upperBound, step,
949+
ValueRange{init});
950+
rewriter.setInsertionPointToStart(loadLoop.getBody());
951+
952+
auto tileSliceIndex = loadLoop.getInductionVar();
953+
954+
auto idx0 = rewriter.create<arith::AddIOp>(loc, tileSliceIndex,
955+
readOp.getIndices()[0]);
956+
auto idx1 = readOp.getIndices()[1];
957+
958+
Value scalar = rewriter.create<memref::LoadOp>(
959+
loc, readOp.getBase(), SmallVector<Value>({idx0, idx1}));
960+
961+
Operation *updateInit = rewriter.create<vector::InsertOp>(
962+
loc, scalar, loadLoop.getRegionIterArg(0), tileSliceIndex);
963+
964+
rewriter.create<scf::YieldOp>(loc, updateInit->getResult(0));
965+
}
966+
967+
// The read operation has been "legalized", but since the original result
968+
// type was a 2D vector, we need to cast before returning the result. This
969+
// ShapeCast should cancel-out with some other ShapeCast (i.e. it's a
970+
// no-op).
971+
auto sc = rewriter.create<vector::ShapeCastOp>(
972+
loc, readOp.getResult().getType(), loadLoop.getResult(0));
973+
974+
rewriter.replaceOp(readOp, sc);
975+
976+
return success();
977+
}
978+
};
979+
870980
struct VectorLegalizationPass
871981
: public arm_sme::impl::VectorLegalizationBase<VectorLegalizationPass> {
872982
void runOnOperation() override {
@@ -888,9 +998,10 @@ struct VectorLegalizationPass
888998

889999
// Apply preprocessing patterns.
8901000
RewritePatternSet rewritePatterns(context);
891-
rewritePatterns.add<FoldExtractFromVectorOfSMELikeCreateMasks,
892-
LiftIllegalVectorTransposeToMemory,
893-
LowerIllegalTransposeStoreViaZA>(context);
1001+
rewritePatterns
1002+
.add<FoldExtractFromVectorOfSMELikeCreateMasks,
1003+
LowerColumnTransferReadToLoops, LiftIllegalVectorTransposeToMemory,
1004+
LowerIllegalTransposeStoreViaZA>(context);
8941005
if (failed(
8951006
applyPatternsGreedily(getOperation(), std::move(rewritePatterns))))
8961007
return signalPassFailure();

mlir/test/Dialect/ArmSME/vector-legalization.mlir

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -611,3 +611,59 @@ func.func @vector_mask_without_maskable_op(%mask: vector<16x2xi1>, %vec: vector<
611611
%0 = vector.mask %mask { vector.yield %vec : vector<16x16xf32> } : vector<16x2xi1> -> vector<16x16xf32>
612612
return %0 : vector<16x16xf32>
613613
}
614+
615+
// -----
616+
617+
//=============================================================================
618+
// 1D examples - to be moved to the SVE dialect
619+
//=============================================================================
620+
621+
/// TODO: Handle in_bounds
622+
623+
// CHECK-LABEL: func.func @xfer_read_scalable_column(
624+
// CHECK-SAME: %[[IDX_0:[a-zA-Z0-9]+]]: index,
625+
// CHECK-SAME: %[[IDX_1:[a-zA-Z0-9]+]]: index,
626+
// CHECK-SAME: %[[PAD:.*]]: f32,
627+
// CHECK-SAME: %[[SRC:.*]]: memref<?x?xf32>) -> vector<[4]x1xf32> {
628+
func.func @xfer_read_scalable_column(%a: index, %b: index, %pad: f32, %src: memref<?x?xf32>) -> (vector<[4]x1xf32>) {
629+
// CHECK: %[[INIT:.*]] = arith.constant dense<0.000000e+00> : vector<[4]xf32>
630+
// CHECK: %[[STEP:.*]] = arith.constant 1 : index
631+
// CHECK: %[[C4:.*]] = arith.constant 4 : index
632+
// CHECK: %[[LB:.*]] = arith.constant 0 : index
633+
// CHECK: %[[VSCALE:.*]] = vector.vscale
634+
// CHECK: %[[C4_VSCALE:.*]] = arith.muli %[[VSCALE]], %[[C4]] : index
635+
636+
// <scf.for>
637+
// CHECK: %[[SCF:.*]] = scf.for %[[IND_VAR:.*]] = %[[LB]] to %[[C4_VSCALE]] step %[[STEP]] iter_args(%[[SCF_RES:.*]] = %[[INIT]]) -> (vector<[4]xf32>) {
638+
// CHECK: %[[IDX_0_UPDATED:.*]] = arith.addi %[[IND_VAR]], %[[IDX_0]] : index
639+
// CHECK: %[[VAL_10:.*]] = memref.load %[[SRC]][%[[IDX_0_UPDATED]], %[[IDX_1]]] : memref<?x?xf32>
640+
// CHECK: %[[RES_UPDATED:.*]] = vector.insert %[[VAL_10]], %[[SCF_RES]] [%[[IND_VAR]]] : f32 into vector<[4]xf32>
641+
// CHECK: scf.yield %[[RES_UPDATED]] : vector<[4]xf32>
642+
// CHECK: }
643+
644+
// <shape-cast>
645+
// CHECK: %[[SC:.*]] = vector.shape_cast %[[SCF]] : vector<[4]xf32> to vector<[4]x1xf32>
646+
// CHECK: return %[[SC]]
647+
%read = vector.transfer_read %src[%a, %b], %pad : memref<?x?xf32>, vector<[4]x1xf32>
648+
return %read : vector<[4]x1xf32>
649+
}
650+
651+
// -----
652+
653+
// CHECK-LABEL: func.func @negative_xfer_read_scalable_column_x2
654+
func.func @negative_xfer_read_scalable_column_x2(%a: index, %b: index, %pad: f32, %src: memref<?x?xf32>) -> (vector<[4]x2xf32>) {
655+
// CHECK-NOT: scf.for
656+
// CHECK-NOT: memref.load
657+
%read = vector.transfer_read %src[%a, %b], %pad : memref<?x?xf32>, vector<[4]x2xf32>
658+
return %read : vector<[4]x2xf32>
659+
}
660+
661+
// -----
662+
663+
// CHECK-LABEL: func.func @negative_xfer_read_scalable_column_scalable_trailing_dim
664+
func.func @negative_xfer_read_scalable_column_scalable_trailing_dim(%a: index, %b: index, %pad: f32, %src: memref<?x?xf32>) -> (vector<4x[1]xf32>) {
665+
// CHECK-NOT: scf.for
666+
// CHECK-NOT: memref.load
667+
%read = vector.transfer_read %src[%a, %b], %pad : memref<?x?xf32>, vector<4x[1]xf32>
668+
return %read : vector<4x[1]xf32>
669+
}

0 commit comments

Comments
 (0)