-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][ArmSME] Move vector.print -> ArmSME lowering to VectorToArmSME #74063
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
This moves the SME tile vector.print lowering from `-convert-arm-sme-to-scf` to `-convert-vector-to-arm-sme`. This seems like a more logical place, as this is lowering a vector op to ArmSME, and it also prevents vector.print from blocking tile allocation.
@llvm/pr-subscribers-mlir-sme @llvm/pr-subscribers-mlir Author: Benjamin Maxwell (MacDue) ChangesThis moves the SME tile vector.print lowering from Full diff: https://github.com/llvm/llvm-project/pull/74063.diff 4 Files Affected:
diff --git a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
index 69c68663070b6d5..fece03040dbb881 100644
--- a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
+++ b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
@@ -447,75 +447,13 @@ struct TileStoreOpConversion : public OpRewritePattern<arm_sme::TileStoreOp> {
}
};
-/// Lowers `vector.print` of a tile into a loop over the rows of the tile,
-/// extracting them via `arm_sme.move_tile_slice_to_vector`, then printing with
-/// a 1D `vector.print`.
-///
-/// BEFORE:
-/// ```mlir
-/// vector.print %tile : vector<[4]x[4]xf32>
-/// ```
-/// AFTER:
-/// ```mlir
-/// %c0 = arith.constant 0 : index
-/// %c1 = arith.constant 1 : index
-/// %c4 = arith.constant 4 : index
-/// %vscale = vector.vscale
-/// %svl_s = arith.muli %c4, %vscale : index
-/// scf.for %i = %c0 to %svl_s step %c1 {
-/// %tile_slice = arm_sme.move_tile_slice_to_vector %tile[%i]
-/// : vector<[4]xf32> from vector<[4]x[4]xf32>
-/// vector.print %tile_slice : vector<[4]xf32>
-/// }
-/// ```
-struct TileVectorPrintOpConversion : public OpRewritePattern<vector::PrintOp> {
- using OpRewritePattern<vector::PrintOp>::OpRewritePattern;
-
- LogicalResult matchAndRewrite(vector::PrintOp printOp,
- PatternRewriter &rewriter) const override {
- if (!printOp.getSource())
- return failure();
-
- VectorType vectorType = dyn_cast<VectorType>(printOp.getPrintType());
- if (!vectorType || !arm_sme::isValidSMETileVectorType(vectorType))
- return failure();
-
- auto loc = printOp.getLoc();
-
- // Create a loop over the rows of the tile.
- auto vscale = rewriter.create<vector::VectorScaleOp>(loc);
- auto minTileRows =
- rewriter.create<arith::ConstantIndexOp>(loc, vectorType.getDimSize(0));
- auto lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0);
- auto upperBound = rewriter.create<arith::MulIOp>(loc, minTileRows, vscale);
- auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
- auto forOp = rewriter.create<scf::ForOp>(loc, lowerBound, upperBound, step);
- {
- // Loop body.
- rewriter.setInsertionPointToStart(forOp.getBody());
- // Extract the current row from the tile.
- Value rowIndex = forOp.getInductionVar();
- // FIXME: Forward tile IDs.
- // For now, if you vector.print a SME tile you need to do
- // -allocate-arm-sme-tiles after -convert-arm-sme-to-scf.
- auto tileSlice = rewriter.create<arm_sme::MoveTileSliceToVectorOp>(
- loc, printOp.getSource(), rowIndex);
- // Print the row with a 1D vector.print.
- rewriter.create<vector::PrintOp>(loc, tileSlice,
- printOp.getPunctuation());
- }
-
- rewriter.eraseOp(printOp);
- return success();
- }
-};
-
} // namespace
void mlir::populateArmSMEToSCFConversionPatterns(RewritePatternSet &patterns) {
- patterns.add<TileLoadOpConversion, TileLoadOpWithMaskAndPadZeroConversion,
- TileLoadOpWithMaskAndPadNonZeroConversion, TileStoreOpConversion,
- TileVectorPrintOpConversion>(patterns.getContext());
+ patterns
+ .add<TileLoadOpConversion, TileLoadOpWithMaskAndPadZeroConversion,
+ TileLoadOpWithMaskAndPadNonZeroConversion, TileStoreOpConversion>(
+ patterns.getContext());
}
namespace {
diff --git a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
index 3016c7b0a84772d..4b3fd26c6d59ec7 100644
--- a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
+++ b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
@@ -666,15 +666,75 @@ struct VectorInsertToArmSMELowering
}
};
+/// Lowers `vector.print` of a tile into a loop over the rows of the tile,
+/// extracting them via `arm_sme.move_tile_slice_to_vector`, then printing with
+/// a 1D `vector.print`.
+///
+/// BEFORE:
+/// ```mlir
+/// vector.print %tile : vector<[4]x[4]xf32>
+/// ```
+/// AFTER:
+/// ```mlir
+/// %c0 = arith.constant 0 : index
+/// %c1 = arith.constant 1 : index
+/// %c4 = arith.constant 4 : index
+/// %vscale = vector.vscale
+/// %svl_s = arith.muli %c4, %vscale : index
+/// scf.for %i = %c0 to %svl_s step %c1 {
+/// %tile_slice = arm_sme.move_tile_slice_to_vector %tile[%i]
+/// : vector<[4]xf32> from vector<[4]x[4]xf32>
+/// vector.print %tile_slice : vector<[4]xf32>
+/// }
+/// ```
+struct VectorPrintToArmSMELowering : public OpRewritePattern<vector::PrintOp> {
+ using OpRewritePattern<vector::PrintOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::PrintOp printOp,
+ PatternRewriter &rewriter) const override {
+ if (!printOp.getSource())
+ return failure();
+
+ VectorType vectorType = dyn_cast<VectorType>(printOp.getPrintType());
+ if (!vectorType || !arm_sme::isValidSMETileVectorType(vectorType))
+ return failure();
+
+ auto loc = printOp.getLoc();
+
+ // Create a loop over the rows of the tile.
+ auto vscale = rewriter.create<vector::VectorScaleOp>(loc);
+ auto minTileRows =
+ rewriter.create<arith::ConstantIndexOp>(loc, vectorType.getDimSize(0));
+ auto lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+ auto upperBound = rewriter.create<arith::MulIOp>(loc, minTileRows, vscale);
+ auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
+ auto forOp = rewriter.create<scf::ForOp>(loc, lowerBound, upperBound, step);
+ {
+ // Loop body.
+ rewriter.setInsertionPointToStart(forOp.getBody());
+ // Extract the current row from the tile.
+ Value rowIndex = forOp.getInductionVar();
+ auto tileSlice = rewriter.create<arm_sme::MoveTileSliceToVectorOp>(
+ loc, printOp.getSource(), rowIndex);
+ // Print the row with a 1D vector.print.
+ rewriter.create<vector::PrintOp>(loc, tileSlice,
+ printOp.getPunctuation());
+ }
+
+ rewriter.eraseOp(printOp);
+ return success();
+ }
+};
+
} // namespace
void mlir::populateVectorToArmSMEPatterns(RewritePatternSet &patterns,
MLIRContext &ctx) {
- patterns.add<BroadcastOpToArmSMELowering, ConstantOpToArmSMELowering,
- SplatOpToArmSMELowering, TransferReadToArmSMELowering,
- TransferWriteToArmSMELowering, TransposeOpToArmSMELowering,
- VectorLoadToArmSMELowering, VectorStoreToArmSMELowering,
- VectorOuterProductToArmSMELowering,
- VectorExtractToArmSMELowering, VectorInsertToArmSMELowering>(
- &ctx);
+ patterns
+ .add<BroadcastOpToArmSMELowering, ConstantOpToArmSMELowering,
+ SplatOpToArmSMELowering, TransferReadToArmSMELowering,
+ TransferWriteToArmSMELowering, TransposeOpToArmSMELowering,
+ VectorLoadToArmSMELowering, VectorStoreToArmSMELowering,
+ VectorOuterProductToArmSMELowering, VectorExtractToArmSMELowering,
+ VectorInsertToArmSMELowering, VectorPrintToArmSMELowering>(&ctx);
}
diff --git a/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir b/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir
index fc28645a7acf7c0..efefc6c49e08f04 100644
--- a/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir
+++ b/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir
@@ -160,25 +160,3 @@ func.func @arm_sme_tile_store_hor_with_mask(%tile : vector<[4]x[4]xi32>, %dest :
arm_sme.tile_store %tile, %dest[%c0, %c0], %mask : memref<?x?xi32>, vector<[4]x[4]xi32>
return
}
-
-//===----------------------------------------------------------------------===//
-// vector.print
-//===----------------------------------------------------------------------===//
-
-// -----
-
-func.func @arm_sme_tile_print(%tile: vector<[4]x[4]xf32>)
-{
- vector.print %tile : vector<[4]x[4]xf32>
- return
-}
-// CHECK-LABEL: func.func @arm_sme_tile_print(
-// CHECK-SAME: %[[TILE:.*]]: vector<[4]x[4]xf32>) {
-// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
-// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
-// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index
-// CHECK-DAG: %[[VSCALE:.*]] = vector.vscale
-// CHECK-DAG: %[[NUM_TILE_SLICES:.*]] = arith.muli %[[C4]], %[[VSCALE]] : index
-// CHECK-NEXT: scf.for %[[TILE_SLICE_INDEX:.*]] = %[[C0]] to %[[NUM_TILE_SLICES]] step %[[C1]] {
-// CHECK-NEXT: %[[TILE_SLICE:.*]] = arm_sme.move_tile_slice_to_vector %[[TILE]][%[[TILE_SLICE_INDEX]]] : vector<[4]xf32> from vector<[4]x[4]xf32>
-// CHECK-NEXT: vector.print %[[TILE_SLICE]] : vector<[4]xf32>
diff --git a/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir b/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir
index 2491b2e2468cdaf..5bc147c60f3a664 100644
--- a/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir
+++ b/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir
@@ -736,3 +736,25 @@ func.func @vector_outerproduct_f64(%lhs : vector<[2]xf64>, %rhs : vector<[2]xf64
%result = vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>} : vector<[2]xf64>, vector<[2]xf64>
"prevent.dce"(%result) : (vector<[2]x[2]xf64>) -> ()
}
+
+//===----------------------------------------------------------------------===//
+// vector.print
+//===----------------------------------------------------------------------===//
+
+// -----
+
+func.func @vector_print_tile(%tile: vector<[4]x[4]xf32>)
+{
+ vector.print %tile : vector<[4]x[4]xf32>
+ return
+}
+// CHECK-LABEL: func.func @vector_print_tile(
+// CHECK-SAME: %[[TILE:.*]]: vector<[4]x[4]xf32>) {
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index
+// CHECK-DAG: %[[VSCALE:.*]] = vector.vscale
+// CHECK-DAG: %[[NUM_TILE_SLICES:.*]] = arith.muli %[[VSCALE]], %[[C4]] : index
+// CHECK-NEXT: scf.for %[[TILE_SLICE_INDEX:.*]] = %[[C0]] to %[[NUM_TILE_SLICES]] step %[[C1]] {
+// CHECK-NEXT: %[[TILE_SLICE:.*]] = arm_sme.move_tile_slice_to_vector %[[TILE]][%[[TILE_SLICE_INDEX]]] : vector<[4]xf32> from vector<[4]x[4]xf32>
+// CHECK-NEXT: vector.print %[[TILE_SLICE]] : vector<[4]xf32>
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM cheers
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks!
This was moved to VectorToArmSME in llvm#74063, so this is no longer needed. VectorToArmSME uses a greedy rewriter, so a similar legality rule is not needed there. See: https://github.com/llvm/llvm-project/blob/bbb8a0df7367068e1cf2fc54edd376beb976b430/mlir/lib/Conversion/VectorToArmSME/VectorToArmSMEPass.cpp#L35
…74875) This was moved to VectorToArmSME in #74063, so this is no longer needed. VectorToArmSME uses a greedy rewriter, so a similar legality rule is not needed there. See: https://github.com/llvm/llvm-project/blob/bbb8a0df7367068e1cf2fc54edd376beb976b430/mlir/lib/Conversion/VectorToArmSME/VectorToArmSMEPass.cpp#L35
This moves the SME tile vector.print lowering from
-convert-arm-sme-to-scf
to-convert-vector-to-arm-sme
. This seems like a more logical place, as this is lowering a vector op to ArmSME, and it also prevents vector.print from blocking tile allocation.