-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][ArmSME] Add custom vector.print lowering for SME tiles #66691
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir-llvm @llvm/pr-subscribers-mlir-vector ChangesThis adds a custom lowering for SME that loops over each row of the tile, extracting it via an SME MOVA, then printing with a normal 1D vector.print. This patch is "done" but needs to be split into individual tested patches. Currently this:
Patch is 20.27 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/66691.diff 9 Files Affected:
diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td
index 7f02e723f3d91c2..10fee9251cd3e9e 100644
--- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td
@@ -469,15 +469,16 @@ def MOPVector : ScalableVectorOfLengthAndType<[16, 8, 4, 2],
def LDSTPredicate : ScalableVectorOfLengthAndType<[16, 8, 4, 2, 1], [I1]>;
class ArmSME_IntrOp<string mnemonic, list<int> overloadedOperands = [],
- list<Trait> traits = []>
+ list<Trait> traits = [], int numResults = 0,
+ list<int> overloadedResults = []>
: LLVM_IntrOpBase<
/*Dialect dialect=*/ArmSME_Dialect,
/*string opName=*/"intr." # mnemonic,
/*string enumName=*/"aarch64_sme_" # !subst(".", "_", mnemonic),
- /*list<int> overloadedResults=*/[],
+ /*list<int> overloadedResults=*/overloadedResults,
/*list<int> overloadedOperands=*/overloadedOperands,
/*list<Trait> traits=*/traits,
- /*int numResults=*/0>;
+ /*int numResults=*/numResults>;
// Zero
def LLVM_aarch64_sme_zero : ArmSME_IntrOp<"zero">,
@@ -548,7 +549,7 @@ def LLVM_aarch64_sme_str
Arguments<(ins Arg<I32, "Index">,
Arg<LLVM_AnyPointer, "Store address", [MemWrite]>)>;
-// Vector to tile
+// Vector to tile slice
class LLVM_aarch64_sme_write<string direction>
: ArmSME_IntrOp<"write." # direction, /*overloadedOperands=*/[3],
[AllShapesMatch<["pg", "vector"]>]>,
@@ -557,9 +558,22 @@ class LLVM_aarch64_sme_write<string direction>
Arg<SVEPredicate, "Vector predicate">:$pg,
Arg<SVEVector, "Vector operand">:$vector)>;
+// Tile slice to vector
+class LLVM_aarch64_sme_read<string direction>
+ : ArmSME_IntrOp<"read." # direction, /*overloadedOperands=*/[],
+ [AllShapesMatch<["pg", "res"]>],
+ /*numResults*/1, /*overloadedResults*/[0]>,
+ Arguments<(ins Arg<SVEVector, "Vector operand">:$vector,
+ Arg<SVEPredicate, "Vector predicate">:$pg,
+ Arg<I32, "Virtual tile ID">,
+ Arg<I32, "Tile slice">)>;
+
def LLVM_aarch64_sme_write_horiz : LLVM_aarch64_sme_write<"horiz">;
def LLVM_aarch64_sme_write_vert : LLVM_aarch64_sme_write<"vert">;
+def LLVM_aarch64_sme_read_horiz : LLVM_aarch64_sme_read<"horiz">;
+def LLVM_aarch64_sme_read_vert : LLVM_aarch64_sme_read<"vert">;
+
def LLVM_aarch64_sme_za_enable : ArmSME_IntrOp<"za.enable">;
def LLVM_aarch64_sme_za_disable : ArmSME_IntrOp<"za.disable">;
diff --git a/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h b/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h
index 9e8ad48b3c2db94..0941592497beaae 100644
--- a/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/ArmSME/Utils/Utils.h
@@ -34,6 +34,12 @@ bool isValidSMETileElementType(Type type);
/// otherwise.
bool isValidSMETileVectorType(VectorType vType);
+/// Extends or truncates `tile`, which should be an `arm_sme::GetTileID` or
+/// `arm_sme::CastVectorToTile` op returning an 8/16/32/64/128-bit scalar
+/// integer, to an i32 that can be passed as the `tile` parameter to the SME
+/// intrinsics. Or returns `tile` if already i32.
+Value castTileIDToI32(Value tile, Location loc, RewriterBase &rewriter);
+
} // namespace arm_sme
} // namespace mlir
diff --git a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
index 4028a7ad0870b51..7c07672ce4a41fa 100644
--- a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
+++ b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp
@@ -190,11 +190,93 @@ struct TileStoreOpConversion : public OpRewritePattern<arm_sme::TileStoreOp> {
}
};
+/// Lowers `vector.print` of a tile into a loop over the rows of the tile,
+/// extracting them via a MOVA, then printing with a 1D `vector.print`.
+///
+/// BEFORE:
+/// ```mlir
+/// vector.print %tile : vector<[4]x[4]xf32>
+/// ```
+/// AFTER:
+/// ```mlir
+/// %c0 = arith.constant 0 : index
+/// %c1 = arith.constant 1 : index
+/// %c4 = arith.constant 4 : index
+/// %ptrue = arith.constant dense<true> : vector<[4]xi1>
+/// %tile_id = arm_sme.cast_vector_to_tile %tile : vector<[4]x[4]xf32> to i32
+/// %vscale = vector.vscale
+/// %svl_s = arith.muli %c4, %vscale : index
+/// %cst = arith.constant dense<0.000000e+00> : vector<[4]xf32>
+/// scf.for %i = %c0 to %svl_s step %c1 {
+/// %slice_idx = arith.index_cast %i : index to i32
+/// %tile_slice = "arm_sme.intr.read.horiz"
+/// (%cst, %ptrue, %tile_id, %slice_idx)
+/// : (vector<[4]xf32>, vector<[4]xi1>, i32, i32) -> vector<[4]xf32>
+/// vector.print %tile_slice : vector<[4]xf32>
+/// }
+/// ```
+struct TileVectorPrintOpConversion : public OpRewritePattern<vector::PrintOp> {
+ using OpRewritePattern<vector::PrintOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::PrintOp printOp,
+ PatternRewriter &rewriter) const override {
+ if (!printOp.getSource())
+ return failure();
+
+ VectorType vectorType = dyn_cast<VectorType>(printOp.getPrintType());
+ if (!vectorType || !arm_sme::isValidSMETileVectorType(vectorType))
+ return failure();
+
+ auto loc = printOp.getLoc();
+
+ // Create an 'all true' predicate for each tile row.
+ auto predicateType =
+ VectorType::get(vectorType.getDimSize(1), rewriter.getI1Type(), true);
+ auto rowType = VectorType::get(vectorType.getDimSize(1),
+ vectorType.getElementType(), true);
+ auto allTruePredicate = rewriter.create<arith::ConstantOp>(
+ loc, DenseElementsAttr::get(predicateType, true));
+
+ // Cast tile to i32 tile ID.
+ auto tileId =
+ rewriter.create<arm_sme::CastVectorToTile>(loc, printOp.getSource());
+ auto tileIdI32 = castTileIDToI32(tileId, loc, rewriter);
+ // Zero destination/fallback for tile slice extraction.
+ auto zeroVector = rewriter.create<arith::ConstantOp>(
+ loc, rowType, rewriter.getZeroAttr(rowType));
+
+ // Create a loop over the rows of the tile.
+ auto vscale = rewriter.create<vector::VectorScaleOp>(loc);
+ auto minTileRows =
+ rewriter.create<arith::ConstantIndexOp>(loc, vectorType.getDimSize(0));
+ auto lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+ auto upperBound = rewriter.create<arith::MulIOp>(loc, minTileRows, vscale);
+ auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
+ auto forOp = rewriter.create<scf::ForOp>(loc, lowerBound, upperBound, step);
+ {
+ // Loop body.
+ rewriter.setInsertionPointToStart(forOp.getBody());
+ // Extract the current row from the tile.
+ auto rowIndex = forOp.getInductionVar();
+ auto rowIndexI32 = rewriter.create<arith::IndexCastOp>(
+ loc, rewriter.getI32Type(), rowIndex);
+ auto tileSlice = rewriter.create<arm_sme::aarch64_sme_read_horiz>(
+ loc, rowType, zeroVector, allTruePredicate, tileIdI32, rowIndexI32);
+ // Print it with a 1D print.
+ rewriter.create<vector::PrintOp>(loc, tileSlice,
+ printOp.getPunctuation());
+ }
+
+ rewriter.eraseOp(printOp);
+ return success();
+ }
+};
+
} // namespace
void mlir::populateArmSMEToSCFConversionPatterns(RewritePatternSet &patterns) {
- patterns.add<TileLoadOpConversion, TileStoreOpConversion>(
- patterns.getContext());
+ patterns.add<TileLoadOpConversion, TileStoreOpConversion,
+ TileVectorPrintOpConversion>(patterns.getContext());
}
namespace {
@@ -208,6 +290,12 @@ struct ConvertArmSMEToSCFPass
target.addLegalDialect<arm_sme::ArmSMEDialect, vector::VectorDialect,
arith::ArithDialect, scf::SCFDialect>();
target.addIllegalOp<arm_sme::TileLoadOp, arm_sme::TileStoreOp>();
+ target.addDynamicallyLegalOp<vector::PrintOp>([](vector::PrintOp op) {
+ if (!op.getSource())
+ return true;
+ VectorType vectorType = dyn_cast<VectorType>(op.getPrintType());
+ return !vectorType || !arm_sme::isValidSMETileVectorType(vectorType);
+ });
if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns))))
signalPassFailure();
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/EnableArmStreaming.cpp b/mlir/lib/Dialect/ArmSME/Transforms/EnableArmStreaming.cpp
index 97c38b546349510..30ca414dde49d92 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/EnableArmStreaming.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/EnableArmStreaming.cpp
@@ -52,6 +52,8 @@ using namespace mlir::arm_sme;
static constexpr char kArmStreamingAttr[] = "arm_streaming";
static constexpr char kArmLocallyStreamingAttr[] = "arm_locally_streaming";
static constexpr char kArmZAAttr[] = "arm_za";
+static constexpr char kArmEnableStreamingIgnore[] =
+ "enable_arm_streaming_ignore";
namespace {
struct EnableArmStreamingPass
@@ -61,7 +63,9 @@ struct EnableArmStreamingPass
this->enableZA = enableZA;
}
void runOnOperation() override {
- std::string attr;
+ if (getOperation()->getAttr(kArmEnableStreamingIgnore))
+ return;
+ StringRef attr;
switch (mode) {
case ArmStreaming::Default:
attr = kArmStreamingAttr;
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
index 6c8843fbb4546e6..e4d1292358eb6d6 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
@@ -49,23 +49,6 @@ struct DisableZAPattern : public OpRewritePattern<func::ReturnOp> {
}
};
-/// Extends or truncates `tile`, which should be an `arm_sme::GetTileID` or
-/// `arm_sme::CastVectorToTile` op returning an 8/16/32/64/128-bit scalar
-/// integer, to an i32 that can be passed as the `tile` parameter to the SME
-/// intrinsics. Or returns `tile` if already i32.
-Value castTileIDToI32(Value tile, Location loc,
- ConversionPatternRewriter &rewriter) {
- assert((isa<arm_sme::GetTileID, arm_sme::CastVectorToTile>(
- tile.getDefiningOp())) &&
- "expected ArmSME GetTileID or CastVectorToTile op!");
- unsigned tileElementWidth = tile.getType().getIntOrFloatBitWidth();
- if (tileElementWidth < 32)
- return rewriter.create<arith::ExtUIOp>(loc, rewriter.getI32Type(), tile);
- if (tileElementWidth > 32)
- return rewriter.create<arith::TruncIOp>(loc, rewriter.getI32Type(), tile);
- return tile;
-}
-
/// Lower 'arm_sme.zero' to SME intrinsics.
///
/// BEFORE:
diff --git a/mlir/lib/Dialect/ArmSME/Utils/Utils.cpp b/mlir/lib/Dialect/ArmSME/Utils/Utils.cpp
index b8a47951cc7bbba..f17077ff8565d59 100644
--- a/mlir/lib/Dialect/ArmSME/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/ArmSME/Utils/Utils.cpp
@@ -12,6 +12,7 @@
#include "mlir/Dialect/ArmSME/Utils/Utils.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/ArmSME/IR/ArmSME.h"
using namespace mlir;
@@ -42,3 +43,16 @@ bool mlir::arm_sme::isValidSMETileVectorType(VectorType vType) {
return true;
}
+
+Value mlir::arm_sme::castTileIDToI32(Value tile, Location loc,
+ RewriterBase &rewriter) {
+ assert((isa<arm_sme::GetTileID, arm_sme::CastVectorToTile>(
+ tile.getDefiningOp())) &&
+ "expected ArmSME GetTileID or CastVectorToTile op!");
+ unsigned tileElementWidth = tile.getType().getIntOrFloatBitWidth();
+ if (tileElementWidth < 32)
+ return rewriter.create<arith::ExtUIOp>(loc, rewriter.getI32Type(), tile);
+ if (tileElementWidth > 32)
+ return rewriter.create<arith::TruncIOp>(loc, rewriter.getI32Type(), tile);
+ return tile;
+}
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f32.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f32.mlir
index 00f1f6fd3fa8e19..4265ca0f599281c 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f32.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f32.mlir
@@ -16,7 +16,7 @@
llvm.func @printCString(!llvm.ptr<i8>)
-func.func @printTileBegin() {
+func.func @printTileBegin() attributes { enable_arm_streaming_ignore } {
%0 = llvm.mlir.addressof @str_tile_begin : !llvm.ptr<array<11 x i8>>
%1 = llvm.mlir.constant(0 : index) : i64
%2 = llvm.getelementptr %0[%1, %1]
@@ -25,7 +25,7 @@ func.func @printTileBegin() {
return
}
-func.func @printTileEnd() {
+func.func @printTileEnd() attributes { enable_arm_streaming_ignore } {
%0 = llvm.mlir.addressof @str_tile_end : !llvm.ptr<array<9 x i8>>
%1 = llvm.mlir.constant(0 : index) : i64
%2 = llvm.getelementptr %0[%1, %1]
@@ -41,20 +41,8 @@ func.func @test_outerproduct_no_accumulator_4x4xf32() {
%vector = arith.sitofp %vector_i32 : vector<[4]xi32> to vector<[4]xf32>
%tile = vector.outerproduct %vector, %vector : vector<[4]xf32>, vector<[4]xf32>
- // Calculate the size of a 32-bit tile, e.g. ZA{n}.s.
- %vscale = vector.vscale
- %min_elts_s = arith.constant 4 : index
- %svl_s = arith.muli %min_elts_s, %vscale : index
- %za_s_size = arith.muli %svl_s, %svl_s : index
-
- // Allocate memory.
- %mem = memref.alloca(%za_s_size) : memref<?xf32>
-
- // Store the tile to memory.
- vector.store %tile, %mem[%c0] : memref<?xf32>, vector<[4]x[4]xf32>
-
- // Reload and print. The smallest SVL is 128-bits so the tile will be at
- // least 4x4xf32.
+ // Print the tile. The smallest SVL is 128-bits so the tile will be at least
+ // 4x4xf32.
//
// WITHOUT-ACC: TILE BEGIN
// WITHOUT-ACC-NEXT: ( 0, 0, 0, 0
@@ -63,10 +51,7 @@ func.func @test_outerproduct_no_accumulator_4x4xf32() {
// WITHOUT-ACC-NEXT: ( 0, 3, 6, 9
// WITHOUT-ACC: TILE END
func.call @printTileBegin() : () -> ()
- scf.for %i = %c0 to %za_s_size step %svl_s {
- %tileslice = vector.load %mem[%i] : memref<?xf32>, vector<[4]xf32>
- vector.print %tileslice : vector<[4]xf32>
- }
+ vector.print %tile : vector<[4]x[4]xf32>
func.call @printTileEnd() : () -> ()
return
@@ -81,20 +66,8 @@ func.func @test_outerproduct_with_accumulator_4x4xf32() {
%vector = arith.sitofp %vector_i32 : vector<[4]xi32> to vector<[4]xf32>
%tile = vector.outerproduct %vector, %vector, %acc : vector<[4]xf32>, vector<[4]xf32>
- // Calculate the size of a 32-bit tile, e.g. ZA{n}.s.
- %vscale = vector.vscale
- %min_elts_s = arith.constant 4 : index
- %svl_s = arith.muli %min_elts_s, %vscale : index
- %za_s_size = arith.muli %svl_s, %svl_s : index
-
- // Allocate memory.
- %mem = memref.alloca(%za_s_size) : memref<?xf32>
-
- // Store the tile to memory.
- vector.store %tile, %mem[%c0] : memref<?xf32>, vector<[4]x[4]xf32>
-
- // Reload and print. The smallest SVL is 128-bits so the tile will be at
- // least 4x4xf32.
+ // Print the tile. The smallest SVL is 128-bits so the tile will be at least
+ // 4x4xf32.
//
// WITH-ACC: TILE BEGIN
// WITH-ACC-NEXT: ( 10, 10, 10, 10
@@ -103,10 +76,7 @@ func.func @test_outerproduct_with_accumulator_4x4xf32() {
// WITH-ACC-NEXT: ( 10, 13, 16, 19
// WITH-ACC: TILE END
func.call @printTileBegin() : () -> ()
- scf.for %i = %c0 to %za_s_size step %svl_s {
- %tileslice = vector.load %mem[%i] : memref<?xf32>, vector<[4]xf32>
- vector.print %tileslice : vector<[4]xf32>
- }
+ vector.print %tile : vector<[4]x[4]xf32>
func.call @printTileEnd() : () -> ()
return
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f64.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f64.mlir
index 2c2a06fa8db26e1..cb2c6b98a4eef3a 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f64.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f64.mlir
@@ -13,7 +13,7 @@
llvm.func @printCString(!llvm.ptr<i8>)
-func.func @printTileBegin() {
+func.func @printTileBegin() attributes { enable_arm_streaming_ignore } {
%0 = llvm.mlir.addressof @str_tile_begin : !llvm.ptr<array<11 x i8>>
%1 = llvm.mlir.constant(0 : index) : i64
%2 = llvm.getelementptr %0[%1, %1]
@@ -22,7 +22,7 @@ func.func @printTileBegin() {
return
}
-func.func @printTileEnd() {
+func.func @printTileEnd() attributes { enable_arm_streaming_ignore } {
%0 = llvm.mlir.addressof @str_tile_end : !llvm.ptr<array<9 x i8>>
%1 = llvm.mlir.constant(0 : index) : i64
%2 = llvm.getelementptr %0[%1, %1]
@@ -32,7 +32,6 @@ func.func @printTileEnd() {
}
func.func @test_outerproduct_with_accumulator_2x2xf64() {
- %c0 = arith.constant 0 : index
%f1 = arith.constant 1.0 : f64
%f2 = arith.constant 2.0 : f64
%f10 = arith.constant 10.0 : f64
@@ -44,30 +43,15 @@ func.func @test_outerproduct_with_accumulator_2x2xf64() {
%tile = vector.outerproduct %a, %b, %c : vector<[2]xf64>, vector<[2]xf64>
- // Calculate the size of a 64-bit tile, e.g. ZA{n}.d.
- %vscale = vector.vscale
- %min_elts_d = arith.constant 2 : index
- %svl_d = arith.muli %min_elts_d, %vscale : index
- %za_d_size = arith.muli %svl_d, %svl_d : index
-
- // Allocate memory.
- %mem = memref.alloca(%za_d_size) : memref<?xf64>
-
- // Store the tile to memory.
- vector.store %tile, %mem[%c0] : memref<?xf64>, vector<[2]x[2]xf64>
-
- // Reload and print. The smallest SVL is 128-bits so the tile will be at
- // least 2x2xf64.
+ // Print the tile. The smallest SVL is 128-bits so the tile will be at least
+ // 2x2xf64.
//
// CHECK: TILE BEGIN
// CHECK-NEXT: ( 12, 12
// CHECK-NEXT: ( 12, 12
// CHECK: TILE END
func.call @printTileBegin() : () -> ()
- scf.for %i = %c0 to %za_d_size step %svl_d {
- %tileslice = vector.load %mem[%i] : memref<?xf64>, vector<[2]xf64>
- vector.print %tileslice : vector<[2]xf64>
- }
+ vector.print %tile : vector<[2]x[2]xf64>
func.call @printTileEnd() : () -> ()
return
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/tile_fill.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/tile_fill.mlir
index a407b13b541839f..1eaabbad68f3af8 100644
--- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/tile_fill.mlir
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/tile_fill.mlir
@@ -13,7 +13,7 @@
llvm.func @printCString(!llvm.ptr<i8>)
-func.func @printTileBegin() {
+func.func @printTileBegin() attributes { enable_arm_streaming_ignore } {
%0 = llvm.mlir.addressof @str_tile_begin : !llvm.ptr<array<11 x i8>>
%1 = llvm.mlir.constant(0 : index) : i64
%2 = llvm.getelementptr %0[%1, %1]
@@ -22,7 +22,7 @@ func.func @printTileBegin() {
return
}
-func.func @printTileEnd() {
+func.func @printTileEnd() attributes { enable_arm_streaming_ignore } {
%0 = llvm.mlir.addressof @str_tile_end : !llvm.ptr<array<9 x i8>>
%1 = llvm.mlir.constant(0 : index) : i64
%2 = llvm.getelementptr %0[%1, %1]
@@ -32,29 +32,15 @@ func.func @printTileEnd() {
}
func.func @entry() -> i32 {
- %c0 = arith.constant 0 : index
- %c1_index = arith.constant 1 : index
-
- %min_elts_s = arith.constant 4 : index
- %vscale = vector.vscale
-
- // "svl" refers to the Streaming Vector Length and "svl_s" the number of
- // 32-bit elements in a vector of SVL bits.
- %svl_s = arith.muli %min_elts_s, %vscale : index
-
- // Allocate memory.
- %tilesize = arith.muli %svl_s, %svl_s : index
- %mem = memref.alloca(%tilesize) : memref<?xi32>
-
// Fill a tile with '123'. This will get lowered to a 1-d vector splat of
// '123' and a loop that writes this vector to each tile slice in the ZA
// tile.
%tile = arith.constant dense<123> : vector<[4]x[4]xi32>
- // Store tile to memory so it can be dumped.
- vector.store %tile, %mem[%c0] : memref<?xi32>, vector<[4]x[4]xi32>
-
- // Dump "mem". The smallest SVL is 128-bits so the tile will be at least
+ func.call @printTileBegin() : () -> ()
+ vector.print %tile : vector<[4]x[4]xi32>
+ func.call @printTileEnd() : () -> ()
+ // Print the tile. The smallest SVL is 128-bits so the tile will be at least
// 4x4xi32.
//
// CHECK: TILE BEGIN
@@ -63,12 +49,6 @@ func.func @entry() -> i32 {
// CHECK-NEXT: ( 123, 123, 123, 123
// CHECK-NEXT: ( 123, 123, 123, 123
// CHECK: TILE END
- func.call @printTileBegin() : () ...
[truncated]
|
Thanks Ben this is really great! Cleans up the tests nicely.
Are you planning to break this up into separate PRs?
I like this. Long term the plan is to leverage the backend ABI support via attributes so we don't have to implement the ABI in MLIR, but that will depend on having the ABI support routines, D154045 adds these to compiler-rt but it hasn't landed yet, and that would also introduce a compiler-rt dependency on MLIR which needs some consideration. An attribute to disable streaming-mode is a nice stop gap! |
I plan on splitting adding the intrinsic and the |
069e578
to
04c77e6
Compare
04c77e6
to
91daa66
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome, thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM cheers
This adds a custom lowering for SME that loops over each row of the tile, extracting it via an SME MOVA, then printing with a normal 1D vector.print. This makes writing SME integration tests easier and less verbose.
91daa66
to
387684b
Compare
is this ready to land? |
I think so :) |
…6691) This adds a custom lowering for SME that loops over each row of the tile, extracting it via an SME MOVA, then printing with a normal 1D vector.print. This makes writing SME integration tests easier and less verbose. Depends on: llvm#66910, llvm#66911
This adds a custom lowering for SME that loops over each row of the tile, extracting it via an SME MOVA, then printing with a normal 1D vector.print.
This makes writing SME integration tests easier and less verbose.
Depends on: #66910, #66911