Skip to content

[mlir][ArmSME] Move vector.print -> ArmSME lowering to VectorToArmSME #74063

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Dec 4, 2023

Conversation

MacDue
Copy link
Member

@MacDue MacDue commented Dec 1, 2023

This moves the SME tile vector.print lowering from -convert-arm-sme-to-scf to -convert-vector-to-arm-sme. This seems like a more logical place, as this is lowering a vector op to ArmSME, and it also prevents vector.print from blocking tile allocation.

This moves the SME tile vector.print lowering from `-convert-arm-sme-to-scf`
to `-convert-vector-to-arm-sme`. This seems like a more logical place,
as this is lowering a vector op to ArmSME, and it also prevents
vector.print from blocking tile allocation.
@llvmbot
Copy link
Member

llvmbot commented Dec 1, 2023

@llvm/pr-subscribers-mlir-sme

@llvm/pr-subscribers-mlir

Author: Benjamin Maxwell (MacDue)

Changes

This moves the SME tile vector.print lowering from -convert-arm-sme-to-scf to -convert-vector-to-arm-sme. This seems like a more logical place, as this is lowering a vector op to ArmSME, and it also prevents vector.print from blocking tile allocation.


Full diff: https://github.com/llvm/llvm-project/pull/74063.diff

4 Files Affected:

  • (modified) mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp (+4-66)
  • (modified) mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp (+67-7)
  • (modified) mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir (-22)
  • (modified) mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir (+22)
diff --git a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
index 69c68663070b6d5..fece03040dbb881 100644
--- a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
+++ b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
@@ -447,75 +447,13 @@ struct TileStoreOpConversion : public OpRewritePattern<arm_sme::TileStoreOp> {
   }
 };
 
-/// Lowers `vector.print` of a tile into a loop over the rows of the tile,
-/// extracting them via `arm_sme.move_tile_slice_to_vector`, then printing with
-/// a 1D `vector.print`.
-///
-///  BEFORE:
-///  ```mlir
-///  vector.print %tile : vector<[4]x[4]xf32>
-///  ```
-///  AFTER:
-///  ```mlir
-///  %c0 = arith.constant 0 : index
-///  %c1 = arith.constant 1 : index
-///  %c4 = arith.constant 4 : index
-///  %vscale = vector.vscale
-///  %svl_s = arith.muli %c4, %vscale : index
-///  scf.for %i = %c0 to %svl_s step %c1 {
-///    %tile_slice = arm_sme.move_tile_slice_to_vector %tile[%i]
-///                     : vector<[4]xf32> from vector<[4]x[4]xf32>
-///    vector.print %tile_slice : vector<[4]xf32>
-///  }
-///  ```
-struct TileVectorPrintOpConversion : public OpRewritePattern<vector::PrintOp> {
-  using OpRewritePattern<vector::PrintOp>::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(vector::PrintOp printOp,
-                                PatternRewriter &rewriter) const override {
-    if (!printOp.getSource())
-      return failure();
-
-    VectorType vectorType = dyn_cast<VectorType>(printOp.getPrintType());
-    if (!vectorType || !arm_sme::isValidSMETileVectorType(vectorType))
-      return failure();
-
-    auto loc = printOp.getLoc();
-
-    // Create a loop over the rows of the tile.
-    auto vscale = rewriter.create<vector::VectorScaleOp>(loc);
-    auto minTileRows =
-        rewriter.create<arith::ConstantIndexOp>(loc, vectorType.getDimSize(0));
-    auto lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0);
-    auto upperBound = rewriter.create<arith::MulIOp>(loc, minTileRows, vscale);
-    auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
-    auto forOp = rewriter.create<scf::ForOp>(loc, lowerBound, upperBound, step);
-    {
-      // Loop body.
-      rewriter.setInsertionPointToStart(forOp.getBody());
-      // Extract the current row from the tile.
-      Value rowIndex = forOp.getInductionVar();
-      // FIXME: Forward tile IDs.
-      // For now, if you vector.print a SME tile you need to do
-      // -allocate-arm-sme-tiles after -convert-arm-sme-to-scf.
-      auto tileSlice = rewriter.create<arm_sme::MoveTileSliceToVectorOp>(
-          loc, printOp.getSource(), rowIndex);
-      // Print the row with a 1D vector.print.
-      rewriter.create<vector::PrintOp>(loc, tileSlice,
-                                       printOp.getPunctuation());
-    }
-
-    rewriter.eraseOp(printOp);
-    return success();
-  }
-};
-
 } // namespace
 
 void mlir::populateArmSMEToSCFConversionPatterns(RewritePatternSet &patterns) {
-  patterns.add<TileLoadOpConversion, TileLoadOpWithMaskAndPadZeroConversion,
-               TileLoadOpWithMaskAndPadNonZeroConversion, TileStoreOpConversion,
-               TileVectorPrintOpConversion>(patterns.getContext());
+  patterns
+      .add<TileLoadOpConversion, TileLoadOpWithMaskAndPadZeroConversion,
+           TileLoadOpWithMaskAndPadNonZeroConversion, TileStoreOpConversion>(
+          patterns.getContext());
 }
 
 namespace {
diff --git a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
index 3016c7b0a84772d..4b3fd26c6d59ec7 100644
--- a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
+++ b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
@@ -666,15 +666,75 @@ struct VectorInsertToArmSMELowering
   }
 };
 
