Skip to content

[mlir][ArmSME] Add custom vector.print lowering for SME tiles #66691

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Sep 26, 2023

Conversation

MacDue
Copy link
Member

@MacDue MacDue commented Sep 18, 2023

This adds a custom lowering for SME that loops over each row of the tile, extracting it via an SME MOVA, then printing with a normal 1D vector.print.

This makes writing SME integration tests easier and less verbose.

Depends on: #66910, #66911

@MacDue
Copy link
Member Author

MacDue commented Sep 18, 2023

cc @c-rhodes, @banach-space

@llvmbot
Copy link
Member

llvmbot commented Sep 18, 2023

@llvm/pr-subscribers-mlir-llvm
@llvm/pr-subscribers-mlir-sme
@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-vector

Changes

This adds a custom lowering for SME that loops over each row of the tile, extracting it via an SME MOVA, then printing with a normal 1D vector.print.

This patch is "done" but needs to be split into individual tested patches.

Currently this:

  • Registers the "arm_sme.intr.read.horiz" intrinsic, aka move tile slice to vector, aka MOVA
  • Adds a new "enable_arm_streaming_ignore" attribute to disable the "enable_arm_streaming" pass for certain functions in a file
    • This prevents ABI issues when calling helpers, as MLIR SME currently does not implement the ABI conventions needed for nested streaming mode calls
  • Adds a lowering for vector.print of SME tiles

Patch is 20.27 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/66691.diff

9 Files Affected:

  • (modified) mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td (+18-4)
  • (modified) mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h (+6)
  • (modified) mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp (+90-2)
  • (modified) mlir/lib/Dialect/ArmSME/Transforms/EnableArmStreaming.cpp (+5-1)
  • (modified) mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp (-17)
  • (modified) mlir/lib/Dialect/ArmSME/Utils/Utils.cpp (+14)
  • (modified) mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f32.mlir (+8-38)
  • (modified) mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f64.mlir (+5-21)
  • (modified) mlir/test/Integration/Dialect/Vector/CPU/ArmSME/tile_fill.mlir (+6-26)
diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td
index 7f02e723f3d91c2..10fee9251cd3e9e 100644
--- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td
@@ -469,15 +469,16 @@ def MOPVector : ScalableVectorOfLengthAndType<[16, 8, 4, 2],
 def LDSTPredicate : ScalableVectorOfLengthAndType<[16, 8, 4, 2, 1], [I1]>;
 
 class ArmSME_IntrOp<string mnemonic, list<int> overloadedOperands = [],
-                    list<Trait> traits = []>
+                    list<Trait> traits = [], int numResults = 0,
+                    list<int> overloadedResults = []>
     : LLVM_IntrOpBase<
           /*Dialect dialect=*/ArmSME_Dialect,
           /*string opName=*/"intr." # mnemonic,
           /*string enumName=*/"aarch64_sme_" # !subst(".", "_", mnemonic),
-          /*list<int> overloadedResults=*/[],
+          /*list<int> overloadedResults=*/overloadedResults,
           /*list<int> overloadedOperands=*/overloadedOperands,
           /*list<Trait> traits=*/traits,
-          /*int numResults=*/0>;
+          /*int numResults=*/numResults>;
 
 // Zero
 def LLVM_aarch64_sme_zero : ArmSME_IntrOp<"zero">,
@@ -548,7 +549,7 @@ def LLVM_aarch64_sme_str
       Arguments<(ins Arg<I32, "Index">,
                  Arg<LLVM_AnyPointer, "Store address", [MemWrite]>)>;
 
-// Vector to tile
+// Vector to tile slice
 class LLVM_aarch64_sme_write<string direction>
     : ArmSME_IntrOp<"write." # direction, /*overloadedOperands=*/[3],
                     [AllShapesMatch<["pg", "vector"]>]>,
@@ -557,9 +558,22 @@ class LLVM_aarch64_sme_write<string direction>
                      Arg<SVEPredicate, "Vector predicate">:$pg,
                      Arg<SVEVector, "Vector operand">:$vector)>;
 
+// Tile slice to vector
+class LLVM_aarch64_sme_read<string direction>
+    : ArmSME_IntrOp<"read." # direction, /*overloadedOperands=*/[],
+                    [AllShapesMatch<["pg", "res"]>],
+                    /*numResults*/1, /*overloadedResults*/[0]>,
+      Arguments<(ins Arg<SVEVector, "Vector operand">:$vector,
+                     Arg<SVEPredicate, "Vector predicate">:$pg,
+                     Arg<I32, "Virtual tile ID">,
+                     Arg<I32, "Tile slice">)>;
+
 def LLVM_aarch64_sme_write_horiz : LLVM_aarch64_sme_write<"horiz">;
 def LLVM_aarch64_sme_write_vert : LLVM_aarch64_sme_write<"vert">;
 
+def LLVM_aarch64_sme_read_horiz : LLVM_aarch64_sme_read<"horiz">;
+def LLVM_aarch64_sme_read_vert : LLVM_aarch64_sme_read<"vert">;
+
 def LLVM_aarch64_sme_za_enable : ArmSME_IntrOp<"za.enable">;
 def LLVM_aarch64_sme_za_disable : ArmSME_IntrOp<"za.disable">;
 
diff --git a/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h b/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h
index 9e8ad48b3c2db94..0941592497beaae 100644
--- a/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h
@@ -34,6 +34,12 @@ bool isValidSMETileElementType(Type type);
 /// otherwise.
 bool isValidSMETileVectorType(VectorType vType);
 
+/// Extends or truncates `tile`, which should be an `arm_sme::GetTileID` or
+/// `arm_sme::CastVectorToTile` op returning an 8/16/32/64/128-bit scalar
+/// integer, to an i32 that can be passed as the `tile` parameter to the SME
+/// intrinsics. Or returns `tile` if already i32.
+Value castTileIDToI32(Value tile, Location loc, RewriterBase &rewriter);
+
 } // namespace arm_sme
 } // namespace mlir
 
diff --git a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
index 4028a7ad0870b51..7c07672ce4a41fa 100644
--- a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
+++ b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
@@ -190,11 +190,93 @@ struct TileStoreOpConversion : public OpRewritePattern<arm_sme::TileStoreOp> {
   }
 };
 
+/// Lowers `vector.print` of a tile into a loop over the rows of the tile,
+/// extracting them via a MOVA, then printing with a 1D `vector.print`.
+///
+///  BEFORE:
+///  ```mlir
+///  vector.print %tile : vector<[4]x[4]xf32>
+///  ```
+///  AFTER:
+///  ```mlir
+///  %c0 = arith.constant 0 : index
+///  %c1 = arith.constant 1 : index
+///  %c4 = arith.constant 4 : index
+///  %ptrue = arith.constant dense<true> : vector<[4]xi1>
+///  %tile_id = arm_sme.cast_vector_to_tile %tile : vector<[4]x[4]xf32> to i32
+///  %vscale = vector.vscale
+///  %svl_s = arith.muli %c4, %vscale : index
+///  %cst = arith.constant dense<0.000000e+00> : vector<[4]xf32>
+///  scf.for %i = %c0 to %svl_s step %c1 {
+///    %slice_idx = arith.index_cast %i : index to i32
+///    %tile_slice = "arm_sme.intr.read.horiz"
+///        (%cst, %ptrue, %tile_id, %slice_idx)
+///      : (vector<[4]xf32>, vector<[4]xi1>, i32, i32) -> vector<[4]xf32>
+///    vector.print %tile_slice : vector<[4]xf32>
+///  }
+///  ```
+struct TileVectorPrintOpConversion : public OpRewritePattern<vector::PrintOp> {
+  using OpRewritePattern<vector::PrintOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::PrintOp printOp,
+                                PatternRewriter &rewriter) const override {
+    if (!printOp.getSource())
+      return failure();
+
+    VectorType vectorType = dyn_cast<VectorType>(printOp.getPrintType());
+    if (!vectorType || !arm_sme::isValidSMETileVectorType(vectorType))
+      return failure();
+
+    auto loc = printOp.getLoc();
+
+    // Create an 'all true' predicate for each tile row.
+    auto predicateType =
+        VectorType::get(vectorType.getDimSize(1), rewriter.getI1Type(), true);
+    auto rowType = VectorType::get(vectorType.getDimSize(1),
+                                   vectorType.getElementType(), true);
+    auto allTruePredicate = rewriter.create<arith::ConstantOp>(
+        loc, DenseElementsAttr::get(predicateType, true));
+
+    // Cast tile to i32 tile ID.
+    auto tileId =
+        rewriter.create<arm_sme::CastVectorToTile>(loc, printOp.getSource());
+    auto tileIdI32 = castTileIDToI32(tileId, loc, rewriter);
+    // Zero destination/fallback for tile slice extraction.
+    auto zeroVector = rewriter.create<arith::ConstantOp>(
+        loc, rowType, rewriter.getZeroAttr(rowType));
+
+    // Create a loop over the rows of the tile.
+    auto vscale = rewriter.create<vector::VectorScaleOp>(loc);
+    auto minTileRows =
+        rewriter.create<arith::ConstantIndexOp>(loc, vectorType.getDimSize(0));
+    auto lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+    auto upperBound = rewriter.create<arith::MulIOp>(loc, minTileRows, vscale);
+    auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
+    auto forOp = rewriter.create<scf::ForOp>(loc, lowerBound, upperBound, step);
+    {
+      // Loop body.
+      rewriter.setInsertionPointToStart(forOp.getBody());
+      // Extract the current row from the tile.
+      auto rowIndex = forOp.getInductionVar();
+      auto rowIndexI32 = rewriter.create<arith::IndexCastOp>(
+          loc, rewriter.getI32Type(), rowIndex);
+      auto tileSlice = rewriter.create<arm_sme::aarch64_sme_read_horiz>(
+          loc, rowType, zeroVector, allTruePredicate, tileIdI32, rowIndexI32);
+      // Print it with a 1D print.
+      rewriter.create<vector::PrintOp>(loc, tileSlice,
+                                       printOp.getPunctuation());
+    }
+
+    rewriter.eraseOp(printOp);
+    return success();
+  }
+};
+
 } // namespace
 
 void mlir::populateArmSMEToSCFConversionPatterns(RewritePatternSet &patterns) {
-  patterns.add<TileLoadOpConversion, TileStoreOpConversion>(
-      patterns.getContext());
+  patterns.add<TileLoadOpConversion, TileStoreOpConversion,
+               TileVectorPrintOpConversion>(patterns.getContext());
 }
 
 namespace {
@@ -208,6 +290,12 @@ struct ConvertArmSMEToSCFPass
     target.addLegalDialect<arm_sme::ArmSMEDialect, vector::VectorDialect,
                            arith::ArithDialect, scf::SCFDialect>();
     target.addIllegalOp<arm_sme::TileLoadOp, arm_sme::TileStoreOp>();
+    target.addDynamicallyLegalOp<vector::PrintOp>([](vector::PrintOp op) {
+      if (!op.getSource())
+        return true;
+      VectorType vectorType = dyn_cast<VectorType>(op.getPrintType());
+      return !vectorType || !arm_sme::isValidSMETileVectorType(vectorType);
+    });
     if (failed(applyPartialConversion(getOperation(), target,
                                       std::move(patterns))))
       signalPassFailure();
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/EnableArmStreaming.cpp b/mlir/lib/Dialect/ArmSME/Transforms/EnableArmStreaming.cpp
index 97c38b546349510..30ca414dde49d92 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/EnableArmStreaming.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/EnableArmStreaming.cpp
@@ -52,6 +52,8 @@ using namespace mlir::arm_sme;
 static constexpr char kArmStreamingAttr[] = "arm_streaming";
 static constexpr char kArmLocallyStreamingAttr[] = "arm_locally_streaming";
 static constexpr char kArmZAAttr[] = "arm_za";
+static constexpr char kArmEnableStreamingIgnore[] =
+    "enable_arm_streaming_ignore";
 
 namespace {
 struct EnableArmStreamingPass
@@ -61,7 +63,9 @@ struct EnableArmStreamingPass
     this->enableZA = enableZA;
   }
   void runOnOperation() override {
-    std::string attr;
+    if (getOperation()->getAttr(kArmEnableStreamingIgnore))
+      return;
+    StringRef attr;
     switch (mode) {
     case ArmStreaming::Default:
       attr = kArmStreamingAttr;
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
index 6c8843fbb4546e6..e4d1292358eb6d6 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
@@ -49,23 +49,6 @@ struct DisableZAPattern : public OpRewritePattern<func::ReturnOp> {
   }
 };
 
-/// Extends or truncates `tile`, which should be an `arm_sme::GetTileID` or
-/// `arm_sme::CastVectorToTile` op returning an 8/16/32/64/128-bit scalar
-/// integer, to an i32 that can be passed as the `tile` parameter to the SME
-/// intrinsics. Or returns `tile` if already i32.
-Value castTileIDToI32(Value tile, Location loc,
-                      ConversionPatternRewriter &rewriter) {
-  assert((isa<arm_sme::GetTileID, arm_sme::CastVectorToTile>(
-             tile.getDefiningOp())) &&
-         "expected ArmSME GetTileID or CastVectorToTile op!");
-  unsigned tileElementWidth = tile.getType().getIntOrFloatBitWidth();
-  if (tileElementWidth < 32)
-    return rewriter.create<arith::ExtUIOp>(loc, rewriter.getI32Type(), tile);
-  if (tileElementWidth > 32)
-    return rewriter.create<arith::TruncIOp>(loc, rewriter.getI32Type(), tile);
-  return tile;
-}
-
 /// Lower 'arm_sme.zero' to SME intrinsics.
 ///
 ///  BEFORE:
diff --git a/mlir/lib/Dialect/ArmSME/Utils/Utils.cpp b/mlir/lib/Dialect/ArmSME/Utils/Utils.cpp
index b8a47951cc7bbba..f17077ff8565d59 100644
--- a/mlir/lib/Dialect/ArmSME/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/ArmSME/Utils/Utils.cpp
@@ -12,6 +12,7 @@
 
 #include "mlir/Dialect/ArmSME/Utils/Utils.h"
 
+#include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/ArmSME/IR/ArmSME.h"
 
 using namespace mlir;
@@ -42,3 +43,16 @@ bool mlir::arm_sme::isValidSMETileVectorType(VectorType vType) {
 
   return true;
 }
+
+Value mlir::arm_sme::castTileIDToI32(Value tile, Location loc,
+                                     RewriterBase &rewriter) {
+  assert((isa<arm_sme::GetTileID, arm_sme::CastVectorToTile>(
+             tile.getDefiningOp())) &&
+         "expected ArmSME GetTileID or CastVectorToTile op!");
+  unsigned tileElementWidth = tile.getType().getIntOrFloatBitWidth();
+  if (tileElementWidth < 32)
+    return rewriter.create<arith::ExtUIOp>(loc, rewriter.getI32Type(), tile);
+  if (tileElementWidth > 32)
+    return rewriter.create<arith::TruncIOp>(loc, rewriter.getI32Type(), tile);
+  return tile;
+}
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f32.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f32.mlir
index 00f1f6fd3fa8e19..4265ca0f599281c 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f32.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f32.mlir
@@ -16,7 +16,7 @@
 
 llvm.func @printCString(!llvm.ptr<i8>)
 
-func.func @printTileBegin() {
+func.func @printTileBegin() attributes { enable_arm_streaming_ignore } {
   %0 = llvm.mlir.addressof @str_tile_begin : !llvm.ptr<array<11 x i8>>
   %1 = llvm.mlir.constant(0 : index) : i64
   %2 = llvm.getelementptr %0[%1, %1]
@@ -25,7 +25,7 @@ func.func @printTileBegin() {
   return
 }
 
-func.func @printTileEnd() {
+func.func @printTileEnd() attributes { enable_arm_streaming_ignore } {
   %0 = llvm.mlir.addressof @str_tile_end : !llvm.ptr<array<9 x i8>>
   %1 = llvm.mlir.constant(0 : index) : i64
   %2 = llvm.getelementptr %0[%1, %1]
@@ -41,20 +41,8 @@ func.func @test_outerproduct_no_accumulator_4x4xf32() {
   %vector = arith.sitofp %vector_i32 : vector<[4]xi32> to vector<[4]xf32>
   %tile = vector.outerproduct %vector, %vector : vector<[4]xf32>, vector<[4]xf32>
 
-  // Calculate the size of a 32-bit tile, e.g. ZA{n}.s.
-  %vscale = vector.vscale
-  %min_elts_s = arith.constant 4 : index
-  %svl_s = arith.muli %min_elts_s, %vscale : index
-  %za_s_size = arith.muli %svl_s, %svl_s : index
-
-  // Allocate memory.
-  %mem = memref.alloca(%za_s_size) : memref<?xf32>
-
-  // Store the tile to memory.
-  vector.store %tile, %mem[%c0] : memref<?xf32>, vector<[4]x[4]xf32>
-
-  // Reload and print. The smallest SVL is 128-bits so the tile will be at
-  // least 4x4xf32.
+  // Print the tile. The smallest SVL is 128-bits so the tile will be at least
+  // 4x4xf32.
   //
   // WITHOUT-ACC:      TILE BEGIN
   // WITHOUT-ACC-NEXT: ( 0, 0, 0, 0
@@ -63,10 +51,7 @@ func.func @test_outerproduct_no_accumulator_4x4xf32() {
   // WITHOUT-ACC-NEXT: ( 0, 3, 6, 9
   // WITHOUT-ACC:      TILE END
   func.call @printTileBegin() : () -> ()
-  scf.for %i = %c0 to %za_s_size step %svl_s {
-    %tileslice = vector.load %mem[%i] : memref<?xf32>, vector<[4]xf32>
-    vector.print %tileslice : vector<[4]xf32>
-  }
+  vector.print %tile : vector<[4]x[4]xf32>
   func.call @printTileEnd() : () -> ()
 
   return
@@ -81,20 +66,8 @@ func.func @test_outerproduct_with_accumulator_4x4xf32() {
   %vector = arith.sitofp %vector_i32 : vector<[4]xi32> to vector<[4]xf32>
   %tile = vector.outerproduct %vector, %vector, %acc : vector<[4]xf32>, vector<[4]xf32>
 
-  // Calculate the size of a 32-bit tile, e.g. ZA{n}.s.
-  %vscale = vector.vscale
-  %min_elts_s = arith.constant 4 : index
-  %svl_s = arith.muli %min_elts_s, %vscale : index
-  %za_s_size = arith.muli %svl_s, %svl_s : index
-
-  // Allocate memory.
-  %mem = memref.alloca(%za_s_size) : memref<?xf32>
-
-  // Store the tile to memory.
-  vector.store %tile, %mem[%c0] : memref<?xf32>, vector<[4]x[4]xf32>
-
-  // Reload and print. The smallest SVL is 128-bits so the tile will be at
-  // least 4x4xf32.
+  // Print the tile. The smallest SVL is 128-bits so the tile will be at least
+  // 4x4xf32.
   //
   // WITH-ACC:      TILE BEGIN
   // WITH-ACC-NEXT: ( 10, 10, 10, 10
@@ -103,10 +76,7 @@ func.func @test_outerproduct_with_accumulator_4x4xf32() {
   // WITH-ACC-NEXT: ( 10, 13, 16, 19
   // WITH-ACC:      TILE END
   func.call @printTileBegin() : () -> ()
-  scf.for %i = %c0 to %za_s_size step %svl_s {
-    %tileslice = vector.load %mem[%i] : memref<?xf32>, vector<[4]xf32>
-    vector.print %tileslice : vector<[4]xf32>
-  }
+  vector.print %tile : vector<[4]x[4]xf32>
   func.call @printTileEnd() : () -> ()
 
   return
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f64.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f64.mlir
index 2c2a06fa8db26e1..cb2c6b98a4eef3a 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f64.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f64.mlir
@@ -13,7 +13,7 @@
 
 llvm.func @printCString(!llvm.ptr<i8>)
 
-func.func @printTileBegin() {
+func.func @printTileBegin() attributes { enable_arm_streaming_ignore } {
   %0 = llvm.mlir.addressof @str_tile_begin : !llvm.ptr<array<11 x i8>>
   %1 = llvm.mlir.constant(0 : index) : i64
   %2 = llvm.getelementptr %0[%1, %1]
@@ -22,7 +22,7 @@ func.func @printTileBegin() {
   return
 }
 
-func.func @printTileEnd() {
+func.func @printTileEnd() attributes { enable_arm_streaming_ignore } {
   %0 = llvm.mlir.addressof @str_tile_end : !llvm.ptr<array<9 x i8>>
   %1 = llvm.mlir.constant(0 : index) : i64
   %2 = llvm.getelementptr %0[%1, %1]
@@ -32,7 +32,6 @@ func.func @printTileEnd() {
 }
 
 func.func @test_outerproduct_with_accumulator_2x2xf64() {
-  %c0 = arith.constant 0 : index
   %f1 = arith.constant 1.0 : f64
   %f2 = arith.constant 2.0 : f64
   %f10 = arith.constant 10.0 : f64
@@ -44,30 +43,15 @@ func.func @test_outerproduct_with_accumulator_2x2xf64() {
 
   %tile = vector.outerproduct %a, %b, %c : vector<[2]xf64>, vector<[2]xf64>
 
-  // Calculate the size of a 64-bit tile, e.g. ZA{n}.d.
-  %vscale = vector.vscale
-  %min_elts_d = arith.constant 2 : index
-  %svl_d = arith.muli %min_elts_d, %vscale : index
-  %za_d_size = arith.muli %svl_d, %svl_d : index
-
-  // Allocate memory.
-  %mem = memref.alloca(%za_d_size) : memref<?xf64>
-
-  // Store the tile to memory.
-  vector.store %tile, %mem[%c0] : memref<?xf64>, vector<[2]x[2]xf64>
-
-  // Reload and print. The smallest SVL is 128-bits so the tile will be at
-  // least 2x2xf64.
+  // Print the tile. The smallest SVL is 128-bits so the tile will be at least
+  // 2x2xf64.
   //
   // CHECK:      TILE BEGIN
   // CHECK-NEXT: ( 12, 12
   // CHECK-NEXT: ( 12, 12
   // CHECK:      TILE END
   func.call @printTileBegin() : () -> ()
-  scf.for %i = %c0 to %za_d_size step %svl_d {
-    %tileslice = vector.load %mem[%i] : memref<?xf64>, vector<[2]xf64>
-    vector.print %tileslice : vector<[2]xf64>
-  }
+  vector.print %tile : vector<[2]x[2]xf64>
   func.call @printTileEnd() : () -> ()
 
   return
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/tile_fill.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/tile_fill.mlir
index a407b13b541839f..1eaabbad68f3af8 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/tile_fill.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/tile_fill.mlir
@@ -13,7 +13,7 @@
 
 llvm.func @printCString(!llvm.ptr<i8>)
 
-func.func @printTileBegin() {
+func.func @printTileBegin() attributes { enable_arm_streaming_ignore } {
   %0 = llvm.mlir.addressof @str_tile_begin : !llvm.ptr<array<11 x i8>>
   %1 = llvm.mlir.constant(0 : index) : i64
   %2 = llvm.getelementptr %0[%1, %1]
@@ -22,7 +22,7 @@ func.func @printTileBegin() {
   return
 }
 
-func.func @printTileEnd() {
+func.func @printTileEnd() attributes { enable_arm_streaming_ignore } {
   %0 = llvm.mlir.addressof @str_tile_end : !llvm.ptr<array<9 x i8>>
   %1 = llvm.mlir.constant(0 : index) : i64
   %2 = llvm.getelementptr %0[%1, %1]
@@ -32,29 +32,15 @@ func.func @printTileEnd() {
 }
 
 func.func @entry() -> i32 {
-  %c0 = arith.constant 0 : index
-  %c1_index = arith.constant 1 : index
-
-  %min_elts_s = arith.constant 4 : index
-  %vscale = vector.vscale
-
-  // "svl" refers to the Streaming Vector Length and "svl_s" the number of
-  // 32-bit elements in a vector of SVL bits.
-  %svl_s = arith.muli %min_elts_s, %vscale : index
-
-  // Allocate memory.
-  %tilesize = arith.muli %svl_s, %svl_s : index
-  %mem = memref.alloca(%tilesize) : memref<?xi32>
-
   // Fill a tile with '123'. This will get lowered to a 1-d vector splat of
   // '123' and a loop that writes this vector to each tile slice in the ZA
   // tile.
   %tile = arith.constant dense<123> : vector<[4]x[4]xi32>
 
-  // Store tile to memory so it can be dumped.
-  vector.store %tile, %mem[%c0] : memref<?xi32>, vector<[4]x[4]xi32>
-
-  // Dump "mem". The smallest SVL is 128-bits so the tile will be at least
+  func.call @printTileBegin() : () -> ()
+  vector.print %tile : vector<[4]x[4]xi32>
+  func.call @printTileEnd() : () -> ()
+  // Print the tile. The smallest SVL is 128-bits so the tile will be at least
   // 4x4xi32.
   //
   // CHECK:      TILE BEGIN
@@ -63,12 +49,6 @@ func.func @entry() -> i32 {
   // CHECK-NEXT: ( 123, 123, 123, 123
   // CHECK-NEXT: ( 123, 123, 123, 123
   // CHECK:      TILE END
-  func.call @printTileBegin() : () ...
[truncated]

@c-rhodes
Copy link
Collaborator

This adds a custom lowering for SME that loops over each row of the tile, extracting it via an SME MOVA, then printing with a normal 1D vector.print.

Thanks Ben this is really great! Cleans up the tests nicely.

This patch is "done" but needs to be split into individual tested patches.

Are you planning to break this up into separate PRs?

* Adds a new "enable_arm_streaming_ignore" attribute to disable the "enable_arm_streaming" pass for certain functions in a file
  
  * This prevents ABI issues when calling helpers, as MLIR SME currently does not implement the ABI conventions needed for nested streaming mode calls

I like this. Long term the plan is to leverage the backend ABI support via attributes so we don't have to implement the ABI in MLIR, but that will depend on having the ABI support routines, D154045 adds these to compiler-rt but it hasn't landed yet, and that would also introduce a compiler-rt dependency on MLIR which needs some consideration. An attribute to disable streaming-mode is a nice stop gap!

@MacDue
Copy link
Member Author

MacDue commented Sep 19, 2023

Are you planning to break this up into separate PRs?

I plan on splitting adding the intrinsic and the enable_arm_streaming attribute into two separate PRs (with the tests you suggest). Once those are merged, I'll rebase and undraft this PR.

Copy link
Contributor

@dcaballe dcaballe left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome, thanks!

Copy link
Collaborator

@c-rhodes c-rhodes left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM cheers

This adds a custom lowering for SME that loops over each row of the
tile, extracting it via an SME MOVA, then printing with a normal 1D
vector.print.

This makes writing SME integration tests easier and less verbose.
@c-rhodes
Copy link
Collaborator

is this ready to land?

@MacDue MacDue merged commit 174cd61 into llvm:main Sep 26, 2023
@MacDue
Copy link
Member Author

MacDue commented Sep 26, 2023

is this ready to land?

I think so :)

legrosbuffle pushed a commit to legrosbuffle/llvm-project that referenced this pull request Sep 29, 2023
…6691)

This adds a custom lowering for SME that loops over each row of the
tile, extracting it via an SME MOVA, then printing with a normal 1D
vector.print.

This makes writing SME integration tests easier and less verbose.

Depends on: llvm#66910, llvm#66911
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants