Skip to content

Commit 525d60b

Browse files
Théo Degioannigysit
authored andcommitted
[mlir][mem2reg] Add support for mem2reg in MemRef.
This patch implements the mem2reg interfaces for MemRef types. This only supports scalar memrefs of a small list of types. It would be beneficial to create more interfaces for default values before expanding support to more types. Additionally, I am working on an upcoming revision to bring SROA to MLIR that should help with non-scalar memrefs. Reviewed By: gysit, Mogball Differential Revision: https://reviews.llvm.org/D149441
1 parent 9b21cb2 commit 525d60b

File tree

7 files changed

+296
-3
lines changed

7 files changed

+296
-3
lines changed

mlir/include/mlir/Dialect/MemRef/IR/MemRef.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#include "mlir/Interfaces/ControlFlowInterfaces.h"
1717
#include "mlir/Interfaces/CopyOpInterface.h"
1818
#include "mlir/Interfaces/InferTypeOpInterface.h"
19+
#include "mlir/Interfaces/Mem2RegInterfaces.h"
1920
#include "mlir/Interfaces/ShapedOpInterfaces.h"
2021
#include "mlir/Interfaces/SideEffectInterfaces.h"
2122
#include "mlir/Interfaces/ViewLikeInterface.h"

mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ include "mlir/Interfaces/CastInterfaces.td"
1515
include "mlir/Interfaces/ControlFlowInterfaces.td"
1616
include "mlir/Interfaces/CopyOpInterface.td"
1717
include "mlir/Interfaces/InferTypeOpInterface.td"
18+
include "mlir/Interfaces/Mem2RegInterfaces.td"
1819
include "mlir/Interfaces/ShapedOpInterfaces.td"
1920
include "mlir/Interfaces/SideEffectInterfaces.td"
2021
include "mlir/Interfaces/ViewLikeInterface.td"
@@ -309,7 +310,8 @@ def MemRef_ReallocOp : MemRef_Op<"realloc"> {
309310
//===----------------------------------------------------------------------===//
310311

311312
def MemRef_AllocaOp : AllocLikeOp<"alloca", AutomaticAllocationScopeResource,[
312-
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>]> {
313+
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
314+
DeclareOpInterfaceMethods<PromotableAllocationOpInterface>]> {
313315
let summary = "stack memory allocation operation";
314316
let description = [{
315317
The `alloca` operation allocates memory on the stack, to be automatically
@@ -1159,7 +1161,8 @@ def LoadOp : MemRef_Op<"load",
11591161
[TypesMatchWith<"result type matches element type of 'memref'",
11601162
"memref", "result",
11611163
"$_self.cast<MemRefType>().getElementType()">,
1162-
MemRefsNormalizable]> {
1164+
MemRefsNormalizable,
1165+
DeclareOpInterfaceMethods<PromotableMemOpInterface>]> {
11631166
let summary = "load operation";
11641167
let description = [{
11651168
The `load` op reads an element from a memref specified by an index list. The
@@ -1748,7 +1751,8 @@ def MemRef_StoreOp : MemRef_Op<"store",
17481751
[TypesMatchWith<"type of 'value' matches element type of 'memref'",
17491752
"memref", "value",
17501753
"$_self.cast<MemRefType>().getElementType()">,
1751-
MemRefsNormalizable]> {
1754+
MemRefsNormalizable,
1755+
DeclareOpInterfaceMethods<PromotableMemOpInterface>]> {
17521756
let summary = "store operation";
17531757
let description = [{
17541758
Store a value to a memref location given by indices. The value stored should

mlir/lib/Dialect/MemRef/IR/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
add_mlir_dialect_library(MLIRMemRefDialect
22
MemRefDialect.cpp
3+
MemRefMem2Reg.cpp
34
MemRefOps.cpp
45
ValueBoundsOpInterfaceImpl.cpp
56

Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
//===- MemRefMem2Reg.cpp - Mem2Reg Interfaces -------------------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This file implements Mem2Reg-related interfaces for MemRef dialect
10+
// operations.
11+
//
12+
//===----------------------------------------------------------------------===//
13+
14+
#include "mlir/Dialect/Arith/IR/Arith.h"
15+
#include "mlir/Dialect/Complex/IR/Complex.h"
16+
#include "mlir/Dialect/MemRef/IR/MemRef.h"
17+
#include "mlir/Interfaces/Mem2RegInterfaces.h"
18+
#include "llvm/ADT/TypeSwitch.h"
19+
20+
using namespace mlir;
21+
22+
//===----------------------------------------------------------------------===//
23+
// AllocaOp interfaces
24+
//===----------------------------------------------------------------------===//
25+
26+
static bool isSupportedElementType(Type type) {
27+
return type.isa<MemRefType>() ||
28+
OpBuilder(type.getContext()).getZeroAttr(type);
29+
}
30+
31+
SmallVector<MemorySlot> memref::AllocaOp::getPromotableSlots() {
32+
MemRefType type = getType();
33+
if (!isSupportedElementType(type.getElementType()))
34+
return {};
35+
if (!type.hasStaticShape())
36+
return {};
37+
// Make sure the memref contains only a single element.
38+
if (any_of(type.getShape(), [](uint64_t dim) { return dim != 1; }))
39+
return {};
40+
41+
return {MemorySlot{getResult(), type.getElementType()}};
42+
}
43+
44+
Value memref::AllocaOp::getDefaultValue(const MemorySlot &slot,
45+
OpBuilder &builder) {
46+
assert(isSupportedElementType(slot.elemType));
47+
// TODO: support more types.
48+
return TypeSwitch<Type, Value>(slot.elemType)
49+
.Case([&](MemRefType t) {
50+
return builder.create<memref::AllocaOp>(getLoc(), t);
51+
})
52+
.Default([&](Type t) {
53+
return builder.create<arith::ConstantOp>(getLoc(), t,
54+
builder.getZeroAttr(t));
55+
});
56+
}
57+
58+
void memref::AllocaOp::handlePromotionComplete(const MemorySlot &slot,
59+
Value defaultValue) {
60+
if (defaultValue.use_empty())
61+
defaultValue.getDefiningOp()->erase();
62+
erase();
63+
}
64+
65+
void memref::AllocaOp::handleBlockArgument(const MemorySlot &slot,
66+
BlockArgument argument,
67+
OpBuilder &builder) {}
68+
69+
//===----------------------------------------------------------------------===//
70+
// LoadOp/StoreOp interfaces
71+
//===----------------------------------------------------------------------===//
72+
73+
bool memref::LoadOp::loadsFrom(const MemorySlot &slot) {
74+
return getMemRef() == slot.ptr;
75+
}
76+
77+
Value memref::LoadOp::getStored(const MemorySlot &slot) { return {}; }
78+
79+
bool memref::LoadOp::canUsesBeRemoved(
80+
const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
81+
SmallVectorImpl<OpOperand *> &newBlockingUses) {
82+
if (blockingUses.size() != 1)
83+
return false;
84+
Value blockingUse = (*blockingUses.begin())->get();
85+
return blockingUse == slot.ptr && getMemRef() == slot.ptr &&
86+
getResult().getType() == slot.elemType;
87+
}
88+
89+
DeletionKind memref::LoadOp::removeBlockingUses(
90+
const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
91+
OpBuilder &builder, Value reachingDefinition) {
92+
// `canUsesBeRemoved` checked this blocking use must be the loaded slot
93+
// pointer.
94+
getResult().replaceAllUsesWith(reachingDefinition);
95+
return DeletionKind::Delete;
96+
}
97+
98+
bool memref::StoreOp::loadsFrom(const MemorySlot &slot) { return false; }
99+
100+
Value memref::StoreOp::getStored(const MemorySlot &slot) {
101+
if (getMemRef() != slot.ptr)
102+
return {};
103+
return getValue();
104+
}
105+
106+
bool memref::StoreOp::canUsesBeRemoved(
107+
const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
108+
SmallVectorImpl<OpOperand *> &newBlockingUses) {
109+
if (blockingUses.size() != 1)
110+
return false;
111+
Value blockingUse = (*blockingUses.begin())->get();
112+
return blockingUse == slot.ptr && getMemRef() == slot.ptr &&
113+
getValue() != slot.ptr && getValue().getType() == slot.elemType;
114+
}
115+
116+
DeletionKind memref::StoreOp::removeBlockingUses(
117+
const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
118+
OpBuilder &builder, Value reachingDefinition) {
119+
return DeletionKind::Delete;
120+
}

mlir/test/Dialect/MemRef/mem2reg.mlir

Lines changed: 167 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,167 @@
1+
// RUN: mlir-opt %s --pass-pipeline='builtin.module(func.func(mem2reg))' --split-input-file | FileCheck %s
2+
3+
// CHECK-LABEL: func.func @basic
4+
func.func @basic() -> i32 {
5+
// CHECK-NOT: = memref.alloca
6+
// CHECK: %[[RES:.*]] = arith.constant 5 : i32
7+
// CHECK-NOT: = memref.alloca
8+
%0 = arith.constant 5 : i32
9+
%1 = memref.alloca() : memref<i32>
10+
memref.store %0, %1[] : memref<i32>
11+
%2 = memref.load %1[] : memref<i32>
12+
// CHECK: return %[[RES]] : i32
13+
return %2 : i32
14+
}
15+
16+
// -----
17+
18+
// CHECK-LABEL: func.func @basic_default
19+
func.func @basic_default() -> i32 {
20+
// CHECK-NOT: = memref.alloca
21+
// CHECK: %[[RES:.*]] = arith.constant 0 : i32
22+
// CHECK-NOT: = memref.alloca
23+
%0 = arith.constant 5 : i32
24+
%1 = memref.alloca() : memref<i32>
25+
%2 = memref.load %1[] : memref<i32>
26+
// CHECK-NOT: memref.store
27+
memref.store %0, %1[] : memref<i32>
28+
// CHECK: return %[[RES]] : i32
29+
return %2 : i32
30+
}
31+
32+
// -----
33+
34+
// CHECK-LABEL: func.func @basic_float
35+
func.func @basic_float() -> f32 {
36+
// CHECK-NOT: = memref.alloca
37+
// CHECK: %[[RES:.*]] = arith.constant {{.*}} : f32
38+
%0 = arith.constant 5.2 : f32
39+
// CHECK-NOT: = memref.alloca
40+
%1 = memref.alloca() : memref<f32>
41+
memref.store %0, %1[] : memref<f32>
42+
%2 = memref.load %1[] : memref<f32>
43+
// CHECK: return %[[RES]] : f32
44+
return %2 : f32
45+
}
46+
47+
// -----
48+
49+
// CHECK-LABEL: func.func @basic_ranked
50+
func.func @basic_ranked() -> i32 {
51+
// CHECK-NOT: = memref.alloca
52+
// CHECK: %[[RES:.*]] = arith.constant 5 : i32
53+
// CHECK-NOT: = memref.alloca
54+
%0 = arith.constant 0 : index
55+
%1 = arith.constant 5 : i32
56+
%2 = memref.alloca() : memref<1x1xi32>
57+
memref.store %1, %2[%0, %0] : memref<1x1xi32>
58+
%3 = memref.load %2[%0, %0] : memref<1x1xi32>
59+
// CHECK: return %[[RES]] : i32
60+
return %3 : i32
61+
}
62+
63+
// -----
64+
65+
// CHECK-LABEL: func.func @reject_multiple_elements
66+
func.func @reject_multiple_elements() -> i32 {
67+
// CHECK: %[[INDEX:.*]] = arith.constant 0 : index
68+
%0 = arith.constant 0 : index
69+
// CHECK: %[[STORED:.*]] = arith.constant 5 : i32
70+
%1 = arith.constant 5 : i32
71+
// CHECK: %[[ALLOCA:.*]] = memref.alloca()
72+
%2 = memref.alloca() : memref<1x2xi32>
73+
// CHECK: memref.store %[[STORED]], %[[ALLOCA]][%[[INDEX]], %[[INDEX]]]
74+
memref.store %1, %2[%0, %0] : memref<1x2xi32>
75+
// CHECK: %[[RES:.*]] = memref.load %[[ALLOCA]][%[[INDEX]], %[[INDEX]]]
76+
%3 = memref.load %2[%0, %0] : memref<1x2xi32>
77+
// CHECK: return %[[RES]] : i32
78+
return %3 : i32
79+
}
80+
81+
// -----
82+
83+
// CHECK-LABEL: func.func @cycle
84+
// CHECK-SAME: (%[[ARG0:.*]]: i64, %[[ARG1:.*]]: i1, %[[ARG2:.*]]: i64)
85+
func.func @cycle(%arg0: i64, %arg1: i1, %arg2: i64) {
86+
// CHECK-NOT: = memref.alloca
87+
%alloca = memref.alloca() : memref<i64>
88+
memref.store %arg2, %alloca[] : memref<i64>
89+
// CHECK: cf.cond_br %[[ARG1:.*]], ^[[BB1:.*]](%[[ARG2]] : i64), ^[[BB2:.*]](%[[ARG2]] : i64)
90+
cf.cond_br %arg1, ^bb1, ^bb2
91+
// CHECK: ^[[BB1]](%[[USE:.*]]: i64):
92+
^bb1:
93+
%use = memref.load %alloca[] : memref<i64>
94+
// CHECK: call @use(%[[USE]])
95+
func.call @use(%use) : (i64) -> ()
96+
memref.store %arg0, %alloca[] : memref<i64>
97+
// CHECK: cf.br ^[[BB2]](%[[ARG0]] : i64)
98+
cf.br ^bb2
99+
// CHECK: ^[[BB2]](%[[FWD:.*]]: i64):
100+
^bb2:
101+
// CHECK: cf.br ^[[BB1]](%[[FWD]] : i64)
102+
cf.br ^bb1
103+
}
104+
105+
func.func @use(%arg: i64) { return }
106+
107+
// -----
108+
109+
// CHECK-LABEL: func.func @recursive
110+
// CHECK-SAME: (%[[ARG:.*]]: i64)
111+
func.func @recursive(%arg: i64) -> i64 {
112+
// CHECK-NOT: = memref.alloca()
113+
%alloca0 = memref.alloca() : memref<memref<memref<i64>>>
114+
%alloca1 = memref.alloca() : memref<memref<i64>>
115+
%alloca2 = memref.alloca() : memref<i64>
116+
memref.store %arg, %alloca2[] : memref<i64>
117+
memref.store %alloca2, %alloca1[] : memref<memref<i64>>
118+
memref.store %alloca1, %alloca0[] : memref<memref<memref<i64>>>
119+
%load0 = memref.load %alloca0[] : memref<memref<memref<i64>>>
120+
%load1 = memref.load %load0[] : memref<memref<i64>>
121+
%load2 = memref.load %load1[] : memref<i64>
122+
// CHECK: return %[[ARG]] : i64
123+
return %load2 : i64
124+
}
125+
126+
// -----
127+
128+
// CHECK-LABEL: func.func @deny_store_of_alloca
129+
// CHECK-SAME: (%[[ARG:.*]]: memref<memref<i32>>)
130+
func.func @deny_store_of_alloca(%arg: memref<memref<i32>>) -> i32 {
131+
// CHECK: %[[VALUE:.*]] = arith.constant 5 : i32
132+
%0 = arith.constant 5 : i32
133+
// CHECK: %[[ALLOCA:.*]] = memref.alloca
134+
%1 = memref.alloca() : memref<i32>
135+
// Storing into the memref is allowed.
136+
// CHECK: memref.store %[[VALUE]], %[[ALLOCA]][]
137+
memref.store %0, %1[] : memref<i32>
138+
// Storing the memref itself is NOT allowed.
139+
// CHECK: memref.store %[[ALLOCA]], %[[ARG]][]
140+
memref.store %1, %arg[] : memref<memref<i32>>
141+
// CHECK: %[[RES:.*]] = memref.load %[[ALLOCA]][]
142+
%2 = memref.load %1[] : memref<i32>
143+
// CHECK: return %[[RES]] : i32
144+
return %2 : i32
145+
}
146+
147+
// -----
148+
149+
// CHECK-LABEL: func.func @promotable_nonpromotable_intertwined
150+
func.func @promotable_nonpromotable_intertwined() -> i32 {
151+
// CHECK: %[[VAL:.*]] = arith.constant 5 : i32
152+
%0 = arith.constant 5 : i32
153+
// CHECK: %[[NON_PROMOTED:.*]] = memref.alloca() : memref<i32>
154+
%1 = memref.alloca() : memref<i32>
155+
// CHECK-NOT: = memref.alloca() : memref<memref<i32>>
156+
%2 = memref.alloca() : memref<memref<i32>>
157+
memref.store %1, %2[] : memref<memref<i32>>
158+
%3 = memref.load %2[] : memref<memref<i32>>
159+
// CHECK: call @use(%[[NON_PROMOTED]])
160+
call @use(%1) : (memref<i32>) -> ()
161+
// CHECK: %[[RES:.*]] = memref.load %[[NON_PROMOTED]][]
162+
%4 = memref.load %1[] : memref<i32>
163+
// CHECK: return %[[RES]] : i32
164+
return %4 : i32
165+
}
166+
167+
func.func @use(%arg: memref<i32>) { return }

0 commit comments

Comments
 (0)