Skip to content

Commit 47df8c5

Browse files
[MLIR] Updates around MemRef Normalization
The documentation for the NormalizeMemRefs pass and the associated MemRefsNormalizable traits was confusing and not on the website. This update clarifies the language around the difference between a MemRef Type, an operation that accesses the value of MemRef Type, and better documents the limitations of the current implementation. This patch also includes some basic debugging information for the pass so people might have a chance of figuring out why it doesn't work on their code. Differential Revision: https://reviews.llvm.org/D88532
1 parent b8ac19c commit 47df8c5

File tree

4 files changed

+130
-42
lines changed

4 files changed

+130
-42
lines changed

mlir/docs/Traits.md

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -251,13 +251,15 @@ to have [passes](PassManagement.md) scheduled under them.
251251

252252
* `OpTrait::MemRefsNormalizable` -- `MemRefsNormalizable`
253253

254-
This trait is used to flag operations that can accommodate `MemRefs` with
255-
non-identity memory-layout specifications. This trait indicates that the
256-
normalization of memory layout can be performed for such operations.
257-
`MemRefs` normalization consists of replacing an original memory reference
258-
with layout specifications to an equivalent memory reference where
259-
the specified memory layout is applied by rewritting accesses and types
260-
associated with that memory reference.
254+
This trait is used to flag operations that consume or produce
255+
values of `MemRef` type where those references can be 'normalized'.
256+
In cases where an associated `MemRef` has a
257+
non-identity memory-layout specification, such normalizable operations can be
258+
modified so that the `MemRef` has an identity layout specification.
259+
This can be implemented by associating the operation with its own
260+
index expression that can express the equivalent of the memory-layout
261+
specification of the MemRef type. See [the -normalize-memrefs pass].
262+
(https://mlir.llvm.org/docs/Passes/#-normalize-memrefs-normalize-memrefs)
261263

262264
### Single Block with Implicit Terminator
263265

mlir/include/mlir/IR/OpDefinition.h

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1212,13 +1212,8 @@ struct NoRegionArguments : public TraitBase<ConcrentType, NoRegionArguments> {
12121212
}
12131213
};
12141214

1215-
/// This trait is used to flag operations that can accommodate MemRefs with
1216-
/// non-identity memory-layout specifications. This trait indicates that the
1217-
/// normalization of memory layout can be performed for such operations.
1218-
/// MemRefs normalization consists of replacing an original memory reference
1219-
/// with layout specifications to an equivalent memory reference where the
1220-
/// specified memory layout is applied by rewritting accesses and types
1221-
/// associated with that memory reference.
1215+
// This trait is used to flag operations that consume or produce
1216+
// values of `MemRef` type where those references can be 'normalized'.
12221217
// TODO: Right now, the operands of an operation are either all normalizable,
12231218
// or not. In the future, we may want to allow some of the operands to be
12241219
// normalizable.

mlir/include/mlir/Transforms/Passes.td

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -313,6 +313,116 @@ def MemRefDataFlowOpt : FunctionPass<"memref-dataflow-opt"> {
313313

314314
def NormalizeMemRefs : Pass<"normalize-memrefs", "ModuleOp"> {
315315
let summary = "Normalize memrefs";
316+
let description = [{
317+
This pass transforms memref types with a non-trivial
318+
[layout map](https://mlir.llvm.org/docs/LangRef/#layout-map) into
319+
memref types with an identity layout map, e.g. (i, j) -> (i, j). This
320+
pass is inter-procedural, in the sense that it can modify function
321+
interfaces and call sites that pass memref types. In order to modify
322+
memref types while preserving the original behavior, users of those
323+
memref types are also modified to incorporate the resulting layout map.
324+
For instance, an [AffineLoadOp]
325+
(https://mlir.llvm.org/docs/Dialects/Affine/#affineload-affineloadop)
326+
will be updated to compose the layout map with with the affine expression
327+
contained in the op. Operations marked with the [MemRefsNormalizable]
328+
(https://mlir.llvm.org/docs/Traits/#memrefsnormalizable) trait are
329+
expected to be normalizable. Supported operations include affine
330+
operations, std.alloc, std.dealloc, and std.return.
331+
332+
Given an appropriate layout map specified in the code, this transformation
333+
can express tiled or linearized access to multi-dimensional data
334+
structures, but will not modify memref types without an explicit layout
335+
map.
336+
337+
Currently this pass is limited to only modify
338+
functions where all memref types can be normalized. If a function
339+
contains any operations that are not MemRefNormalizable, then the function
340+
and any functions that call or call it will not be modified.
341+
342+
Input
343+
344+
```mlir
345+
#tile = affine_map<(i) -> (i floordiv 4, i mod 4)>
346+
func @matmul(%A: memref<16xf64, #tile>,
347+
%B: index, %C: memref<16xf64>) -> (memref<16xf64, #tile>) {
348+
affine.for %arg3 = 0 to 16 {
349+
%a = affine.load %A[%arg3] : memref<16xf64, #tile>
350+
%p = mulf %a, %a : f64
351+
affine.store %p, %A[%arg3] : memref<16xf64, #tile>
352+
}
353+
%c = alloc() : memref<16xf64, #tile>
354+
%d = affine.load %c[0] : memref<16xf64, #tile>
355+
return %A: memref<16xf64, #tile>
356+
}
357+
```
358+
359+
Output
360+
361+
```mlir
362+
func @matmul(%arg0: memref<4x4xf64>, %arg1: index, %arg2: memref<16xf64>)
363+
-> memref<4x4xf64> {
364+
affine.for %arg3 = 0 to 16 {
365+
%3 = affine.load %arg0[%arg3 floordiv 4, %arg3 mod 4]: memref<4x4xf64>
366+
%4 = mulf %3, %3 : f64
367+
affine.store %4, %arg0[%arg3 floordiv 4, %arg3 mod 4]: memref<4x4xf64>
368+
}
369+
%0 = alloc() : memref<4x4xf64>
370+
%1 = affine.apply #map1()
371+
%2 = affine.load %0[0, 0] : memref<4x4xf64>
372+
return %arg0 : memref<4x4xf64>
373+
}
374+
```
375+
376+
Input
377+
378+
```
379+
#linear8 = affine_map<(i, j) -> (i * 8 + j)>
380+
func @linearize(%arg0: memref<8x8xi32, #linear8>,
381+
%arg1: memref<8x8xi32, #linear8>,
382+
%arg2: memref<8x8xi32, #linear8>) {
383+
%c8 = constant 8 : index
384+
%c0 = constant 0 : index
385+
%c1 = constant 1 : index
386+
affine.for %arg3 = %c0 to %c8 {
387+
affine.for %arg4 = %c0 to %c8 {
388+
affine.for %arg5 = %c0 to %c8 {
389+
%0 = affine.load %arg0[%arg3, %arg5] : memref<8x8xi32, #linear8>
390+
%1 = affine.load %arg1[%arg5, %arg4] : memref<8x8xi32, #linear8>
391+
%2 = affine.load %arg2[%arg3, %arg4] : memref<8x8xi32, #linear8>
392+
%3 = muli %0, %1 : i32
393+
%4 = addi %2, %3 : i32
394+
affine.store %4, %arg2[%arg3, %arg4] : memref<8x8xi32, #linear8>
395+
}
396+
}
397+
}
398+
return
399+
}
400+
```
401+
402+
Output
403+
404+
```mlir
405+
func @linearize(%arg0: memref<64xi32>,
406+
%arg1: memref<64xi32>,
407+
%arg2: memref<64xi32>) {
408+
%c8 = constant 8 : index
409+
%c0 = constant 0 : index
410+
affine.for %arg3 = %c0 to %c8 {
411+
affine.for %arg4 = %c0 to %c8 {
412+
affine.for %arg5 = %c0 to %c8 {
413+
%0 = affine.load %arg0[%arg3 * 8 + %arg5] : memref<64xi32>
414+
%1 = affine.load %arg1[%arg5 * 8 + %arg4] : memref<64xi32>
415+
%2 = affine.load %arg2[%arg3 * 8 + %arg4] : memref<64xi32>
416+
%3 = muli %0, %1 : i32
417+
%4 = addi %2, %3 : i32
418+
affine.store %4, %arg2[%arg3 * 8 + %arg4] : memref<64xi32>
419+
}
420+
}
421+
}
422+
return
423+
}
424+
```
425+
}];
316426
let constructor = "mlir::createNormalizeMemRefsPass()";
317427
}
318428

mlir/lib/Transforms/NormalizeMemRefs.cpp

Lines changed: 9 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -29,34 +29,6 @@ namespace {
2929
/// such functions as normalizable. Also, if a normalizable function is known
3030
/// to call a non-normalizable function, we treat that function as
3131
/// non-normalizable as well. We assume external functions to be normalizable.
32-
///
33-
/// Input :-
34-
/// #tile = affine_map<(i) -> (i floordiv 4, i mod 4)>
35-
/// func @matmul(%A: memref<16xf64, #tile>, %B: index, %C: memref<16xf64>) ->
36-
/// (memref<16xf64, #tile>) {
37-
/// affine.for %arg3 = 0 to 16 {
38-
/// %a = affine.load %A[%arg3] : memref<16xf64, #tile>
39-
/// %p = mulf %a, %a : f64
40-
/// affine.store %p, %A[%arg3] : memref<16xf64, #tile>
41-
/// }
42-
/// %c = alloc() : memref<16xf64, #tile>
43-
/// %d = affine.load %c[0] : memref<16xf64, #tile>
44-
/// return %A: memref<16xf64, #tile>
45-
/// }
46-
///
47-
/// Output :-
48-
/// func @matmul(%arg0: memref<4x4xf64>, %arg1: index, %arg2: memref<16xf64>)
49-
/// -> memref<4x4xf64> {
50-
/// affine.for %arg3 = 0 to 16 {
51-
/// %2 = affine.load %arg0[%arg3 floordiv 4, %arg3 mod 4] :
52-
/// memref<4x4xf64> %3 = mulf %2, %2 : f64 affine.store %3, %arg0[%arg3
53-
/// floordiv 4, %arg3 mod 4] : memref<4x4xf64>
54-
/// }
55-
/// %0 = alloc() : memref<16xf64, #map0>
56-
/// %1 = affine.load %0[0] : memref<16xf64, #map0>
57-
/// return %arg0 : memref<4x4xf64>
58-
/// }
59-
///
6032
struct NormalizeMemRefs : public NormalizeMemRefsBase<NormalizeMemRefs> {
6133
void runOnOperation() override;
6234
void normalizeFuncOpMemRefs(FuncOp funcOp, ModuleOp moduleOp);
@@ -73,6 +45,7 @@ std::unique_ptr<OperationPass<ModuleOp>> mlir::createNormalizeMemRefsPass() {
7345
}
7446

7547
void NormalizeMemRefs::runOnOperation() {
48+
LLVM_DEBUG(llvm::dbgs() << "Normalizing Memrefs...\n");
7649
ModuleOp moduleOp = getOperation();
7750
// We maintain all normalizable FuncOps in a DenseSet. It is initialized
7851
// with all the functions within a module and then functions which are not
@@ -92,6 +65,9 @@ void NormalizeMemRefs::runOnOperation() {
9265
moduleOp.walk([&](FuncOp funcOp) {
9366
if (normalizableFuncs.contains(funcOp)) {
9467
if (!areMemRefsNormalizable(funcOp)) {
68+
LLVM_DEBUG(llvm::dbgs()
69+
<< "@" << funcOp.getName()
70+
<< " contains ops that cannot normalize MemRefs\n");
9571
// Since this function is not normalizable, we set all the caller
9672
// functions and the callees of this function as not normalizable.
9773
// TODO: Drop this conservative assumption in the future.
@@ -101,6 +77,8 @@ void NormalizeMemRefs::runOnOperation() {
10177
}
10278
});
10379

80+
LLVM_DEBUG(llvm::dbgs() << "Normalizing " << normalizableFuncs.size()
81+
<< " functions\n");
10482
// Those functions which can be normalized are subjected to normalization.
10583
for (FuncOp &funcOp : normalizableFuncs)
10684
normalizeFuncOpMemRefs(funcOp, moduleOp);
@@ -127,6 +105,9 @@ void NormalizeMemRefs::setCalleesAndCallersNonNormalizable(
127105
if (!normalizableFuncs.contains(funcOp))
128106
return;
129107

108+
LLVM_DEBUG(
109+
llvm::dbgs() << "@" << funcOp.getName()
110+
<< " calls or is called by non-normalizable function\n");
130111
normalizableFuncs.erase(funcOp);
131112
// Caller of the function.
132113
Optional<SymbolTable::UseRange> symbolUses = funcOp.getSymbolUses(moduleOp);

0 commit comments

Comments
 (0)