Skip to content

Commit ddeb55a

Browse files
committed
[mlir] add option to multi-buffering
Allow user to apply multi-buffering transformation for cases where proving that there is no loop carried dependency is not trivial. In this case user needs to ensure that the data are written and read in the same iteration otherwise the result is incorrect. Differential Revision: https://reviews.llvm.org/D144227
1 parent 4ecc6af commit ddeb55a

File tree

5 files changed

+73
-17
lines changed

5 files changed

+73
-17
lines changed

mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.td

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,9 @@ def MemRefMultiBufferOp : Op<Transform_Dialect, "memref.multibuffer",
2828
iterations. This transform expands the size of an allocation by
2929
a given multiplicative factor and fixes up any users of the
3030
multibuffered allocation.
31+
If skip analysis is not set the transformation will only apply
32+
if it can prove that there is no data being carried across loop
33+
iterations.
3134

3235
#### Return modes
3336

@@ -37,7 +40,8 @@ def MemRefMultiBufferOp : Op<Transform_Dialect, "memref.multibuffer",
3740

3841
let arguments =
3942
(ins Transform_MemRefAllocOp:$target,
40-
ConfinedAttr<I64Attr, [IntPositive]>:$factor);
43+
ConfinedAttr<I64Attr, [IntPositive]>:$factor,
44+
UnitAttr:$skip_analysis);
4145

4246
let results = (outs PDL_Operation:$transformed);
4347