+/// Lowers `vector.print` of a tile into a loop over the rows of the tile,
+/// extracting them via `arm_sme.move_tile_slice_to_vector`, then printing with
+/// a 1D `vector.print`.
+///
+///  BEFORE:
+///  ```mlir
+///  vector.print %tile : vector<[4]x[4]xf32>
+///  ```
+///  AFTER:
+///  ```mlir
+///  %c0 = arith.constant 0 : index
+///  %c1 = arith.constant 1 : index
+///  %c4 = arith.constant 4 : index
+///  %vscale = vector.vscale
+///  %svl_s = arith.muli %c4, %vscale : index
+///  scf.for %i = %c0 to %svl_s step %c1 {
+///    %tile_slice = arm_sme.move_tile_slice_to_vector %tile[%i]
+///                     : vector<[4]xf32> from vector<[4]x[4]xf32>
+///    vector.print %tile_slice : vector<[4]xf32>
+///  }
+///  ```
+struct VectorPrintToArmSMELowering : public OpRewritePattern<vector::PrintOp> {
+  using OpRewritePattern<vector::PrintOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::PrintOp printOp,
+                                PatternRewriter &rewriter) const override {
+    if (!printOp.getSource())
+      return failure();
+
+    VectorType vectorType = dyn_cast<VectorType>(printOp.getPrintType());
+    if (!vectorType || !arm_sme::isValidSMETileVectorType(vectorType))
+      return failure();
+
+    auto loc = printOp.getLoc();
+
+    // Create a loop over the rows of the tile.
+    auto vscale = rewriter.create<vector::VectorScaleOp>(loc);
+    auto minTileRows =
+        rewriter.create<arith::ConstantIndexOp>(loc, vectorType.getDimSize(0));
+    auto lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+    auto upperBound = rewriter.create<arith::MulIOp>(loc, minTileRows, vscale);
+    auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
+    auto forOp = rewriter.create<scf::ForOp>(loc, lowerBound, upperBound, step);
+    {
+      // Loop body.
+      rewriter.setInsertionPointToStart(forOp.getBody());
+      // Extract the current row from the tile.
+      Value rowIndex = forOp.getInductionVar();
+      auto tileSlice = rewriter.create<arm_sme::MoveTileSliceToVectorOp>(
+          loc, printOp.getSource(), rowIndex);
+      // Print the row with a 1D vector.print.
+      rewriter.create<vector::PrintOp>(loc, tileSlice,
+                                       printOp.getPunctuation());
+    }
+
+    rewriter.eraseOp(printOp);
+    return success();
+  }
+};
+
 } // namespace
 
 void mlir::populateVectorToArmSMEPatterns(RewritePatternSet &patterns,
                                           MLIRContext &ctx) {
-  patterns.add<BroadcastOpToArmSMELowering, ConstantOpToArmSMELowering,
-               SplatOpToArmSMELowering, TransferReadToArmSMELowering,
-               TransferWriteToArmSMELowering, TransposeOpToArmSMELowering,
-               VectorLoadToArmSMELowering, VectorStoreToArmSMELowering,
-               VectorOuterProductToArmSMELowering,
-               VectorExtractToArmSMELowering, VectorInsertToArmSMELowering>(
-      &ctx);
+  patterns
+      .add<BroadcastOpToArmSMELowering, ConstantOpToArmSMELowering,
+           SplatOpToArmSMELowering, TransferReadToArmSMELowering,
+           TransferWriteToArmSMELowering, TransposeOpToArmSMELowering,
+           VectorLoadToArmSMELowering, VectorStoreToArmSMELowering,
+           VectorOuterProductToArmSMELowering, VectorExtractToArmSMELowering,
+           VectorInsertToArmSMELowering, VectorPrintToArmSMELowering>(&ctx);
 }
diff --git a/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir b/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir
index fc28645a7acf7c0..efefc6c49e08f04 100644
--- a/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir
+++ b/mlir/test/Conversion/ArmSMEToSCF/arm-sme-to-scf.mlir
@@ -160,25 +160,3 @@ func.func @arm_sme_tile_store_hor_with_mask(%tile : vector<[4]x[4]xi32>, %dest :
   arm_sme.tile_store %tile, %dest[%c0, %c0], %mask : memref<?x?xi32>, vector<[4]x[4]xi32>
   return
 }
-
-//===----------------------------------------------------------------------===//
-// vector.print
-//===----------------------------------------------------------------------===//
-
-// -----
-
-func.func @arm_sme_tile_print(%tile: vector<[4]x[4]xf32>)
-{
-  vector.print %tile : vector<[4]x[4]xf32>
-  return
-}
-// CHECK-LABEL:   func.func @arm_sme_tile_print(
-// CHECK-SAME:                                  %[[TILE:.*]]: vector<[4]x[4]xf32>) {
-// CHECK-DAG:     %[[C0:.*]] = arith.constant 0 : index
-// CHECK-DAG:     %[[C1:.*]] = arith.constant 1 : index
-// CHECK-DAG:     %[[C4:.*]] = arith.constant 4 : index
-// CHECK-DAG:     %[[VSCALE:.*]] = vector.vscale
-// CHECK-DAG:     %[[NUM_TILE_SLICES:.*]] = arith.muli %[[C4]], %[[VSCALE]] : index
-// CHECK-NEXT:      scf.for %[[TILE_SLICE_INDEX:.*]] = %[[C0]] to %[[NUM_TILE_SLICES]] step %[[C1]] {
-// CHECK-NEXT:        %[[TILE_SLICE:.*]] = arm_sme.move_tile_slice_to_vector %[[TILE]][%[[TILE_SLICE_INDEX]]] : vector<[4]xf32> from vector<[4]x[4]xf32>
-// CHECK-NEXT:        vector.print %[[TILE_SLICE]] : vector<[4]xf32>
diff --git a/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir b/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir
index 2491b2e2468cdaf..5bc147c60f3a664 100644
--- a/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir
+++ b/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir
@@ -736,3 +736,25 @@ func.func @vector_outerproduct_f64(%lhs : vector<[2]xf64>, %rhs : vector<[2]xf64
   %result = vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>} : vector<[2]xf64>, vector<[2]xf64>
   "prevent.dce"(%result) : (vector<[2]x[2]xf64>) -> ()
 }
+
+//===----------------------------------------------------------------------===//
+// vector.print
+//===----------------------------------------------------------------------===//
+
+// -----
+
+func.func @vector_print_tile(%tile: vector<[4]x[4]xf32>)
+{
+  vector.print %tile : vector<[4]x[4]xf32>
+  return
+}
+// CHECK-LABEL:   func.func @vector_print_tile(
+// CHECK-SAME:                                  %[[TILE:.*]]: vector<[4]x[4]xf32>) {
+// CHECK-DAG:     %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG:     %[[C1:.*]] = arith.constant 1 : index
+// CHECK-DAG:     %[[C4:.*]] = arith.constant 4 : index
+// CHECK-DAG:     %[[VSCALE:.*]] = vector.vscale
+// CHECK-DAG:     %[[NUM_TILE_SLICES:.*]] = arith.muli %[[VSCALE]], %[[C4]] : index
+// CHECK-NEXT:      scf.for %[[TILE_SLICE_INDEX:.*]] = %[[C0]] to %[[NUM_TILE_SLICES]] step %[[C1]] {
+// CHECK-NEXT:        %[[TILE_SLICE:.*]] = arm_sme.move_tile_slice_to_vector %[[TILE]][%[[TILE_SLICE_INDEX]]] : vector<[4]xf32> from vector<[4]x[4]xf32>
+// CHECK-NEXT:        vector.print %[[TILE_SLICE]] : vector<[4]xf32>

Copy link
Collaborator

@c-rhodes c-rhodes left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM cheers

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks!

@MacDue MacDue merged commit 10063c5 into llvm:main Dec 4, 2023
@MacDue MacDue deleted the arm_sme_move_vector_print_lowering branch December 4, 2023 09:42
MacDue added a commit to MacDue/llvm-project that referenced this pull request Dec 8, 2023
This was moved to VectorToArmSME in llvm#74063, so this is no longer needed.

VectorToArmSME uses a greedy rewriter, so a similar legality rule is not
needed there.

See: https://github.com/llvm/llvm-project/blob/bbb8a0df7367068e1cf2fc54edd376beb976b430/mlir/lib/Conversion/VectorToArmSME/VectorToArmSMEPass.cpp#L35
MacDue added a commit that referenced this pull request Dec 11, 2023
…74875)

This was moved to VectorToArmSME in #74063, so this is no longer needed.

VectorToArmSME uses a greedy rewriter, so a similar legality rule is not
needed there.

See:
https://github.com/llvm/llvm-project/blob/bbb8a0df7367068e1cf2fc54edd376beb976b430/mlir/lib/Conversion/VectorToArmSME/VectorToArmSMEPass.cpp#L35
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants