Skip to content

Commit e37d6d2

Browse files
authored
[mlir][ArmSME] Merge consecutive arm_sme.intr.zero ops (#106215)
This merges consecutive SME zero intrinsics within a basic block, which avoids the backend eventually emitting multiple zero instructions when it could just use one. Note: This kind of peephole optimization could be implemented in the backend too.
1 parent c9b6e01 commit e37d6d2

File tree

3 files changed

+90
-0
lines changed

3 files changed

+90
-0
lines changed

mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
#include "mlir/Dialect/Vector/IR/VectorOps.h"
2626
#include "mlir/Pass/Pass.h"
2727
#include "mlir/Transforms/DialectConversion.h"
28+
#include "llvm/ADT/ScopeExit.h"
2829

2930
namespace mlir {
3031
#define GEN_PASS_DEF_CONVERTARMSMETOLLVM
@@ -481,6 +482,9 @@ struct ZeroOpConversion : public ConvertArmSMEOpToLLVMPattern<arm_sme::ZeroOp> {
481482
loc, rewriter.getI32IntegerAttr(zeroMask));
482483

483484
// Create a placeholder op to preserve dataflow.
485+
// Note: Place the `get_tile` op at the start of the block. This ensures
486+
// that if there are multiple `zero` ops the intrinsics will be consecutive.
487+
rewriter.setInsertionPointToStart(zero->getBlock());
484488
rewriter.replaceOpWithNewOp<arm_sme::GetTileOp>(zero, zero.getVectorType());
485489

486490
return success();
@@ -855,6 +859,36 @@ struct StreamingVLOpConversion
855859
}
856860
};
857861

862+
/// Merges consecutive `arm_sme.intr.zero` operations in a block by bitwise
863+
/// or-ing the zero masks. Note: In future the backend _should_ handle this.
864+
static void mergeConsecutiveTileZerosInBlock(Block *block) {
865+
uint32_t mergedZeroMask = 0;
866+
SmallVector<arm_sme::aarch64_sme_zero, 16> zeroOpsToMerge;
867+
auto replaceMergedZeroOps = [&] {
868+
auto cleanup = llvm::make_scope_exit([&] {
869+
mergedZeroMask = 0;
870+
zeroOpsToMerge.clear();
871+
});
872+
if (zeroOpsToMerge.size() <= 1)
873+
return;
874+
IRRewriter rewriter(zeroOpsToMerge.front());
875+
rewriter.create<arm_sme::aarch64_sme_zero>(
876+
zeroOpsToMerge.front().getLoc(),
877+
rewriter.getI32IntegerAttr(mergedZeroMask));
878+
for (auto zeroOp : zeroOpsToMerge)
879+
rewriter.eraseOp(zeroOp);
880+
};
881+
for (Operation &op : *block) {
882+
if (auto zeroOp = dyn_cast<arm_sme::aarch64_sme_zero>(op)) {
883+
mergedZeroMask |= zeroOp.getTileMask();
884+
zeroOpsToMerge.push_back(zeroOp);
885+
} else {
886+
replaceMergedZeroOps();
887+
}
888+
}
889+
replaceMergedZeroOps();
890+
}
891+
858892
} // namespace
859893

