Skip to content

Commit b850ce4

Browse files
authored
[MLIR][Affine] Fix private memref creation bug in affine fusion (#126028)
Fix private memref creation bug in affine fusion exposed in the case of the same memref being loaded from/stored to in producer nest. Make the private memref replacement sound. Change affine fusion debug string to affine-fusion - more compact. Fixes: #48703
1 parent ff79d83 commit b850ce4

File tree

4 files changed

+170
-20
lines changed

4 files changed

+170
-20
lines changed

mlir/include/mlir/Dialect/Affine/Analysis/Utils.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -610,6 +610,14 @@ FailureOr<AffineValueMap>
610610
simplifyConstrainedMinMaxOp(Operation *op,
611611
FlatAffineValueConstraints constraints);
612612

613+
/// Find the innermost common `Block` of `a` and `b` in the affine scope
614+
/// that `a` and `b` are part of. Return nullptr if they belong to different
615+
/// affine scopes. Also, return nullptr if they do not have a common `Block`
616+
/// ancestor (for eg., when they are part of the `then` and `else` regions
617+
/// of an op that itself starts an affine scope.
618+
mlir::Block *findInnermostCommonBlockInScope(mlir::Operation *a,
619+
mlir::Operation *b);
620+
613621
} // namespace affine
614622
} // namespace mlir
615623

mlir/lib/Dialect/Affine/Analysis/Utils.cpp

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
//===----------------------------------------------------------------------===//
1313

1414
#include "mlir/Dialect/Affine/Analysis/Utils.h"
15+
1516
#include "mlir/Analysis/Presburger/PresburgerRelation.h"
1617
#include "mlir/Dialect/Affine/Analysis/AffineAnalysis.h"
1718
#include "mlir/Dialect/Affine/Analysis/LoopAnalysis.h"
@@ -2297,3 +2298,41 @@ FailureOr<AffineValueMap> mlir::affine::simplifyConstrainedMinMaxOp(
22972298
affine::canonicalizeMapAndOperands(&newMap, &newOperands);
22982299
return AffineValueMap(newMap, newOperands);
22992300
}
2301+
2302+
Block *mlir::affine::findInnermostCommonBlockInScope(Operation *a,
2303+
Operation *b) {
2304+
Region *aScope = mlir::affine::getAffineScope(a);
2305+
Region *bScope = mlir::affine::getAffineScope(b);
2306+
if (aScope != bScope)
2307+
return nullptr;
2308+
2309+
// Get the block ancestry of `op` while stopping at the affine scope `aScope`
2310+
// and store them in `ancestry`.
2311+
auto getBlockAncestry = [&](Operation *op,
2312+
SmallVectorImpl<Block *> &ancestry) {
2313+
Operation *curOp = op;
2314+
do {
2315+
ancestry.push_back(curOp->getBlock());
2316+
if (curOp->getParentRegion() == aScope)
2317+
break;
2318+
curOp = curOp->getParentOp();
2319+
} while (curOp);
2320+
assert(curOp && "can't reach root op without passing through affine scope");
2321+
std::reverse(ancestry.begin(), ancestry.end());
2322+
};
2323+
2324+
SmallVector<Block *, 4> aAncestors, bAncestors;
2325+
getBlockAncestry(a, aAncestors);
2326+
getBlockAncestry(b, bAncestors);
2327+
assert(!aAncestors.empty() && !bAncestors.empty() &&
2328+
"at least one Block ancestor expected");
2329+
2330+
Block *innermostCommonBlock = nullptr;
2331+
for (unsigned a = 0, b = 0, e = aAncestors.size(), f = bAncestors.size();
2332+
a < e && b < f; ++a, ++b) {
2333+
if (aAncestors[a] != bAncestors[b])
2334+
break;
2335+
innermostCommonBlock = aAncestors[a];
2336+
}
2337+
return innermostCommonBlock;
2338+
}

mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp

Lines changed: 63 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ namespace affine {
4141
} // namespace affine
4242
} // namespace mlir
4343

44-
#define DEBUG_TYPE "affine-loop-fusion"
44+
#define DEBUG_TYPE "affine-fusion"
4545

4646
using namespace mlir;
4747
using namespace mlir::affine;
@@ -237,29 +237,67 @@ static void sinkSequentialLoops(MemRefDependenceGraph::Node *node) {
237237
node->op = newRootForOp;
238238
}
239239

