-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[MLIR][Affine] Fix private memref creation bug in affine fusion #126028
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[MLIR][Affine] Fix private memref creation bug in affine fusion #126028
Conversation
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-affine Author: Uday Bondhugula (bondhugula) ChangesFix private memref creation bug in affine fusion exposed in the case of Change affine fusion debug string to affine-fusion - more compact. Full diff: https://github.com/llvm/llvm-project/pull/126028.diff 4 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Affine/Analysis/Utils.h b/mlir/include/mlir/Dialect/Affine/Analysis/Utils.h
index b1fbf4477428ca2..cff4983af17fc90 100644
--- a/mlir/include/mlir/Dialect/Affine/Analysis/Utils.h
+++ b/mlir/include/mlir/Dialect/Affine/Analysis/Utils.h
@@ -610,6 +610,14 @@ FailureOr<AffineValueMap>
simplifyConstrainedMinMaxOp(Operation *op,
FlatAffineValueConstraints constraints);
+/// Find the innermost common `Block` of `a` and `b` in the affine scope
+/// that `a` and `b` are part of. Return nullptr if they belong to different
+/// affine scopes. Also, return null if they do not have a common `Block`
+/// ancestor (for eg., when they are part of the `then` and `else` regions
+/// of an op that itself starts an affine scope.
+mlir::Block *findInnermostCommonBlockInScope(mlir::Operation *a,
+ mlir::Operation *b);
+
} // namespace affine
} // namespace mlir
diff --git a/mlir/lib/Dialect/Affine/Analysis/Utils.cpp b/mlir/lib/Dialect/Affine/Analysis/Utils.cpp
index 9c0b5dbf52d299b..d6c62cdd613643e 100644
--- a/mlir/lib/Dialect/Affine/Analysis/Utils.cpp
+++ b/mlir/lib/Dialect/Affine/Analysis/Utils.cpp
@@ -12,6 +12,7 @@
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Affine/Analysis/Utils.h"
+
#include "mlir/Analysis/Presburger/PresburgerRelation.h"
#include "mlir/Dialect/Affine/Analysis/AffineAnalysis.h"
#include "mlir/Dialect/Affine/Analysis/LoopAnalysis.h"
@@ -2297,3 +2298,40 @@ FailureOr<AffineValueMap> mlir::affine::simplifyConstrainedMinMaxOp(
affine::canonicalizeMapAndOperands(&newMap, &newOperands);
return AffineValueMap(newMap, newOperands);
}
+
+Block *mlir::affine::findInnermostCommonBlockInScope(Operation *a,
+ Operation *b) {
+ Region *aScope = mlir::affine::getAffineScope(a);
+ Region *bScope = mlir::affine::getAffineScope(b);
+ if (aScope != bScope)
+ return nullptr;
+
+ // Get the block ancestry of `a` while stopping at the affine scope.
+ auto getBlockAncestry = [&](Operation *op,
+ SmallVectorImpl<Block *> &ancestry) {
+ Operation *curOp = op;
+ do {
+ ancestry.push_back(curOp->getBlock());
+ if (curOp->getParentRegion() == aScope)
+ break;
+ curOp = curOp->getParentOp();
+ } while (curOp);
+ assert(curOp && "can't reach root op without passing through affine scope");
+ std::reverse(ancestry.begin(), ancestry.end());
+ };
+
+ SmallVector<Block *, 4> aAncestors, bAncestors;
+ getBlockAncestry(a, aAncestors);
+ getBlockAncestry(b, bAncestors);
+ assert(!aAncestors.empty() && !bAncestors.empty() &&
+ "at least one Block ancestor expected");
+
+ Block *innermostCommonBlock = nullptr;
+ for (unsigned a = 0, b = 0, e = aAncestors.size(), f = bAncestors.size();
+ a < e && b < f; ++a, ++b) {
+ if (aAncestors[a] != bAncestors[b])
+ break;
+ innermostCommonBlock = aAncestors[a];
+ }
+ return innermostCommonBlock;
+}
diff --git a/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp b/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp
index c22ec213be95c84..0ea27df704d0694 100644
--- a/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp
@@ -41,7 +41,7 @@ namespace affine {
} // namespace affine
} // namespace mlir
-#define DEBUG_TYPE "affine-loop-fusion"
+#define DEBUG_TYPE "affine-fusion"
using namespace mlir;
using namespace mlir::affine;
@@ -237,29 +237,71 @@ static void sinkSequentialLoops(MemRefDependenceGraph::Node *node) {
node->op = newRootForOp;
}
+/// Get the operation that should act as a dominance filter while replacing
+/// memref uses with a private memref for which `producerStores` and
+/// `sliceInsertionBlock` are provided. This effectively determines in what
+/// part of the IR we should be performing the replacement.
+static Operation *
+getDominanceFilterForPrivateMemRefRepl(Block *sliceInsertionBlock,
+ ArrayRef<Operation *> producerStores) {
+ assert(!producerStores.empty() && "expected producer store");
+
+ // We first find the common block that contains the producer stores and
+ // the slice computation. The first ancestor among the ancestors of the
+ // producer stores in that common block is the dominance filter to use for
+ // replacement.
+ Block *commonBlock = nullptr;
+ // Find the common block of all relevant operations.
+ for (Operation *store : producerStores) {
+ if (!commonBlock)
+ commonBlock = findInnermostCommonBlockInScope(
+ store, &*sliceInsertionBlock->begin());
+ else
+ commonBlock =
+ findInnermostCommonBlockInScope(store, &*commonBlock->begin());
+ }
+ assert(commonBlock &&
+ "common block of producer stores and slice should exist");
+
+ // Find the first ancestor among the ancestors of `producerStores` in
+ // `commonBlock`.
+ Operation *firstAncestor = nullptr;
+ for (Operation *store : producerStores) {
+ Operation *ancestor = commonBlock->findAncestorOpInBlock(*store);
+ assert(ancestor && "producer store should be contained in common block");
+ firstAncestor = !firstAncestor || ancestor->isBeforeInBlock(firstAncestor)
+ ? ancestor
+ : firstAncestor;
+ }
+ return firstAncestor;
+}
+
// Creates and returns a private (single-user) memref for fused loop rooted
// at 'forOp', with (potentially reduced) memref size based on the
// MemRefRegion written to by 'srcStoreOpInst' at depth 'dstLoopDepth'.
// TODO: consider refactoring the common code from generateDma and
// this one.
-static Value createPrivateMemRef(AffineForOp forOp, Operation *srcStoreOpInst,
+static Value createPrivateMemRef(AffineForOp forOp,
+ ArrayRef<Operation *> storeOps,
unsigned dstLoopDepth,
std::optional<unsigned> fastMemorySpace,
+ Block *sliceInsertionBlock,
uint64_t localBufSizeThreshold) {
- Operation *forInst = forOp.getOperation();
+ assert(!storeOps.empty() && "no source stores supplied");
+ Operation *srcStoreOp = storeOps[0];
// Create builder to insert alloc op just before 'forOp'.
- OpBuilder b(forInst);
+ OpBuilder b(forOp);
// Builder to create constants at the top level.
- OpBuilder top(forInst->getParentRegion());
+ OpBuilder top(forOp->getParentRegion());
// Create new memref type based on slice bounds.
- auto oldMemRef = cast<AffineWriteOpInterface>(srcStoreOpInst).getMemRef();
+ auto oldMemRef = cast<AffineWriteOpInterface>(srcStoreOp).getMemRef();
auto oldMemRefType = cast<MemRefType>(oldMemRef.getType());
unsigned rank = oldMemRefType.getRank();
// Compute MemRefRegion for 'srcStoreOpInst' at depth 'dstLoopDepth'.
- MemRefRegion region(srcStoreOpInst->getLoc());
- bool validRegion = succeeded(region.compute(srcStoreOpInst, dstLoopDepth));
+ MemRefRegion region(srcStoreOp->getLoc());
+ bool validRegion = succeeded(region.compute(srcStoreOp, dstLoopDepth));
(void)validRegion;
assert(validRegion && "unexpected memref region failure");
SmallVector<int64_t, 4> newShape;
@@ -332,11 +374,12 @@ static Value createPrivateMemRef(AffineForOp forOp, Operation *srcStoreOpInst,
AffineMap::get(outerIVs.size() + rank, 0, remapExprs, forOp.getContext());
// Replace all users of 'oldMemRef' with 'newMemRef'.
+ Operation *domFilter =
+ getDominanceFilterForPrivateMemRefRepl(sliceInsertionBlock, storeOps);
LogicalResult res =
replaceAllMemRefUsesWith(oldMemRef, newMemRef, {}, indexRemap,
/*extraOperands=*/outerIVs,
- /*symbolOperands=*/{},
- /*domOpFilter=*/&*forOp.getBody()->begin());
+ /*symbolOperands=*/{}, domFilter);
assert(succeeded(res) &&
"replaceAllMemrefUsesWith should always succeed here");
(void)res;
@@ -944,6 +987,10 @@ struct GreedyFusion {
// Create private memrefs.
if (!privateMemrefs.empty()) {
+ // Note the block into which fusion was performed. This can be used to
+ // place `alloc`s that create private memrefs.
+ Block *sliceInsertionBlock = bestSlice.insertPoint->getBlock();
+
// Gather stores for all the private-to-be memrefs.
DenseMap<Value, SmallVector<Operation *, 4>> privateMemRefToStores;
dstAffineForOp.walk([&](AffineWriteOpInterface storeOp) {
@@ -962,8 +1009,8 @@ struct GreedyFusion {
SmallVector<Operation *, 4> &storesForMemref =
memrefToStoresPair.second;
Value newMemRef = createPrivateMemRef(
- dstAffineForOp, storesForMemref[0], bestDstLoopDepth,
- fastMemorySpace, localBufSizeThreshold);
+ dstAffineForOp, storesForMemref, bestDstLoopDepth,
+ fastMemorySpace, sliceInsertionBlock, localBufSizeThreshold);
// Create new node in dependence graph for 'newMemRef' alloc op.
unsigned newMemRefNodeId = mdg->addNode(newMemRef.getDefiningOp());
// Add edge from 'newMemRef' node to dstNode.
diff --git a/mlir/test/Dialect/Affine/loop-fusion-4.mlir b/mlir/test/Dialect/Affine/loop-fusion-4.mlir
index ea144f73bb21c6d..1241a46fb389419 100644
--- a/mlir/test/Dialect/Affine/loop-fusion-4.mlir
+++ b/mlir/test/Dialect/Affine/loop-fusion-4.mlir
@@ -285,3 +285,32 @@ module {
spirv.ReturnValue %3 : !spirv.array<8192 x f32>
}
}
+
+// -----
+
+// PRODUCER-CONSUMER-LABEL: func @same_memref_load_store
+func.func @same_memref_load_store(%producer : memref<32xf32>, %consumer: memref<16xf32>){
+ %cst = arith.constant 2.000000e+00 : f32
+ // Source isn't removed.
+ // PRODUCER-CONSUMER: affine.for %{{.*}} = 0 to 32
+ affine.for %arg3 = 0 to 32 {
+ %0 = affine.load %producer[%arg3] : memref<32xf32>
+ %2 = arith.mulf %0, %cst : f32
+ affine.store %2, %producer[%arg3] : memref<32xf32>
+ }
+ affine.for %arg3 = 0 to 16 {
+ %0 = affine.load %producer[%arg3] : memref<32xf32>
+ %2 = arith.addf %0, %cst : f32
+ affine.store %2, %consumer[%arg3] : memref<16xf32>
+ }
+ // Fused nest.
+ // PRODUCER-CONSUMER: affine.for %{{.*}} = 0 to 16
+ // PRODUCER-CONSUMER-NEXT: affine.load %{{.*}}[%{{.*}}] : memref<32xf32>
+ // PRODUCER-CONSUMER-NEXT: arith.mulf
+ // PRODUCER-CONSUMER-NEXT: affine.store %{{.*}}, %{{.*}}[0] : memref<1xf32>
+ // PRODUCER-CONSUMER-NEXT: affine.load %{{.*}}[0] : memref<1xf32>
+ // PRODUCER-CONSUMER-NEXT: arith.addf
+ // PRODUCER-CONSUMER-NEXT: affine.store
+ // PRODUCER-CONSUMER-NEXT: }
+ return
+}
|
24c3116
to
f275eda
Compare
f275eda
to
1e12875
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
1e12875
to
8610428
Compare
Fix private memref creation bug in affine fusion exposed in the case of the same memref being loaded from/stored to in producer nest. Make the private memref replacement sound. Change affine fusion debug string to affine-fusion - more compact. Fixes: llvm#48703
8610428
to
c7cd49d
Compare
…#126028) Fix private memref creation bug in affine fusion exposed in the case of the same memref being loaded from/stored to in producer nest. Make the private memref replacement sound. Change affine fusion debug string to affine-fusion - more compact. Fixes: llvm#48703
Fix private memref creation bug in affine fusion exposed in the case of
the same memref being loaded from/stored to in producer nest. Make the
private memref replacement sound.
Change affine fusion debug string to affine-fusion - more compact.
Fixes: #48703