860894
namespace {
@@ -879,6 +913,8 @@ struct ConvertArmSMEToLLVMPass
879913
if (failed(applyPartialConversion(function, target, std::move(patterns))))
880914
signalPassFailure();
881915

916+
function->walk(mergeConsecutiveTileZerosInBlock);
917+
882918
// Walk the function and fail if there are unexpected operations on SME
883919
// tile types after conversion.
884920
function->walk([&](Operation *op) {

mlir/test/Conversion/ArmSMEToLLVM/tile-spills-and-fills.mlir

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@
5353
/// These are obviously redundant, but there's no checks to avoid this.
5454
func.func @use_too_many_tiles() {
5555
%0 = arm_sme.zero : vector<[4]x[4]xi32>
56+
"test.prevent_zero_merge"() : () -> ()
5657
%1 = arm_sme.zero : vector<[4]x[4]xi32>
5758
// expected-warning @below {{failed to allocate SME virtual tile to operation, tile value will go through memory, expect degraded performance}}
5859
%2 = arm_sme.zero : vector<[8]x[8]xi16>

mlir/test/Dialect/ArmSME/tile-zero-masks.mlir

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ func.func @zero_za_b() {
1919
func.func @zero_za_h() {
2020
// CHECK: "arm_sme.intr.zero"() <{tile_mask = 85 : i32}> : () -> ()
2121
%zero_za0h = arm_sme.zero : vector<[8]x[8]xi16>
22+
"test.prevent_zero_merge"() : () -> ()
2223
// CHECK: "arm_sme.intr.zero"() <{tile_mask = 170 : i32}> : () -> ()
2324
%zero_za1h = arm_sme.zero : vector<[8]x[8]xf16>
2425
"test.some_use"(%zero_za0h) : (vector<[8]x[8]xi16>) -> ()
@@ -32,10 +33,13 @@ func.func @zero_za_h() {
3233
func.func @zero_za_s() {
3334
// CHECK: arm_sme.intr.zero"() <{tile_mask = 17 : i32}> : () -> ()
3435
%zero_za0s = arm_sme.zero : vector<[4]x[4]xi32>
36+
"test.prevent_zero_merge"() : () -> ()
3537
// CHECK: arm_sme.intr.zero"() <{tile_mask = 34 : i32}> : () -> ()
3638
%zero_za1s = arm_sme.zero : vector<[4]x[4]xi32>
39+
"test.prevent_zero_merge"() : () -> ()
3740
// CHECK: arm_sme.intr.zero"() <{tile_mask = 68 : i32}> : () -> ()
3841
%zero_za2s = arm_sme.zero : vector<[4]x[4]xi32>
42+
"test.prevent_zero_merge"() : () -> ()
3943
// CHECK: arm_sme.intr.zero"() <{tile_mask = 136 : i32}> : () -> ()
4044
%zero_za3s = arm_sme.zero : vector<[4]x[4]xf32>
4145
"test.some_use"(%zero_za0s) : (vector<[4]x[4]xi32>) -> ()
@@ -51,18 +55,25 @@ func.func @zero_za_s() {
5155
func.func @zero_za_d() {
5256
// CHECK: "arm_sme.intr.zero"() <{tile_mask = 1 : i32}> : () -> ()
5357
%zero_za0d = arm_sme.zero : vector<[2]x[2]xi64>
58+
"test.prevent_zero_merge"() : () -> ()
5459
// CHECK: "arm_sme.intr.zero"() <{tile_mask = 2 : i32}> : () -> ()
5560
%zero_za1d = arm_sme.zero : vector<[2]x[2]xi64>
61+
"test.prevent_zero_merge"() : () -> ()
5662
// CHECK: "arm_sme.intr.zero"() <{tile_mask = 4 : i32}> : () -> ()
5763
%zero_za2d = arm_sme.zero : vector<[2]x[2]xi64>
64+
"test.prevent_zero_merge"() : () -> ()
5865
// CHECK: "arm_sme.intr.zero"() <{tile_mask = 8 : i32}> : () -> ()
5966
%zero_za3d = arm_sme.zero : vector<[2]x[2]xi64>
67+
"test.prevent_zero_merge"() : () -> ()
6068
// CHECK: "arm_sme.intr.zero"() <{tile_mask = 16 : i32}> : () -> ()
6169
%zero_za4d = arm_sme.zero : vector<[2]x[2]xi64>
70+
"test.prevent_zero_merge"() : () -> ()
6271
// CHECK: "arm_sme.intr.zero"() <{tile_mask = 32 : i32}> : () -> ()
6372
%zero_za5d = arm_sme.zero : vector<[2]x[2]xi64>
73+
"test.prevent_zero_merge"() : () -> ()
6474
// CHECK: "arm_sme.intr.zero"() <{tile_mask = 64 : i32}> : () -> ()
6575
%zero_za6d = arm_sme.zero : vector<[2]x[2]xi64>
76+
"test.prevent_zero_merge"() : () -> ()
6677
// CHECK: "arm_sme.intr.zero"() <{tile_mask = 128 : i32}> : () -> ()
6778
%zero_za7d = arm_sme.zero : vector<[2]x[2]xf64>
6879
"test.some_use"(%zero_za0d) : (vector<[2]x[2]xi64>) -> ()
@@ -75,3 +86,45 @@ func.func @zero_za_d() {
7586
"test.some_use"(%zero_za7d) : (vector<[2]x[2]xf64>) -> ()
7687
return
7788
}
89+
90+
// -----
91+
92+
// CHECK-LABEL: merge_consecutive_tile_zero_ops
93+
func.func @merge_consecutive_tile_zero_ops() {
94+
// CHECK-NOT: arm_sme.intr.zero
95+
// CHECK: "arm_sme.intr.zero"() <{tile_mask = 255 : i32}> : () -> ()
96+
// CHECK-NOT: arm_sme.intr.zero
97+
%zero_za0s = arm_sme.zero : vector<[4]x[4]xi32>
98+
%zero_za1s = arm_sme.zero : vector<[4]x[4]xi32>
99+
%zero_za2s = arm_sme.zero : vector<[4]x[4]xi32>
100+
%zero_za3s = arm_sme.zero : vector<[4]x[4]xf32>
101+
"test.some_use"(%zero_za0s) : (vector<[4]x[4]xi32>) -> ()
102+
"test.some_use"(%zero_za1s) : (vector<[4]x[4]xi32>) -> ()
103+
"test.some_use"(%zero_za2s) : (vector<[4]x[4]xi32>) -> ()
104+
"test.some_use"(%zero_za3s) : (vector<[4]x[4]xf32>) -> ()
105+
return
106+
}
107+
108+
// -----
109+
110+
/// arm_sme.intr.zero intrinsics are not merged when there is an op other than
111+
/// arm_sme.intr.zero between them.
112+
113+
// CHECK-LABEL: merge_consecutive_tile_zero_ops_with_barrier
114+
func.func @merge_consecutive_tile_zero_ops_with_barrier() {
115+
// CHECK-NOT: arm_sme.intr.zero
116+
// CHECK: "arm_sme.intr.zero"() <{tile_mask = 51 : i32}> : () -> ()
117+
// CHECK-NOT: arm_sme.intr.zero
118+
%zero_za0s = arm_sme.zero : vector<[4]x[4]xi32>
119+
%zero_za1s = arm_sme.zero : vector<[4]x[4]xi32>
120+
"test.prevent_zero_merge"() : () -> ()
121+
// CHECK: "arm_sme.intr.zero"() <{tile_mask = 204 : i32}> : () -> ()
122+
// CHECK-NOT: arm_sme.intr.zero
123+
%zero_za2s = arm_sme.zero : vector<[4]x[4]xi32>
124+
%zero_za3s = arm_sme.zero : vector<[4]x[4]xf32>
125+
"test.some_use"(%zero_za0s) : (vector<[4]x[4]xi32>) -> ()
126+
"test.some_use"(%zero_za1s) : (vector<[4]x[4]xi32>) -> ()
127+
"test.some_use"(%zero_za2s) : (vector<[4]x[4]xi32>) -> ()
128+
"test.some_use"(%zero_za3s) : (vector<[4]x[4]xf32>) -> ()
129+
return
130+
}

0 commit comments

Comments
 (0)