240-
// Creates and returns a private (single-user) memref for fused loop rooted
241-
// at 'forOp', with (potentially reduced) memref size based on the
242-
// MemRefRegion written to by 'srcStoreOpInst' at depth 'dstLoopDepth'.
243-
// TODO: consider refactoring the common code from generateDma and
244-
// this one.
245-
static Value createPrivateMemRef(AffineForOp forOp, Operation *srcStoreOpInst,
240+
/// Get the operation that should act as a dominance filter while replacing
241+
/// memref uses with a private memref for which `producerStores` and
242+
/// `sliceInsertionBlock` are provided. This effectively determines in what
243+
/// part of the IR we should be performing the replacement.
244+
static Operation *
245+
getDominanceFilterForPrivateMemRefRepl(Block *sliceInsertionBlock,
246+
ArrayRef<Operation *> producerStores) {
247+
assert(!producerStores.empty() && "expected producer store");
248+
249+
// We first find the common block that contains the producer stores and
250+
// the slice computation. The first ancestor among the ancestors of the
251+
// producer stores in that common block is the dominance filter to use for
252+
// replacement.
253+
Block *commonBlock = nullptr;
254+
// Find the common block of all relevant operations.
255+
for (Operation *store : producerStores) {
256+
Operation *otherOp =
257+
!commonBlock ? &*sliceInsertionBlock->begin() : &*commonBlock->begin();
258+
commonBlock = findInnermostCommonBlockInScope(store, otherOp);
259+
}
260+
assert(commonBlock &&
261+
"common block of producer stores and slice should exist");
262+
263+
// Find the first ancestor among the ancestors of `producerStores` in
264+
// `commonBlock`.
265+
Operation *firstAncestor = nullptr;
266+
for (Operation *store : producerStores) {
267+
Operation *ancestor = commonBlock->findAncestorOpInBlock(*store);
268+
assert(ancestor && "producer store should be contained in common block");
269+
firstAncestor = !firstAncestor || ancestor->isBeforeInBlock(firstAncestor)
270+
? ancestor
271+
: firstAncestor;
272+
}
273+
return firstAncestor;
274+
}
275+
276+
// Creates and returns a private (single-user) memref for fused loop rooted at
277+
// 'forOp', with (potentially reduced) memref size based on the memref region
278+
// written to by `storeOps` at depth 'dstLoopDepth'. 'sliceInsertionBlock'
279+
// specifies the block in which the slice was/will be inserted.
280+
static Value createPrivateMemRef(AffineForOp forOp,
281+
ArrayRef<Operation *> storeOps,
246282
unsigned dstLoopDepth,
247283
std::optional<unsigned> fastMemorySpace,
284+
Block *sliceInsertionBlock,
248285
uint64_t localBufSizeThreshold) {
249-
Operation *forInst = forOp.getOperation();
286+
assert(!storeOps.empty() && "no source stores supplied");
287+
Operation *srcStoreOp = storeOps[0];
250288

251289
// Create builder to insert alloc op just before 'forOp'.
252-
OpBuilder b(forInst);
290+
OpBuilder b(forOp);
253291
// Builder to create constants at the top level.
254-
OpBuilder top(forInst->getParentRegion());
292+
OpBuilder top(forOp->getParentRegion());
255293
// Create new memref type based on slice bounds.
256-
auto oldMemRef = cast<AffineWriteOpInterface>(srcStoreOpInst).getMemRef();
294+
auto oldMemRef = cast<AffineWriteOpInterface>(srcStoreOp).getMemRef();
257295
auto oldMemRefType = cast<MemRefType>(oldMemRef.getType());
258296
unsigned rank = oldMemRefType.getRank();
259297

260298
// Compute MemRefRegion for 'srcStoreOpInst' at depth 'dstLoopDepth'.
261-
MemRefRegion region(srcStoreOpInst->getLoc());
262-
bool validRegion = succeeded(region.compute(srcStoreOpInst, dstLoopDepth));
299+
MemRefRegion region(srcStoreOp->getLoc());
300+
bool validRegion = succeeded(region.compute(srcStoreOp, dstLoopDepth));
263301
(void)validRegion;
264302
assert(validRegion && "unexpected memref region failure");
265303
SmallVector<int64_t, 4> newShape;
@@ -332,11 +370,12 @@ static Value createPrivateMemRef(AffineForOp forOp, Operation *srcStoreOpInst,
332370
AffineMap::get(outerIVs.size() + rank, 0, remapExprs, forOp.getContext());
333371

334372
// Replace all users of 'oldMemRef' with 'newMemRef'.
335-
LogicalResult res =
336-
replaceAllMemRefUsesWith(oldMemRef, newMemRef, {}, indexRemap,
337-
/*extraOperands=*/outerIVs,
338-
/*symbolOperands=*/{},
339-
/*domOpFilter=*/&*forOp.getBody()->begin());
373+
Operation *domFilter =
374+
getDominanceFilterForPrivateMemRefRepl(sliceInsertionBlock, storeOps);
375+
LogicalResult res = replaceAllMemRefUsesWith(
376+
oldMemRef, newMemRef, /*extraIndices=*/{}, indexRemap,
377+
/*extraOperands=*/outerIVs,
378+
/*symbolOperands=*/{}, domFilter);
340379
assert(succeeded(res) &&
341380
"replaceAllMemrefUsesWith should always succeed here");
342381
(void)res;
@@ -944,6 +983,10 @@ struct GreedyFusion {
944983

945984
// Create private memrefs.
946985
if (!privateMemrefs.empty()) {
986+
// Note the block into which fusion was performed. This can be used to
987+
// place `alloc`s that create private memrefs.
988+
Block *sliceInsertionBlock = bestSlice.insertPoint->getBlock();
989+
947990
// Gather stores for all the private-to-be memrefs.
948991
DenseMap<Value, SmallVector<Operation *, 4>> privateMemRefToStores;
949992
dstAffineForOp.walk([&](AffineWriteOpInterface storeOp) {
@@ -962,8 +1005,8 @@ struct GreedyFusion {
9621005
SmallVector<Operation *, 4> &storesForMemref =
9631006
memrefToStoresPair.second;
9641007
Value newMemRef = createPrivateMemRef(
965-
dstAffineForOp, storesForMemref[0], bestDstLoopDepth,
966-
fastMemorySpace, localBufSizeThreshold);
1008+
dstAffineForOp, storesForMemref, bestDstLoopDepth,
1009+
fastMemorySpace, sliceInsertionBlock, localBufSizeThreshold);
9671010
// Create new node in dependence graph for 'newMemRef' alloc op.
9681011
unsigned newMemRefNodeId = mdg->addNode(newMemRef.getDefiningOp());
9691012
// Add edge from 'newMemRef' node to dstNode.

mlir/test/Dialect/Affine/loop-fusion-4.mlir

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -285,3 +285,63 @@ module {
285285
spirv.ReturnValue %3 : !spirv.array<8192 x f32>
286286
}
287287
}
288+
289+
// -----
290+
291+
// PRODUCER-CONSUMER-LABEL: func @same_memref_load_store
292+
func.func @same_memref_load_store(%producer : memref<32xf32>, %consumer: memref<16xf32>){
293+
%cst = arith.constant 2.000000e+00 : f32
294+
// Source isn't removed.
295+
// PRODUCER-CONSUMER: affine.for %{{.*}} = 0 to 32
296+
affine.for %arg3 = 0 to 32 {
297+
%0 = affine.load %producer[%arg3] : memref<32xf32>
298+
%2 = arith.mulf %0, %cst : f32
299+
affine.store %2, %producer[%arg3] : memref<32xf32>
300+
}
301+
affine.for %arg3 = 0 to 16 {
302+
%0 = affine.load %producer[%arg3] : memref<32xf32>
303+
%2 = arith.addf %0, %cst : f32
304+
affine.store %2, %consumer[%arg3] : memref<16xf32>
305+
}
306+
// Fused nest.
307+
// PRODUCER-CONSUMER: affine.for %{{.*}} = 0 to 16
308+
// PRODUCER-CONSUMER-NEXT: affine.load %{{.*}}[%{{.*}}] : memref<32xf32>
309+
// PRODUCER-CONSUMER-NEXT: arith.mulf
310+
// PRODUCER-CONSUMER-NEXT: affine.store %{{.*}}, %{{.*}}[0] : memref<1xf32>
311+
// PRODUCER-CONSUMER-NEXT: affine.load %{{.*}}[0] : memref<1xf32>
312+
// PRODUCER-CONSUMER-NEXT: arith.addf
313+
// PRODUCER-CONSUMER-NEXT: affine.store
314+
// PRODUCER-CONSUMER-NEXT: }
315+
return
316+
}
317+
318+
// PRODUCER-CONSUMER-LABEL: func @same_memref_load_multiple_stores
319+
func.func @same_memref_load_multiple_stores(%producer : memref<32xf32>, %producer_2 : memref<32xf32>, %consumer: memref<16xf32>){
320+
%cst = arith.constant 2.000000e+00 : f32
321+
// Source isn't removed.
322+
// PRODUCER-CONSUMER: affine.for %{{.*}} = 0 to 32
323+
affine.for %arg3 = 0 to 32 {
324+
%0 = affine.load %producer[%arg3] : memref<32xf32>
325+
%2 = arith.mulf %0, %cst : f32
326+
affine.store %2, %producer[%arg3] : memref<32xf32>
327+
affine.store %2, %producer_2[%arg3] : memref<32xf32>
328+
}
329+
affine.for %arg3 = 0 to 16 {
330+
%0 = affine.load %producer[%arg3] : memref<32xf32>
331+
%1 = affine.load %producer_2[%arg3] : memref<32xf32>
332+
%2 = arith.addf %0, %1 : f32
333+
affine.store %2, %consumer[%arg3] : memref<16xf32>
334+
}
335+
// Fused nest.
336+
// PRODUCER-CONSUMER: affine.for %{{.*}} = 0 to 16
337+
// PRODUCER-CONSUMER-NEXT: affine.load %{{.*}}[%{{.*}}] : memref<32xf32>
338+
// PRODUCER-CONSUMER-NEXT: arith.mulf
339+
// PRODUCER-CONSUMER-NEXT: affine.store %{{.*}}, %{{.*}}[0] : memref<1xf32>
340+
// PRODUCER-CONSUMER-NEXT: affine.store %{{.*}}, %{{.*}}[0] : memref<1xf32>
341+
// PRODUCER-CONSUMER-NEXT: affine.load %{{.*}}[0] : memref<1xf32>
342+
// PRODUCER-CONSUMER-NEXT: affine.load %{{.*}}[0] : memref<1xf32>
343+
// PRODUCER-CONSUMER-NEXT: arith.addf
344+
// PRODUCER-CONSUMER-NEXT: affine.store
345+
// PRODUCER-CONSUMER-NEXT: }
346+
return
347+
}

0 commit comments

Comments
 (0)