-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][ArmSME] Merge consecutive arm_sme.intr.zero
ops
#106215
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
This merges consecutive SME zero intrinsics within a basic block, which avoids the backend eventually emitting multiple zero instructions when it could just use one. Note: This kind of peephole optimization could be implemented in the backend too.
@llvm/pr-subscribers-mlir Author: Benjamin Maxwell (MacDue) ChangesThis merges consecutive SME zero intrinsics within a basic block, which avoids the backend eventually emitting multiple zero instructions when it could just use one. Note: This kind of peephole optimization could be implemented in the backend too. Full diff: https://github.com/llvm/llvm-project/pull/106215.diff 3 Files Affected:
diff --git a/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp b/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
index 4d96091a637cf0..8cdf83e431b69b 100644
--- a/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
+++ b/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
@@ -25,6 +25,7 @@
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
+#include "llvm/ADT/ScopeExit.h"
namespace mlir {
#define GEN_PASS_DEF_CONVERTARMSMETOLLVM
@@ -481,6 +482,9 @@ struct ZeroOpConversion : public ConvertArmSMEOpToLLVMPattern<arm_sme::ZeroOp> {
loc, rewriter.getI32IntegerAttr(zeroMask));
// Create a placeholder op to preserve dataflow.
+ // Note: Place the `get_tile` op at the start of the block. This ensures
+ // that if there are multiple `zero` ops the intrinsics will be consecutive.
+ rewriter.setInsertionPointToStart(zero->getBlock());
rewriter.replaceOpWithNewOp<arm_sme::GetTileOp>(zero, zero.getVectorType());
return success();
@@ -855,6 +859,37 @@ struct StreamingVLOpConversion
}
};
+/// Merges consecutive `arm_sme.intr.zero` operations in a block by bitwise
+/// or-ing the zero masks. Note: In furture the backend _should_ handle this.
+static void mergeConsecutiveTileZerosInBlock(Block *block) {
+ uint32_t mergedZeroMask = 0;
+ SmallVector<arm_sme::aarch64_sme_zero, 16> zeroOpsToMerge;
+ auto replaceMergedZeroOps = [&] {
+ auto cleanup = llvm::make_scope_exit([&] {
+ mergedZeroMask = 0;
+ zeroOpsToMerge.clear();
+ });
+ if (zeroOpsToMerge.size() <= 1)
+ return;
+ IRRewriter rewriter(zeroOpsToMerge.front());
+ rewriter.setInsertionPoint(zeroOpsToMerge.front());
+ rewriter.create<arm_sme::aarch64_sme_zero>(
+ zeroOpsToMerge.front().getLoc(),
+ rewriter.getI32IntegerAttr(mergedZeroMask));
+ for (auto zeroOp : zeroOpsToMerge)
+ rewriter.eraseOp(zeroOp);
+ };
+ for (Operation &op : *block) {
+ if (auto zeroOp = dyn_cast<arm_sme::aarch64_sme_zero>(op)) {
+ mergedZeroMask |= zeroOp.getTileMask();
+ zeroOpsToMerge.push_back(zeroOp);
+ } else {
+ replaceMergedZeroOps();
+ }
+ }
+ replaceMergedZeroOps();
+}
+
} // namespace
namespace {
@@ -879,6 +914,8 @@ struct ConvertArmSMEToLLVMPass
if (failed(applyPartialConversion(function, target, std::move(patterns))))
signalPassFailure();
+ function->walk(mergeConsecutiveTileZerosInBlock);
+
// Walk the function and fail if there are unexpected operations on SME
// tile types after conversion.
function->walk([&](Operation *op) {
diff --git a/mlir/test/Conversion/ArmSMEToLLVM/tile-spills-and-fills.mlir b/mlir/test/Conversion/ArmSMEToLLVM/tile-spills-and-fills.mlir
index 1dced0fcd18c7a..2a183cb4d056a9 100644
--- a/mlir/test/Conversion/ArmSMEToLLVM/tile-spills-and-fills.mlir
+++ b/mlir/test/Conversion/ArmSMEToLLVM/tile-spills-and-fills.mlir
@@ -53,6 +53,7 @@
/// These are obviously redundant, but there's no checks to avoid this.
func.func @use_too_many_tiles() {
%0 = arm_sme.zero : vector<[4]x[4]xi32>
+ "test.prevent_zero_merge"() : () -> ()
%1 = arm_sme.zero : vector<[4]x[4]xi32>
// expected-warning @below {{failed to allocate SME virtual tile to operation, tile value will go through memory, expect degraded performance}}
%2 = arm_sme.zero : vector<[8]x[8]xi16>
diff --git a/mlir/test/Dialect/ArmSME/tile-zero-masks.mlir b/mlir/test/Dialect/ArmSME/tile-zero-masks.mlir
index ca339be5fb56f1..6e229b4a7de53a 100644
--- a/mlir/test/Dialect/ArmSME/tile-zero-masks.mlir
+++ b/mlir/test/Dialect/ArmSME/tile-zero-masks.mlir
@@ -19,6 +19,7 @@ func.func @zero_za_b() {
func.func @zero_za_h() {
// CHECK: "arm_sme.intr.zero"() <{tile_mask = 85 : i32}> : () -> ()
%zero_za0h = arm_sme.zero : vector<[8]x[8]xi16>
+ "test.prevent_zero_merge"() : () -> ()
// CHECK: "arm_sme.intr.zero"() <{tile_mask = 170 : i32}> : () -> ()
%zero_za1h = arm_sme.zero : vector<[8]x[8]xf16>
"test.some_use"(%zero_za0h) : (vector<[8]x[8]xi16>) -> ()
@@ -32,10 +33,13 @@ func.func @zero_za_h() {
func.func @zero_za_s() {
// CHECK: arm_sme.intr.zero"() <{tile_mask = 17 : i32}> : () -> ()
%zero_za0s = arm_sme.zero : vector<[4]x[4]xi32>
+ "test.prevent_zero_merge"() : () -> ()
// CHECK: arm_sme.intr.zero"() <{tile_mask = 34 : i32}> : () -> ()
%zero_za1s = arm_sme.zero : vector<[4]x[4]xi32>
+ "test.prevent_zero_merge"() : () -> ()
// CHECK: arm_sme.intr.zero"() <{tile_mask = 68 : i32}> : () -> ()
%zero_za2s = arm_sme.zero : vector<[4]x[4]xi32>
+ "test.prevent_zero_merge"() : () -> ()
// CHECK: arm_sme.intr.zero"() <{tile_mask = 136 : i32}> : () -> ()
%zero_za3s = arm_sme.zero : vector<[4]x[4]xf32>
"test.some_use"(%zero_za0s) : (vector<[4]x[4]xi32>) -> ()
@@ -51,18 +55,25 @@ func.func @zero_za_s() {
func.func @zero_za_d() {
// CHECK: "arm_sme.intr.zero"() <{tile_mask = 1 : i32}> : () -> ()
%zero_za0d = arm_sme.zero : vector<[2]x[2]xi64>
+ "test.prevent_zero_merge"() : () -> ()
// CHECK: "arm_sme.intr.zero"() <{tile_mask = 2 : i32}> : () -> ()
%zero_za1d = arm_sme.zero : vector<[2]x[2]xi64>
+ "test.prevent_zero_merge"() : () -> ()
// CHECK: "arm_sme.intr.zero"() <{tile_mask = 4 : i32}> : () -> ()
%zero_za2d = arm_sme.zero : vector<[2]x[2]xi64>
+ "test.prevent_zero_merge"() : () -> ()
// CHECK: "arm_sme.intr.zero"() <{tile_mask = 8 : i32}> : () -> ()
%zero_za3d = arm_sme.zero : vector<[2]x[2]xi64>
+ "test.prevent_zero_merge"() : () -> ()
// CHECK: "arm_sme.intr.zero"() <{tile_mask = 16 : i32}> : () -> ()
%zero_za4d = arm_sme.zero : vector<[2]x[2]xi64>
+ "test.prevent_zero_merge"() : () -> ()
// CHECK: "arm_sme.intr.zero"() <{tile_mask = 32 : i32}> : () -> ()
%zero_za5d = arm_sme.zero : vector<[2]x[2]xi64>
+ "test.prevent_zero_merge"() : () -> ()
// CHECK: "arm_sme.intr.zero"() <{tile_mask = 64 : i32}> : () -> ()
%zero_za6d = arm_sme.zero : vector<[2]x[2]xi64>
+ "test.prevent_zero_merge"() : () -> ()
// CHECK: "arm_sme.intr.zero"() <{tile_mask = 128 : i32}> : () -> ()
%zero_za7d = arm_sme.zero : vector<[2]x[2]xf64>
"test.some_use"(%zero_za0d) : (vector<[2]x[2]xi64>) -> ()
@@ -75,3 +86,40 @@ func.func @zero_za_d() {
"test.some_use"(%zero_za7d) : (vector<[2]x[2]xf64>) -> ()
return
}
+
+// -----
+
+// CHECK-LABEL: merge_consecutive_tile_zero_ops
+func.func @merge_consecutive_tile_zero_ops() {
+ // CHECK: "arm_sme.intr.zero"() <{tile_mask = 255 : i32}> : () -> ()
+ %zero_za0s = arm_sme.zero : vector<[4]x[4]xi32>
+ %zero_za1s = arm_sme.zero : vector<[4]x[4]xi32>
+ %zero_za2s = arm_sme.zero : vector<[4]x[4]xi32>
+ %zero_za3s = arm_sme.zero : vector<[4]x[4]xf32>
+ "test.some_use"(%zero_za0s) : (vector<[4]x[4]xi32>) -> ()
+ "test.some_use"(%zero_za1s) : (vector<[4]x[4]xi32>) -> ()
+ "test.some_use"(%zero_za2s) : (vector<[4]x[4]xi32>) -> ()
+ "test.some_use"(%zero_za3s) : (vector<[4]x[4]xf32>) -> ()
+ return
+}
+
+// -----
+
+/// arm_sme.intr.zero intrinsics are not merged when there is an op other than
+/// arm_sme.intr.zero between them.
+
+// CHECK-LABEL: merge_consecutive_tile_zero_ops_with_barrier
+func.func @merge_consecutive_tile_zero_ops_with_barrier() {
+ // CHECK: "arm_sme.intr.zero"() <{tile_mask = 51 : i32}> : () -> ()
+ %zero_za0s = arm_sme.zero : vector<[4]x[4]xi32>
+ %zero_za1s = arm_sme.zero : vector<[4]x[4]xi32>
+ "test.prevent_zero_merge"() : () -> ()
+ // CHECK: "arm_sme.intr.zero"() <{tile_mask = 204 : i32}> : () -> ()
+ %zero_za2s = arm_sme.zero : vector<[4]x[4]xi32>
+ %zero_za3s = arm_sme.zero : vector<[4]x[4]xf32>
+ "test.some_use"(%zero_za0s) : (vector<[4]x[4]xi32>) -> ()
+ "test.some_use"(%zero_za1s) : (vector<[4]x[4]xi32>) -> ()
+ "test.some_use"(%zero_za2s) : (vector<[4]x[4]xi32>) -> ()
+ "test.some_use"(%zero_za3s) : (vector<[4]x[4]xf32>) -> ()
+ return
+}
|
@llvm/pr-subscribers-mlir-sme Author: Benjamin Maxwell (MacDue) ChangesThis merges consecutive SME zero intrinsics within a basic block, which avoids the backend eventually emitting multiple zero instructions when it could just use one. Note: This kind of peephole optimization could be implemented in the backend too. Full diff: https://github.com/llvm/llvm-project/pull/106215.diff 3 Files Affected:
diff --git a/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp b/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
index 4d96091a637cf0..8cdf83e431b69b 100644
--- a/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
+++ b/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
@@ -25,6 +25,7 @@
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
+#include "llvm/ADT/ScopeExit.h"
namespace mlir {
#define GEN_PASS_DEF_CONVERTARMSMETOLLVM
@@ -481,6 +482,9 @@ struct ZeroOpConversion : public ConvertArmSMEOpToLLVMPattern<arm_sme::ZeroOp> {
loc, rewriter.getI32IntegerAttr(zeroMask));
// Create a placeholder op to preserve dataflow.
+ // Note: Place the `get_tile` op at the start of the block. This ensures
+ // that if there are multiple `zero` ops the intrinsics will be consecutive.
+ rewriter.setInsertionPointToStart(zero->getBlock());
rewriter.replaceOpWithNewOp<arm_sme::GetTileOp>(zero, zero.getVectorType());
return success();
@@ -855,6 +859,37 @@ struct StreamingVLOpConversion
}
};
+/// Merges consecutive `arm_sme.intr.zero` operations in a block by bitwise
+/// or-ing the zero masks. Note: In furture the backend _should_ handle this.
+static void mergeConsecutiveTileZerosInBlock(Block *block) {
+ uint32_t mergedZeroMask = 0;
+ SmallVector<arm_sme::aarch64_sme_zero, 16> zeroOpsToMerge;
+ auto replaceMergedZeroOps = [&] {
+ auto cleanup = llvm::make_scope_exit([&] {
+ mergedZeroMask = 0;
+ zeroOpsToMerge.clear();
+ });
+ if (zeroOpsToMerge.size() <= 1)
+ return;
+ IRRewriter rewriter(zeroOpsToMerge.front());
+ rewriter.setInsertionPoint(zeroOpsToMerge.front());
+ rewriter.create<arm_sme::aarch64_sme_zero>(
+ zeroOpsToMerge.front().getLoc(),
+ rewriter.getI32IntegerAttr(mergedZeroMask));
+ for (auto zeroOp : zeroOpsToMerge)
+ rewriter.eraseOp(zeroOp);
+ };
+ for (Operation &op : *block) {
+ if (auto zeroOp = dyn_cast<arm_sme::aarch64_sme_zero>(op)) {
+ mergedZeroMask |= zeroOp.getTileMask();
+ zeroOpsToMerge.push_back(zeroOp);
+ } else {
+ replaceMergedZeroOps();
+ }
+ }
+ replaceMergedZeroOps();
+}
+
} // namespace
namespace {
@@ -879,6 +914,8 @@ struct ConvertArmSMEToLLVMPass
if (failed(applyPartialConversion(function, target, std::move(patterns))))
signalPassFailure();
+ function->walk(mergeConsecutiveTileZerosInBlock);
+
// Walk the function and fail if there are unexpected operations on SME
// tile types after conversion.
function->walk([&](Operation *op) {
diff --git a/mlir/test/Conversion/ArmSMEToLLVM/tile-spills-and-fills.mlir b/mlir/test/Conversion/ArmSMEToLLVM/tile-spills-and-fills.mlir
index 1dced0fcd18c7a..2a183cb4d056a9 100644
--- a/mlir/test/Conversion/ArmSMEToLLVM/tile-spills-and-fills.mlir
+++ b/mlir/test/Conversion/ArmSMEToLLVM/tile-spills-and-fills.mlir
@@ -53,6 +53,7 @@
/// These are obviously redundant, but there's no checks to avoid this.
func.func @use_too_many_tiles() {
%0 = arm_sme.zero : vector<[4]x[4]xi32>
+ "test.prevent_zero_merge"() : () -> ()
%1 = arm_sme.zero : vector<[4]x[4]xi32>
// expected-warning @below {{failed to allocate SME virtual tile to operation, tile value will go through memory, expect degraded performance}}
%2 = arm_sme.zero : vector<[8]x[8]xi16>
diff --git a/mlir/test/Dialect/ArmSME/tile-zero-masks.mlir b/mlir/test/Dialect/ArmSME/tile-zero-masks.mlir
index ca339be5fb56f1..6e229b4a7de53a 100644
--- a/mlir/test/Dialect/ArmSME/tile-zero-masks.mlir
+++ b/mlir/test/Dialect/ArmSME/tile-zero-masks.mlir
@@ -19,6 +19,7 @@ func.func @zero_za_b() {
func.func @zero_za_h() {
// CHECK: "arm_sme.intr.zero"() <{tile_mask = 85 : i32}> : () -> ()
%zero_za0h = arm_sme.zero : vector<[8]x[8]xi16>
+ "test.prevent_zero_merge"() : () -> ()
// CHECK: "arm_sme.intr.zero"() <{tile_mask = 170 : i32}> : () -> ()
%zero_za1h = arm_sme.zero : vector<[8]x[8]xf16>
"test.some_use"(%zero_za0h) : (vector<[8]x[8]xi16>) -> ()
@@ -32,10 +33,13 @@ func.func @zero_za_h() {
func.func @zero_za_s() {
// CHECK: arm_sme.intr.zero"() <{tile_mask = 17 : i32}> : () -> ()
%zero_za0s = arm_sme.zero : vector<[4]x[4]xi32>
+ "test.prevent_zero_merge"() : () -> ()
// CHECK: arm_sme.intr.zero"() <{tile_mask = 34 : i32}> : () -> ()
%zero_za1s = arm_sme.zero : vector<[4]x[4]xi32>
+ "test.prevent_zero_merge"() : () -> ()
// CHECK: arm_sme.intr.zero"() <{tile_mask = 68 : i32}> : () -> ()
%zero_za2s = arm_sme.zero : vector<[4]x[4]xi32>
+ "test.prevent_zero_merge"() : () -> ()
// CHECK: arm_sme.intr.zero"() <{tile_mask = 136 : i32}> : () -> ()
%zero_za3s = arm_sme.zero : vector<[4]x[4]xf32>
"test.some_use"(%zero_za0s) : (vector<[4]x[4]xi32>) -> ()
@@ -51,18 +55,25 @@ func.func @zero_za_s() {
func.func @zero_za_d() {
// CHECK: "arm_sme.intr.zero"() <{tile_mask = 1 : i32}> : () -> ()
%zero_za0d = arm_sme.zero : vector<[2]x[2]xi64>
+ "test.prevent_zero_merge"() : () -> ()
// CHECK: "arm_sme.intr.zero"() <{tile_mask = 2 : i32}> : () -> ()
%zero_za1d = arm_sme.zero : vector<[2]x[2]xi64>
+ "test.prevent_zero_merge"() : () -> ()
// CHECK: "arm_sme.intr.zero"() <{tile_mask = 4 : i32}> : () -> ()
%zero_za2d = arm_sme.zero : vector<[2]x[2]xi64>
+ "test.prevent_zero_merge"() : () -> ()
// CHECK: "arm_sme.intr.zero"() <{tile_mask = 8 : i32}> : () -> ()
%zero_za3d = arm_sme.zero : vector<[2]x[2]xi64>
+ "test.prevent_zero_merge"() : () -> ()
// CHECK: "arm_sme.intr.zero"() <{tile_mask = 16 : i32}> : () -> ()
%zero_za4d = arm_sme.zero : vector<[2]x[2]xi64>
+ "test.prevent_zero_merge"() : () -> ()
// CHECK: "arm_sme.intr.zero"() <{tile_mask = 32 : i32}> : () -> ()
%zero_za5d = arm_sme.zero : vector<[2]x[2]xi64>
+ "test.prevent_zero_merge"() : () -> ()
// CHECK: "arm_sme.intr.zero"() <{tile_mask = 64 : i32}> : () -> ()
%zero_za6d = arm_sme.zero : vector<[2]x[2]xi64>
+ "test.prevent_zero_merge"() : () -> ()
// CHECK: "arm_sme.intr.zero"() <{tile_mask = 128 : i32}> : () -> ()
%zero_za7d = arm_sme.zero : vector<[2]x[2]xf64>
"test.some_use"(%zero_za0d) : (vector<[2]x[2]xi64>) -> ()
@@ -75,3 +86,40 @@ func.func @zero_za_d() {
"test.some_use"(%zero_za7d) : (vector<[2]x[2]xf64>) -> ()
return
}
+
+// -----
+
+// CHECK-LABEL: merge_consecutive_tile_zero_ops
+func.func @merge_consecutive_tile_zero_ops() {
+ // CHECK: "arm_sme.intr.zero"() <{tile_mask = 255 : i32}> : () -> ()
+ %zero_za0s = arm_sme.zero : vector<[4]x[4]xi32>
+ %zero_za1s = arm_sme.zero : vector<[4]x[4]xi32>
+ %zero_za2s = arm_sme.zero : vector<[4]x[4]xi32>
+ %zero_za3s = arm_sme.zero : vector<[4]x[4]xf32>
+ "test.some_use"(%zero_za0s) : (vector<[4]x[4]xi32>) -> ()
+ "test.some_use"(%zero_za1s) : (vector<[4]x[4]xi32>) -> ()
+ "test.some_use"(%zero_za2s) : (vector<[4]x[4]xi32>) -> ()
+ "test.some_use"(%zero_za3s) : (vector<[4]x[4]xf32>) -> ()
+ return
+}
+
+// -----
+
+/// arm_sme.intr.zero intrinsics are not merged when there is an op other than
+/// arm_sme.intr.zero between them.
+
+// CHECK-LABEL: merge_consecutive_tile_zero_ops_with_barrier
+func.func @merge_consecutive_tile_zero_ops_with_barrier() {
+ // CHECK: "arm_sme.intr.zero"() <{tile_mask = 51 : i32}> : () -> ()
+ %zero_za0s = arm_sme.zero : vector<[4]x[4]xi32>
+ %zero_za1s = arm_sme.zero : vector<[4]x[4]xi32>
+ "test.prevent_zero_merge"() : () -> ()
+ // CHECK: "arm_sme.intr.zero"() <{tile_mask = 204 : i32}> : () -> ()
+ %zero_za2s = arm_sme.zero : vector<[4]x[4]xi32>
+ %zero_za3s = arm_sme.zero : vector<[4]x[4]xf32>
+ "test.some_use"(%zero_za0s) : (vector<[4]x[4]xi32>) -> ()
+ "test.some_use"(%zero_za1s) : (vector<[4]x[4]xi32>) -> ()
+ "test.some_use"(%zero_za2s) : (vector<[4]x[4]xi32>) -> ()
+ "test.some_use"(%zero_za3s) : (vector<[4]x[4]xf32>) -> ()
+ return
+}
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
couple of minor comments, but otherwise LGTM cheers
I do agree this looks like something the backend could handle, we can evaluate effectiveness here in MLIR given how easy it is to do and feedback to relevant folks in backend
This merges consecutive SME zero intrinsics within a basic block, which avoids the backend eventually emitting multiple zero instructions when it could just use one.
Note: This kind of peephole optimization could be implemented in the backend too.