mlir/include/mlir/Dialect/MemRef/Transforms/Passes.h

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,10 @@ void populateMemRefWideIntEmulationConversions(
7878
/// on the temporary allocation between consecutive loop iterations.
7979
/// It returns the new allocation if the original allocation was multi-buffered
8080
/// and returns failure() otherwise.
81-
/// Example:
81+
/// When `skipOverrideAnalysis`, the pass will apply the transformation
82+
/// without checking thwt the buffer is overrided at the beginning of each
83+
/// iteration. This implies that user knows that there is no data carried across
84+
/// loop iterations. Example:
8285
/// ```
8386
/// %0 = memref.alloc() : memref<4x128xf32>
8487
/// scf.for %iv = %c1 to %c1024 step %c3 {
@@ -100,7 +103,8 @@ void populateMemRefWideIntEmulationConversions(
100103
/// }
101104
/// ```
102105
FailureOr<memref::AllocOp> multiBuffer(memref::AllocOp allocOp,
103-
unsigned multiplier);
106+
unsigned multiplier,
107+
bool skipOverrideAnalysis = false);
104108

105109
//===----------------------------------------------------------------------===//
106110
// Passes

mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,8 @@ DiagnosedSilenceableFailure transform::MemRefMultiBufferOp::apply(
4141
if (!canApplyMultiBuffer)
4242
continue;
4343

44-
auto newBuffer = memref::multiBuffer(target, getFactor());
44+
auto newBuffer =
45+
memref::multiBuffer(target, getFactor(), getSkipAnalysis());
4546
if (failed(newBuffer))
4647
return emitSilenceableFailure(target->getLoc())
4748
<< "op failed to multibuffer";

mlir/lib/Dialect/MemRef/Transforms/MultiBuffer.cpp

Lines changed: 26 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -82,8 +82,9 @@ static Value getOrCreateValue(OpFoldResult res, OpBuilder &builder,
8282
// Returns success if the transformation happened and failure otherwise.
8383
// This is not a pattern as it requires propagating the new memref type to its
8484
// uses and requires updating subview ops.
85-
FailureOr<memref::AllocOp> mlir::memref::multiBuffer(memref::AllocOp allocOp,
86-
unsigned multiplier) {
85+
FailureOr<memref::AllocOp>
86+
mlir::memref::multiBuffer(memref::AllocOp allocOp, unsigned multiplier,
87+
bool skipOverrideAnalysis) {
8788
LLVM_DEBUG(DBGS() << "Try multibuffer: " << allocOp << "\n");
8889
DominanceInfo dom(allocOp->getParentOp());
8990
LoopLikeOpInterface candidateLoop;
@@ -93,17 +94,29 @@ FailureOr<memref::AllocOp> mlir::memref::multiBuffer(memref::AllocOp allocOp,
9394
LLVM_DEBUG(DBGS() << "Skip user: no parent loop\n");
9495
return failure();
9596
}
96-
/// Make sure there is no loop-carried dependency on the allocation.
97-
if (!overrideBuffer(user, allocOp.getResult())) {
98-
LLVM_DEBUG(DBGS() << "Skip user: found loop-carried dependence\n");
99-
continue;
100-
}
101-
// If this user doesn't dominate all the other users keep looking.
102-
if (llvm::any_of(allocOp->getUsers(), [&](Operation *otherUser) {
103-
return !dom.dominates(user, otherUser);
104-
})) {
105-
LLVM_DEBUG(DBGS() << "Skip user: does not dominate all other users\n");
106-
continue;
97+
if (!skipOverrideAnalysis) {
98+
/// Make sure there is no loop-carried dependency on the allocation.
99+
if (!overrideBuffer(user, allocOp.getResult())) {
100+
LLVM_DEBUG(DBGS() << "Skip user: found loop-carried dependence\n");
101+
continue;
102+
}
103+
// If this user doesn't dominate all the other users keep looking.
104+
if (llvm::any_of(allocOp->getUsers(), [&](Operation *otherUser) {
105+
return !dom.dominates(user, otherUser);
106+
})) {
107+
LLVM_DEBUG(DBGS() << "Skip user: does not dominate all other users\n");
108+
continue;
109+
}
110+
} else {
111+
if (llvm::any_of(allocOp->getUsers(), [&](Operation *otherUser) {
112+
return !isa<memref::DeallocOp>(otherUser) &&
113+
!parentLoop->isProperAncestor(otherUser);
114+
})) {
115+
LLVM_DEBUG(
116+
DBGS()
117+
<< "Skip user: not all other users are in the parent loop\n");
118+
continue;
119+
}
107120
}
108121
candidateLoop = parentLoop;
109122
break;

mlir/test/Dialect/MemRef/transform-ops.mlir

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,3 +185,37 @@ transform.sequence failures(propagate) {
185185
// Verify that the returned handle is usable.
186186
transform.test_print_remark_at_operand %1, "transformed" : !pdl.operation
187187
}
188+
189+
// -----
190+
191+
192+
// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0) -> ((d0 floordiv 4) mod 2)>
193+
194+
// CHECK-LABEL: func @multi_buffer
195+
func.func @multi_buffer_no_analysis(%in: memref<16xf32>) {
196+
// CHECK: %[[A:.*]] = memref.alloc() : memref<2x4xf32>
197+
// expected-remark @below {{transformed}}
198+
%tmp = memref.alloc() : memref<4xf32>
199+
200+
// CHECK: %[[C0:.*]] = arith.constant 0 : index
201+
// CHECK: %[[C4:.*]] = arith.constant 4 : index
202+
%c0 = arith.constant 0 : index
203+
%c4 = arith.constant 4 : index
204+
%c16 = arith.constant 16 : index
205+
206+
// CHECK: scf.for %[[IV:.*]] = %[[C0]]
207+
scf.for %i0 = %c0 to %c16 step %c4 {
208+
// CHECK: %[[I:.*]] = affine.apply #[[$MAP0]](%[[IV]])
209+
// CHECK: %[[SV:.*]] = memref.subview %[[A]][%[[I]], 0] [1, 4] [1, 1] : memref<2x4xf32> to memref<4xf32, strided<[1], offset: ?>>
210+
"some_write_read"(%tmp) : (memref<4xf32>) ->()
211+
}
212+
return
213+
}
214+
215+
transform.sequence failures(propagate) {
216+
^bb1(%arg1: !pdl.operation):
217+
%0 = transform.structured.match ops{["memref.alloc"]} in %arg1 : (!pdl.operation) -> !transform.op<"memref.alloc">
218+
%1 = transform.memref.multibuffer %0 {factor = 2 : i64, skip_analysis} : (!transform.op<"memref.alloc">) -> !pdl.operation
219+
// Verify that the returned handle is usable.
220+
transform.test_print_remark_at_operand %1, "transformed" : !pdl.operation
221+
}

0 commit comments

Comments
 (0)