Skip to content

Commit 3e2992f

Browse files
authored
[MLIR][Mem2Reg] Replace pattern based approach with a bulk one (#85426)
This commit changes MLIR's Mem2Reg implementation back from being pattern based into a full pass. Using Mem2Reg as a pattern is wasteful, as each application can invalidate the dominance info. Applying changes in bulk allows for reuse of the same dominance info. Unfortunately, this requires some test changes, due to the `IRBuilder` not simplifying IR.
1 parent a2fe410 commit 3e2992f

File tree

3 files changed

+43
-61
lines changed

3 files changed

+43
-61
lines changed

mlir/include/mlir/Transforms/Mem2Reg.h

Lines changed: 0 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,6 @@
99
#ifndef MLIR_TRANSFORMS_MEM2REG_H
1010
#define MLIR_TRANSFORMS_MEM2REG_H
1111

12-
#include "mlir/IR/Dominance.h"
13-
#include "mlir/IR/OpDefinition.h"
1412
#include "mlir/IR/PatternMatch.h"
1513
#include "mlir/Interfaces/MemorySlotInterfaces.h"
1614
#include "llvm/ADT/Statistic.h"
@@ -25,24 +23,6 @@ struct Mem2RegStatistics {
2523
llvm::Statistic *newBlockArgumentAmount = nullptr;
2624
};
2725

28-
/// Pattern applying mem2reg to the regions of the operations on which it
29-
/// matches.
30-
class Mem2RegPattern
31-
: public OpInterfaceRewritePattern<PromotableAllocationOpInterface> {
32-
public:
33-
using OpInterfaceRewritePattern::OpInterfaceRewritePattern;
34-
35-
Mem2RegPattern(MLIRContext *context, Mem2RegStatistics statistics = {},
36-
PatternBenefit benefit = 1)
37-
: OpInterfaceRewritePattern(context, benefit), statistics(statistics) {}
38-
39-
LogicalResult matchAndRewrite(PromotableAllocationOpInterface allocator,
40-
PatternRewriter &rewriter) const override;
41-
42-
private:
43-
Mem2RegStatistics statistics;
44-
};
45-
4626
/// Attempts to promote the memory slots of the provided allocators. Succeeds if
4727
/// at least one memory slot was promoted.
4828
LogicalResult

mlir/lib/Transforms/Mem2Reg.cpp

Lines changed: 27 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,8 @@
1414
#include "mlir/IR/Value.h"
1515
#include "mlir/Interfaces/ControlFlowInterfaces.h"
1616
#include "mlir/Interfaces/MemorySlotInterfaces.h"
17-
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
1817
#include "mlir/Transforms/Passes.h"
1918
#include "mlir/Transforms/RegionUtils.h"
20-
#include "llvm/ADT/PostOrderIterator.h"
2119
#include "llvm/ADT/STLExtras.h"
2220
#include "llvm/Support/Casting.h"
2321
#include "llvm/Support/GenericIteratedDominanceFrontier.h"
@@ -635,13 +633,6 @@ LogicalResult mlir::tryToPromoteMemorySlots(
635633
return success(promotedAny);
636634
}
637635

638-
LogicalResult
639-
Mem2RegPattern::matchAndRewrite(PromotableAllocationOpInterface allocator,
640-
PatternRewriter &rewriter) const {
641-
hasBoundedRewriteRecursion();
642-
return tryToPromoteMemorySlots({allocator}, rewriter, statistics);
643-
}
644-
645636
namespace {
646637

647638
struct Mem2Reg : impl::Mem2RegBase<Mem2Reg> {
@@ -650,17 +641,36 @@ struct Mem2Reg : impl::Mem2RegBase<Mem2Reg> {
650641
void runOnOperation() override {
651642
Operation *scopeOp = getOperation();
652643

653-
Mem2RegStatistics statictics{&promotedAmount, &newBlockArgumentAmount};
644+
Mem2RegStatistics statistics{&promotedAmount, &newBlockArgumentAmount};
645+
646+
bool changed = false;
647+
648+
for (Region &region : scopeOp->getRegions()) {
649+
if (region.getBlocks().empty())
650+
continue;
654651

655-
GreedyRewriteConfig config;
656-
config.enableRegionSimplification = enableRegionSimplification;
652+
OpBuilder builder(&region.front(), region.front().begin());
653+
IRRewriter rewriter(builder);
657654

658-
RewritePatternSet rewritePatterns(&getContext());
659-
rewritePatterns.add<Mem2RegPattern>(&getContext(), statictics);
660-
FrozenRewritePatternSet frozen(std::move(rewritePatterns));
655+
// Promoting a slot can allow for further promotion of other slots,
656+
// promotion is tried until no promotion succeeds.
657+
while (true) {
658+
SmallVector<PromotableAllocationOpInterface> allocators;
659+
// Build a list of allocators to attempt to promote the slots of.
660+
region.walk([&](PromotableAllocationOpInterface allocator) {
661+
allocators.emplace_back(allocator);
662+
});
661663

662-
if (failed(applyPatternsAndFoldGreedily(scopeOp, frozen, config)))
663-
signalPassFailure();
664+
// Attempt promoting until no promotion succeeds.
665+
if (failed(tryToPromoteMemorySlots(allocators, rewriter, statistics)))
666+
break;
667+
668+
changed = true;
669+
getAnalysisManager().invalidate({});
670+
}
671+
}
672+
if (!changed)
673+
markAllAnalysesPreserved();
664674
}
665675
};
666676

mlir/test/Dialect/LLVMIR/mem2reg-intrinsics.mlir

Lines changed: 16 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,18 @@
1-
// RUN: mlir-opt %s --pass-pipeline="builtin.module(llvm.func(mem2reg{region-simplify=false}))" --split-input-file | FileCheck %s
1+
// RUN: mlir-opt %s --pass-pipeline="builtin.module(llvm.func(mem2reg))" --split-input-file | FileCheck %s
22

33
// CHECK-LABEL: llvm.func @basic_memset
44
// CHECK-SAME: (%[[MEMSET_VALUE:.*]]: i8)
55
llvm.func @basic_memset(%memset_value: i8) -> i32 {
66
%0 = llvm.mlir.constant(1 : i32) : i32
77
%1 = llvm.alloca %0 x i32 {alignment = 4 : i64} : (i32) -> !llvm.ptr
88
%memset_len = llvm.mlir.constant(4 : i32) : i32
9-
// CHECK-DAG: %[[C8:.*]] = llvm.mlir.constant(8 : i32) : i32
10-
// CHECK-DAG: %[[C16:.*]] = llvm.mlir.constant(16 : i32) : i32
119
"llvm.intr.memset"(%1, %memset_value, %memset_len) <{isVolatile = false}> : (!llvm.ptr, i8, i32) -> ()
1210
// CHECK-NOT: "llvm.intr.memset"
1311
// CHECK: %[[VALUE_8:.*]] = llvm.zext %[[MEMSET_VALUE]] : i8 to i32
12+
// CHECK: %[[C8:.*]] = llvm.mlir.constant(8 : i32) : i32
1413
// CHECK: %[[SHIFTED_8:.*]] = llvm.shl %[[VALUE_8]], %[[C8]]
1514
// CHECK: %[[VALUE_16:.*]] = llvm.or %[[VALUE_8]], %[[SHIFTED_8]]
15+
// CHECK: %[[C16:.*]] = llvm.mlir.constant(16 : i32) : i32
1616
// CHECK: %[[SHIFTED_16:.*]] = llvm.shl %[[VALUE_16]], %[[C16]]
1717
// CHECK: %[[VALUE_32:.*]] = llvm.or %[[VALUE_16]], %[[SHIFTED_16]]
1818
// CHECK-NOT: "llvm.intr.memset"
@@ -31,7 +31,14 @@ llvm.func @basic_memset_constant() -> i32 {
3131
%memset_len = llvm.mlir.constant(4 : i32) : i32
3232
"llvm.intr.memset"(%1, %memset_value, %memset_len) <{isVolatile = false}> : (!llvm.ptr, i8, i32) -> ()
3333
%2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> i32
34-
// CHECK: %[[RES:.*]] = llvm.mlir.constant(707406378 : i32) : i32
34+
// CHECK: %[[C42:.*]] = llvm.mlir.constant(42 : i8) : i8
35+
// CHECK: %[[VALUE_42:.*]] = llvm.zext %[[C42]] : i8 to i32
36+
// CHECK: %[[C8:.*]] = llvm.mlir.constant(8 : i32) : i32
37+
// CHECK: %[[SHIFTED_42:.*]] = llvm.shl %[[VALUE_42]], %[[C8]] : i32
38+
// CHECK: %[[OR0:.*]] = llvm.or %[[VALUE_42]], %[[SHIFTED_42]] : i32
39+
// CHECK: %[[C16:.*]] = llvm.mlir.constant(16 : i32) : i32
40+
// CHECK: %[[SHIFTED:.*]] = llvm.shl %[[OR0]], %[[C16]] : i32
41+
// CHECK: %[[RES:..*]] = llvm.or %[[OR0]], %[[SHIFTED]] : i32
3542
// CHECK: llvm.return %[[RES]] : i32
3643
llvm.return %2 : i32
3744
}
@@ -44,16 +51,16 @@ llvm.func @exotic_target_memset(%memset_value: i8) -> i40 {
4451
%0 = llvm.mlir.constant(1 : i32) : i32
4552
%1 = llvm.alloca %0 x i40 {alignment = 4 : i64} : (i32) -> !llvm.ptr
4653
%memset_len = llvm.mlir.constant(5 : i32) : i32
47-
// CHECK-DAG: %[[C8:.*]] = llvm.mlir.constant(8 : i40) : i40
48-
// CHECK-DAG: %[[C16:.*]] = llvm.mlir.constant(16 : i40) : i40
49-
// CHECK-DAG: %[[C32:.*]] = llvm.mlir.constant(32 : i40) : i40
5054
"llvm.intr.memset"(%1, %memset_value, %memset_len) <{isVolatile = false}> : (!llvm.ptr, i8, i32) -> ()
5155
// CHECK-NOT: "llvm.intr.memset"
5256
// CHECK: %[[VALUE_8:.*]] = llvm.zext %[[MEMSET_VALUE]] : i8 to i40
57+
// CHECK: %[[C8:.*]] = llvm.mlir.constant(8 : i40) : i40
5358
// CHECK: %[[SHIFTED_8:.*]] = llvm.shl %[[VALUE_8]], %[[C8]]
5459
// CHECK: %[[VALUE_16:.*]] = llvm.or %[[VALUE_8]], %[[SHIFTED_8]]
60+
// CHECK: %[[C16:.*]] = llvm.mlir.constant(16 : i40) : i40
5561
// CHECK: %[[SHIFTED_16:.*]] = llvm.shl %[[VALUE_16]], %[[C16]]
5662
// CHECK: %[[VALUE_32:.*]] = llvm.or %[[VALUE_16]], %[[SHIFTED_16]]
63+
// CHECK: %[[C32:.*]] = llvm.mlir.constant(32 : i40) : i40
5764
// CHECK: %[[SHIFTED_COMPL:.*]] = llvm.shl %[[VALUE_32]], %[[C32]]
5865
// CHECK: %[[VALUE_COMPL:.*]] = llvm.or %[[VALUE_32]], %[[SHIFTED_COMPL]]
5966
// CHECK-NOT: "llvm.intr.memset"
@@ -64,21 +71,6 @@ llvm.func @exotic_target_memset(%memset_value: i8) -> i40 {
6471

6572
// -----
6673

67-
// CHECK-LABEL: llvm.func @exotic_target_memset_constant
68-
llvm.func @exotic_target_memset_constant() -> i40 {
69-
%0 = llvm.mlir.constant(1 : i32) : i32
70-
%1 = llvm.alloca %0 x i40 {alignment = 4 : i64} : (i32) -> !llvm.ptr
71-
%memset_value = llvm.mlir.constant(42 : i8) : i8
72-
%memset_len = llvm.mlir.constant(5 : i32) : i32
73-
"llvm.intr.memset"(%1, %memset_value, %memset_len) <{isVolatile = false}> : (!llvm.ptr, i8, i32) -> ()
74-
%2 = llvm.load %1 {alignment = 4 : i64} : !llvm.ptr -> i40
75-
// CHECK: %[[RES:.*]] = llvm.mlir.constant(181096032810 : i40) : i40
76-
// CHECK: llvm.return %[[RES]] : i40
77-
llvm.return %2 : i40
78-
}
79-
80-
// -----
81-
8274
// CHECK-LABEL: llvm.func @no_volatile_memset
8375
llvm.func @no_volatile_memset() -> i32 {
8476
// CHECK-DAG: %[[ALLOCA_LEN:.*]] = llvm.mlir.constant(1 : i32) : i32
@@ -195,7 +187,7 @@ llvm.func @basic_memcpy_dest(%destination: !llvm.ptr) -> i32 {
195187
// CHECK-LABEL: llvm.func @double_memcpy
196188
llvm.func @double_memcpy() -> i32 {
197189
%0 = llvm.mlir.constant(1 : i32) : i32
198-
// CHECK-NEXT: %[[DATA:.*]] = llvm.mlir.constant(42 : i32) : i32
190+
// CHECK: %[[DATA:.*]] = llvm.mlir.constant(42 : i32) : i32
199191
%data = llvm.mlir.constant(42 : i32) : i32
200192
%is_volatile = llvm.mlir.constant(false) : i1
201193
%memcpy_len = llvm.mlir.constant(4 : i32) : i32
@@ -206,7 +198,7 @@ llvm.func @double_memcpy() -> i32 {
206198
"llvm.intr.memcpy"(%2, %1, %memcpy_len) <{isVolatile = false}> : (!llvm.ptr, !llvm.ptr, i32) -> ()
207199

208200
%res = llvm.load %2 : !llvm.ptr -> i32
209-
// CHECK-NEXT: llvm.return %[[DATA]] : i32
201+
// CHECK: llvm.return %[[DATA]] : i32
210202
llvm.return %res : i32
211203
}
212204

0 commit comments

Comments
 (0)