Skip to content

Commit a14a280

Browse files
AlexandreEichenbergerbondhugula
authored andcommitted
[MLIR] MemRef Normalization for Dialects
When dealing with dialects that will results in function calls to external libraries, it is important to be able to handle maps as some dialects may require mapped data. Before this patch, the detection of whether normalization can apply or not, operations are compared to an explicit list of operations (`alloc`, `dealloc`, `return`) or to the presence of specific operation interfaces (`AffineReadOpInterface`, `AffineWriteOpInterface`, `AffineDMAStartOp`, or `AffineDMAWaitOp`). This patch add a trait, `MemRefsNormalizable` to determine if an operation can have its `memrefs` normalized. This trait can be used in turn by dialects to assert that such operations are compatible with normalization of `memrefs` with nontrivial memory layout specification. An example is given in the literal tests. Differential Revision: https://reviews.llvm.org/D86236
1 parent b5924a8 commit a14a280

File tree

10 files changed

+114
-21
lines changed

10 files changed

+114
-21
lines changed

mlir/docs/Traits.md

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -247,6 +247,18 @@ foo.region_op {
247247
This trait is an important structural property of the IR, and enables operations
248248
to have [passes](PassManagement.md) scheduled under them.
249249

250+
### MemRefsNormalizable
251+
252+
* `OpTrait::MemRefsNormalizable` -- `MemRefsNormalizable`
253+
254+
This trait is used to flag operations that can accommodate `MemRefs` with
255+
non-identity memory-layout specifications. This trait indicates that the
256+
normalization of memory layout can be performed for such operations.
257+
`MemRefs` normalization consists of replacing an original memory reference
258+
with layout specifications to an equivalent memory reference where
259+
the specified memory layout is applied by rewritting accesses and types
260+
associated with that memory reference.
261+
250262
### Single Block with Implicit Terminator
251263

252264
* `OpTrait::SingleBlockImplicitTerminator<typename TerminatorOpType>` :

mlir/include/mlir/Dialect/Affine/IR/AffineOps.h

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -80,8 +80,9 @@ bool isTopLevelValue(Value value);
8080
// multiple stride levels (possibly using AffineMaps to specify multiple levels
8181
// of striding).
8282
// TODO: Consider replacing src/dst memref indices with view memrefs.
83-
class AffineDmaStartOp : public Op<AffineDmaStartOp, OpTrait::VariadicOperands,
84-
OpTrait::ZeroResult> {
83+
class AffineDmaStartOp
84+
: public Op<AffineDmaStartOp, OpTrait::MemRefsNormalizable,
85+
OpTrait::VariadicOperands, OpTrait::ZeroResult> {
8586
public:
8687
using Op::Op;
8788

@@ -268,8 +269,9 @@ class AffineDmaStartOp : public Op<AffineDmaStartOp, OpTrait::VariadicOperands,
268269
// ...
269270
// affine.dma_wait %tag[%index], %num_elements : memref<1xi32, 2>
270271
//
271-
class AffineDmaWaitOp : public Op<AffineDmaWaitOp, OpTrait::VariadicOperands,
272-
OpTrait::ZeroResult> {
272+
class AffineDmaWaitOp
273+
: public Op<AffineDmaWaitOp, OpTrait::MemRefsNormalizable,
274+
OpTrait::VariadicOperands, OpTrait::ZeroResult> {
273275
public:
274276
using Op::Op;
275277

mlir/include/mlir/Dialect/Affine/IR/AffineOps.td

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -405,7 +405,8 @@ def AffineIfOp : Affine_Op<"if",
405405

406406
class AffineLoadOpBase<string mnemonic, list<OpTrait> traits = []> :
407407
Affine_Op<mnemonic, !listconcat(traits,
408-
[DeclareOpInterfaceMethods<AffineReadOpInterface>])> {
408+
[DeclareOpInterfaceMethods<AffineReadOpInterface>,
409+
MemRefsNormalizable])> {
409410
let arguments = (ins Arg<AnyMemRef, "the reference to load from",
410411
[MemRead]>:$memref,
411412
Variadic<Index>:$indices);
@@ -732,7 +733,8 @@ def AffinePrefetchOp : Affine_Op<"prefetch"> {
732733

733734
class AffineStoreOpBase<string mnemonic, list<OpTrait> traits = []> :
734735
Affine_Op<mnemonic, !listconcat(traits,
735-
[DeclareOpInterfaceMethods<AffineWriteOpInterface>])> {
736+
[DeclareOpInterfaceMethods<AffineWriteOpInterface>,
737+
MemRefsNormalizable])> {
736738
code extraClassDeclarationBase = [{
737739
/// Returns the operand index of the value to be stored.
738740
unsigned getStoredValOperandIndex() { return 0; }

mlir/include/mlir/Dialect/StandardOps/IR/Ops.td

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -658,7 +658,7 @@ def BranchOp : Std_Op<"br",
658658
// CallOp
659659
//===----------------------------------------------------------------------===//
660660

661-
def CallOp : Std_Op<"call", [CallOpInterface]> {
661+
def CallOp : Std_Op<"call", [CallOpInterface, MemRefsNormalizable]> {
662662
let summary = "call operation";
663663
let description = [{
664664
The `call` operation represents a direct call to a function that is within
@@ -1388,7 +1388,8 @@ def SinOp : FloatUnaryOp<"sin"> {
13881388
// DeallocOp
13891389
//===----------------------------------------------------------------------===//
13901390

1391-
def DeallocOp : Std_Op<"dealloc", [MemoryEffects<[MemFree]>]> {
1391+
def DeallocOp : Std_Op<"dealloc",
1392+
[MemoryEffects<[MemFree]>, MemRefsNormalizable]> {
13921393
let summary = "memory deallocation operation";
13931394
let description = [{
13941395
The `dealloc` operation frees the region of memory referenced by a memref
@@ -2144,8 +2145,8 @@ def RemFOp : FloatArithmeticOp<"remf"> {
21442145
// ReturnOp
21452146
//===----------------------------------------------------------------------===//
21462147

2147-
def ReturnOp : Std_Op<"return", [NoSideEffect, HasParent<"FuncOp">, ReturnLike,
2148-
Terminator]> {
2148+
def ReturnOp : Std_Op<"return", [NoSideEffect, HasParent<"FuncOp">,
2149+
MemRefsNormalizable, ReturnLike, Terminator]> {
21492150
let summary = "return operation";
21502151
let description = [{
21512152
The `return` operation represents a return operation within a function.

mlir/include/mlir/IR/OpBase.td

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1698,6 +1698,9 @@ def SameOperandsAndResultElementType :
16981698
NativeOpTrait<"SameOperandsAndResultElementType">;
16991699
// Op is a terminator.
17001700
def Terminator : NativeOpTrait<"IsTerminator">;
1701+
// Op can be safely normalized in the presence of MemRefs with
1702+
// non-identity maps.
1703+
def MemRefsNormalizable : NativeOpTrait<"MemRefsNormalizable">;
17011704

17021705
// Op's regions have a single block with the specified terminator.
17031706
class SingleBlockImplicitTerminator<string op>

mlir/include/mlir/IR/OpDefinition.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1212,6 +1212,20 @@ struct NoRegionArguments : public TraitBase<ConcrentType, NoRegionArguments> {
12121212
}
12131213
};
12141214

1215+
/// This trait is used to flag operations that can accommodate MemRefs with
1216+
/// non-identity memory-layout specifications. This trait indicates that the
1217+
/// normalization of memory layout can be performed for such operations.
1218+
/// MemRefs normalization consists of replacing an original memory reference
1219+
/// with layout specifications to an equivalent memory reference where the
1220+
/// specified memory layout is applied by rewritting accesses and types
1221+
/// associated with that memory reference.
1222+
// TODO: Right now, the operands of an operation are either all normalizable,
1223+
// or not. In the future, we may want to allow some of the operands to be
1224+
// normalizable.
1225+
template <typename ConcrentType>
1226+
struct MemRefsNormalizable
1227+
: public TraitBase<ConcrentType, MemRefsNormalizable> {};
1228+
12151229
} // end namespace OpTrait
12161230

12171231
//===----------------------------------------------------------------------===//

mlir/lib/Transforms/NormalizeMemRefs.cpp

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -106,23 +106,15 @@ void NormalizeMemRefs::runOnOperation() {
106106
normalizeFuncOpMemRefs(funcOp, moduleOp);
107107
}
108108

109-
/// Return true if this operation dereferences one or more memref's.
110-
/// TODO: Temporary utility, will be replaced when this is modeled through
111-
/// side-effects/op traits.
112-
static bool isMemRefDereferencingOp(Operation &op) {
113-
return isa<AffineReadOpInterface, AffineWriteOpInterface, AffineDmaStartOp,
114-
AffineDmaWaitOp>(op);
115-
}
116-
117109
/// Check whether all the uses of oldMemRef are either dereferencing uses or the
118110
/// op is of type : DeallocOp, CallOp or ReturnOp. Only if these constraints
119111
/// are satisfied will the value become a candidate for replacement.
120112
/// TODO: Extend this for DimOps.
121113
static bool isMemRefNormalizable(Value::user_range opUsers) {
122114
if (llvm::any_of(opUsers, [](Operation *op) {
123-
if (isMemRefDereferencingOp(*op))
115+
if (op->hasTrait<OpTrait::MemRefsNormalizable>())
124116
return false;
125-
return !isa<DeallocOp, CallOp, ReturnOp>(*op);
117+
return true;
126118
}))
127119
return false;
128120
return true;

mlir/lib/Transforms/Utils/Utils.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -279,7 +279,7 @@ LogicalResult mlir::replaceAllMemRefUsesWith(
279279
// Currently we support the following non-dereferencing ops to be a
280280
// candidate for replacement: Dealloc, CallOp and ReturnOp.
281281
// TODO: Add support for other kinds of ops.
282-
if (!isa<DeallocOp, CallOp, ReturnOp>(*op))
282+
if (!op->hasTrait<OpTrait::MemRefsNormalizable>())
283283
return failure();
284284
}
285285

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
// RUN: mlir-opt -normalize-memrefs %s | FileCheck %s
2+
3+
// For all these cases, we test if MemRefs Normalization works with the test
4+
// operations.
5+
// * test.op_norm: this operation has the MemRefsNormalizable attribute. The tests
6+
// that include this operation are constructed so that the normalization should
7+
// happen.
8+
// * test_op_nonnorm: this operation does not have the MemRefsNormalization
9+
// attribute. The tests that include this operation are contructed so that the
10+
// normalization should not happen.
11+
12+
#map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2 floordiv 32, d3 floordiv 64, d2 mod 32, d3 mod 64)>
13+
14+
// Test with op_norm and maps in arguments and in the operations in the function.
15+
16+
// CHECK-LABEL: test_norm
17+
// CHECK-SAME: (%[[ARG0:[a-z0-9]*]]: memref<1x16x1x1x32x64xf32>)
18+
func @test_norm(%arg0 : memref<1x16x14x14xf32, #map0>) -> () {
19+
%0 = alloc() : memref<1x16x14x14xf32, #map0>
20+
"test.op_norm"(%arg0, %0) : (memref<1x16x14x14xf32, #map0>, memref<1x16x14x14xf32, #map0>) -> ()
21+
dealloc %0 : memref<1x16x14x14xf32, #map0>
22+
23+
// CHECK: %[[v0:[a-z0-9]*]] = alloc() : memref<1x16x1x1x32x64xf32>
24+
// CHECK: "test.op_norm"(%[[ARG0]], %[[v0]]) : (memref<1x16x1x1x32x64xf32>, memref<1x16x1x1x32x64xf32>) -> ()
25+
// CHECK: dealloc %[[v0]] : memref<1x16x1x1x32x64xf32>
26+
return
27+
}
28+
29+
// Same test with op_nonnorm, with maps in the argmentets and the operations in the function.
30+
31+
// CHECK-LABEL: test_nonnorm
32+
// CHECK-SAME: (%[[ARG0:[a-z0-9]*]]: memref<1x16x14x14xf32, #map0>)
33+
func @test_nonnorm(%arg0 : memref<1x16x14x14xf32, #map0>) -> () {
34+
%0 = alloc() : memref<1x16x14x14xf32, #map0>
35+
"test.op_nonnorm"(%arg0, %0) : (memref<1x16x14x14xf32, #map0>, memref<1x16x14x14xf32, #map0>) -> ()
36+
dealloc %0 : memref<1x16x14x14xf32, #map0>
37+
38+
// CHECK: %[[v0:[a-z0-9]*]] = alloc() : memref<1x16x14x14xf32, #map0>
39+
// CHECK: "test.op_nonnorm"(%[[ARG0]], %[[v0]]) : (memref<1x16x14x14xf32, #map0>, memref<1x16x14x14xf32, #map0>) -> ()
40+
// CHECK: dealloc %[[v0]] : memref<1x16x14x14xf32, #map0>
41+
return
42+
}
43+
44+
// Test with op_norm, with maps in the operations in the function.
45+
46+
// CHECK-LABEL: test_norm_mix
47+
// CHECK-SAME: (%[[ARG0:[a-z0-9]*]]: memref<1x16x1x1x32x64xf32>
48+
func @test_norm_mix(%arg0 : memref<1x16x1x1x32x64xf32>) -> () {
49+
%0 = alloc() : memref<1x16x14x14xf32, #map0>
50+
"test.op_norm"(%arg0, %0) : (memref<1x16x1x1x32x64xf32>, memref<1x16x14x14xf32, #map0>) -> ()
51+
dealloc %0 : memref<1x16x14x14xf32, #map0>
52+
53+
// CHECK: %[[v0:[a-z0-9]*]] = alloc() : memref<1x16x1x1x32x64xf32>
54+
// CHECK: "test.op_norm"(%[[ARG0]], %[[v0]]) : (memref<1x16x1x1x32x64xf32>, memref<1x16x1x1x32x64xf32>) -> ()
55+
// CHECK: dealloc %[[v0]] : memref<1x16x1x1x32x64xf32>
56+
return
57+
}

mlir/test/lib/Dialect/Test/TestOps.td

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -618,6 +618,16 @@ def OpM : TEST_Op<"op_m"> {
618618
let arguments = (ins I32, OptionalAttr<I32Attr>:$optional_attr);
619619
let results = (outs I32);
620620
}
621+
622+
// Test for memrefs normalization of an op with normalizable memrefs.
623+
def OpNorm : TEST_Op<"op_norm", [MemRefsNormalizable]> {
624+
let arguments = (ins AnyMemRef:$X, AnyMemRef:$Y);
625+
}
626+
// Test for memrefs normalization of an op without normalizable memrefs.
627+
def OpNonNorm : TEST_Op<"op_nonnorm"> {
628+
let arguments = (ins AnyMemRef:$X, AnyMemRef:$Y);
629+
}
630+
621631
// Pattern add the argument plus a increasing static number hidden in
622632
// OpMTest function. That value is set into the optional argument.
623633
// That way, we will know if operations is called once or twice.

0 commit comments

Comments
 (0)