Skip to content

[MLIR][Affine] Fix private memref creation bug in affine fusion #126028

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged

Conversation

bondhugula
Copy link
Contributor

@bondhugula bondhugula commented Feb 6, 2025

Fix private memref creation bug in affine fusion exposed in the case of
the same memref being loaded from/stored to in producer nest. Make the
private memref replacement sound.

Change affine fusion debug string to affine-fusion - more compact.

Fixes: #48703

@llvmbot
Copy link
Member

llvmbot commented Feb 6, 2025

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-affine

Author: Uday Bondhugula (bondhugula)

Changes

Fix private memref creation bug in affine fusion exposed in the case of
the same memref being loaded from/stored to in producer nest. Make the
private memref replacement sound.

Change affine fusion debug string to affine-fusion - more compact.


Full diff: https://github.com/llvm/llvm-project/pull/126028.diff

4 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Affine/Analysis/Utils.h (+8)
  • (modified) mlir/lib/Dialect/Affine/Analysis/Utils.cpp (+38)
  • (modified) mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp (+59-12)
  • (modified) mlir/test/Dialect/Affine/loop-fusion-4.mlir (+29)
diff --git a/mlir/include/mlir/Dialect/Affine/Analysis/Utils.h b/mlir/include/mlir/Dialect/Affine/Analysis/Utils.h
index b1fbf4477428ca2..cff4983af17fc90 100644
--- a/mlir/include/mlir/Dialect/Affine/Analysis/Utils.h
+++ b/mlir/include/mlir/Dialect/Affine/Analysis/Utils.h
@@ -610,6 +610,14 @@ FailureOr<AffineValueMap>
 simplifyConstrainedMinMaxOp(Operation *op,
                             FlatAffineValueConstraints constraints);
 
+/// Find the innermost common `Block` of `a` and `b` in the affine scope
+/// that `a` and `b` are part of. Return nullptr if they belong to different
+/// affine scopes. Also, return null if they do not have a common `Block`
+/// ancestor (for eg., when they are part of the `then` and `else` regions
+/// of an op that itself starts an affine scope.
+mlir::Block *findInnermostCommonBlockInScope(mlir::Operation *a,
+                                             mlir::Operation *b);
+
 } // namespace affine
 } // namespace mlir
 
diff --git a/mlir/lib/Dialect/Affine/Analysis/Utils.cpp b/mlir/lib/Dialect/Affine/Analysis/Utils.cpp
index 9c0b5dbf52d299b..d6c62cdd613643e 100644
--- a/mlir/lib/Dialect/Affine/Analysis/Utils.cpp
+++ b/mlir/lib/Dialect/Affine/Analysis/Utils.cpp
@@ -12,6 +12,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Dialect/Affine/Analysis/Utils.h"
+
 #include "mlir/Analysis/Presburger/PresburgerRelation.h"
 #include "mlir/Dialect/Affine/Analysis/AffineAnalysis.h"
 #include "mlir/Dialect/Affine/Analysis/LoopAnalysis.h"
@@ -2297,3 +2298,40 @@ FailureOr<AffineValueMap> mlir::affine::simplifyConstrainedMinMaxOp(
   affine::canonicalizeMapAndOperands(&newMap, &newOperands);
   return AffineValueMap(newMap, newOperands);
 }
+
+Block *mlir::affine::findInnermostCommonBlockInScope(Operation *a,
+                                                     Operation *b) {
+  Region *aScope = mlir::affine::getAffineScope(a);
+  Region *bScope = mlir::affine::getAffineScope(b);
+  if (aScope != bScope)
+    return nullptr;
+
+  // Get the block ancestry of `a` while stopping at the affine scope.
+  auto getBlockAncestry = [&](Operation *op,
+                              SmallVectorImpl<Block *> &ancestry) {
+    Operation *curOp = op;
+    do {
+      ancestry.push_back(curOp->getBlock());
+      if (curOp->getParentRegion() == aScope)
+        break;
+      curOp = curOp->getParentOp();
+    } while (curOp);
+    assert(curOp && "can't reach root op without passing through affine scope");
+    std::reverse(ancestry.begin(), ancestry.end());
+  };
+
+  SmallVector<Block *, 4> aAncestors, bAncestors;
+  getBlockAncestry(a, aAncestors);
+  getBlockAncestry(b, bAncestors);
+  assert(!aAncestors.empty() && !bAncestors.empty() &&
+         "at least one Block ancestor expected");
+
+  Block *innermostCommonBlock = nullptr;
+  for (unsigned a = 0, b = 0, e = aAncestors.size(), f = bAncestors.size();
+       a < e && b < f; ++a, ++b) {
+    if (aAncestors[a] != bAncestors[b])
+      break;
+    innermostCommonBlock = aAncestors[a];
+  }
+  return innermostCommonBlock;
+}
diff --git a/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp b/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp
index c22ec213be95c84..0ea27df704d0694 100644
--- a/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp
@@ -41,7 +41,7 @@ namespace affine {
 } // namespace affine
 } // namespace mlir
 
-#define DEBUG_TYPE "affine-loop-fusion"
+#define DEBUG_TYPE "affine-fusion"
 
 using namespace mlir;
 using namespace mlir::affine;
@@ -237,29 +237,71 @@ static void sinkSequentialLoops(MemRefDependenceGraph::Node *node) {
   node->op = newRootForOp;
 }
 
+/// Get the operation that should act as a dominance filter while replacing
+/// memref uses with a private memref for which `producerStores` and
+/// `sliceInsertionBlock` are provided. This effectively determines in what
+/// part of the IR we should be performing the replacement.
+static Operation *
+getDominanceFilterForPrivateMemRefRepl(Block *sliceInsertionBlock,
+                                       ArrayRef<Operation *> producerStores) {
+  assert(!producerStores.empty() && "expected producer store");
+
+  // We first find the common block that contains the producer stores and
+  // the slice computation. The first ancestor among the ancestors of the
+  // producer stores in that common block is the dominance filter to use for
+  // replacement.
+  Block *commonBlock = nullptr;
+  // Find the common block of all relevant operations.
+  for (Operation *store : producerStores) {
+    if (!commonBlock)
+      commonBlock = findInnermostCommonBlockInScope(
+          store, &*sliceInsertionBlock->begin());
+    else
+      commonBlock =
+          findInnermostCommonBlockInScope(store, &*commonBlock->begin());
+  }
+  assert(commonBlock &&
+         "common block of producer stores and slice should exist");
+
+  // Find the first ancestor among the ancestors of `producerStores` in
+  // `commonBlock`.
+  Operation *firstAncestor = nullptr;
+  for (Operation *store : producerStores) {
+    Operation *ancestor = commonBlock->findAncestorOpInBlock(*store);
+    assert(ancestor && "producer store should be contained in common block");
+    firstAncestor = !firstAncestor || ancestor->isBeforeInBlock(firstAncestor)
+                        ? ancestor
+                        : firstAncestor;
+  }
+  return firstAncestor;
+}
+
 // Creates and returns a private (single-user) memref for fused loop rooted
 // at 'forOp', with (potentially reduced) memref size based on the
 // MemRefRegion written to by 'srcStoreOpInst' at depth 'dstLoopDepth'.
 // TODO: consider refactoring the common code from generateDma and
 // this one.
-static Value createPrivateMemRef(AffineForOp forOp, Operation *srcStoreOpInst,
+static Value createPrivateMemRef(AffineForOp forOp,
+                                 ArrayRef<Operation *> storeOps,
                                  unsigned dstLoopDepth,
                                  std::optional<unsigned> fastMemorySpace,
+                                 Block *sliceInsertionBlock,
                                  uint64_t localBufSizeThreshold) {
-  Operation *forInst = forOp.getOperation();
+  assert(!storeOps.empty() && "no source stores supplied");
+  Operation *srcStoreOp = storeOps[0];
 
   // Create builder to insert alloc op just before 'forOp'.
-  OpBuilder b(forInst);
+  OpBuilder b(forOp);
   // Builder to create constants at the top level.
-  OpBuilder top(forInst->getParentRegion());
+  OpBuilder top(forOp->getParentRegion());
   // Create new memref type based on slice bounds.
-  auto oldMemRef = cast<AffineWriteOpInterface>(srcStoreOpInst).getMemRef();
+  auto oldMemRef = cast<AffineWriteOpInterface>(srcStoreOp).getMemRef();
   auto oldMemRefType = cast<MemRefType>(oldMemRef.getType());
   unsigned rank = oldMemRefType.getRank();
 
   // Compute MemRefRegion for 'srcStoreOpInst' at depth 'dstLoopDepth'.
-  MemRefRegion region(srcStoreOpInst->getLoc());
-  bool validRegion = succeeded(region.compute(srcStoreOpInst, dstLoopDepth));
+  MemRefRegion region(srcStoreOp->getLoc());
+  bool validRegion = succeeded(region.compute(srcStoreOp, dstLoopDepth));
   (void)validRegion;
   assert(validRegion && "unexpected memref region failure");
   SmallVector<int64_t, 4> newShape;
@@ -332,11 +374,12 @@ static Value createPrivateMemRef(AffineForOp forOp, Operation *srcStoreOpInst,
       AffineMap::get(outerIVs.size() + rank, 0, remapExprs, forOp.getContext());
 
   // Replace all users of 'oldMemRef' with 'newMemRef'.
+  Operation *domFilter =
+      getDominanceFilterForPrivateMemRefRepl(sliceInsertionBlock, storeOps);
   LogicalResult res =
       replaceAllMemRefUsesWith(oldMemRef, newMemRef, {}, indexRemap,
                                /*extraOperands=*/outerIVs,
-                               /*symbolOperands=*/{},
-                               /*domOpFilter=*/&*forOp.getBody()->begin());
+                               /*symbolOperands=*/{}, domFilter);
   assert(succeeded(res) &&
          "replaceAllMemrefUsesWith should always succeed here");
   (void)res;
@@ -944,6 +987,10 @@ struct GreedyFusion {
 
         // Create private memrefs.
         if (!privateMemrefs.empty()) {
+          // Note the block into which fusion was performed. This can be used to
+          // place `alloc`s that create private memrefs.
+          Block *sliceInsertionBlock = bestSlice.insertPoint->getBlock();
+
           // Gather stores for all the private-to-be memrefs.
           DenseMap<Value, SmallVector<Operation *, 4>> privateMemRefToStores;
           dstAffineForOp.walk([&](AffineWriteOpInterface storeOp) {
@@ -962,8 +1009,8 @@ struct GreedyFusion {
             SmallVector<Operation *, 4> &storesForMemref =
                 memrefToStoresPair.second;
             Value newMemRef = createPrivateMemRef(
-                dstAffineForOp, storesForMemref[0], bestDstLoopDepth,
-                fastMemorySpace, localBufSizeThreshold);
+                dstAffineForOp, storesForMemref, bestDstLoopDepth,
+                fastMemorySpace, sliceInsertionBlock, localBufSizeThreshold);
             // Create new node in dependence graph for 'newMemRef' alloc op.
             unsigned newMemRefNodeId = mdg->addNode(newMemRef.getDefiningOp());
             // Add edge from 'newMemRef' node to dstNode.
diff --git a/mlir/test/Dialect/Affine/loop-fusion-4.mlir b/mlir/test/Dialect/Affine/loop-fusion-4.mlir
index ea144f73bb21c6d..1241a46fb389419 100644
--- a/mlir/test/Dialect/Affine/loop-fusion-4.mlir
+++ b/mlir/test/Dialect/Affine/loop-fusion-4.mlir
@@ -285,3 +285,32 @@ module {
     spirv.ReturnValue %3 : !spirv.array<8192 x f32>
   }
 }
+
+// -----
+
+// PRODUCER-CONSUMER-LABEL: func @same_memref_load_store
+func.func @same_memref_load_store(%producer : memref<32xf32>, %consumer: memref<16xf32>){
+  %cst = arith.constant 2.000000e+00 : f32
+  // Source isn't removed.
+  // PRODUCER-CONSUMER: affine.for %{{.*}} = 0 to 32
+  affine.for %arg3 = 0 to 32 {
+    %0 = affine.load %producer[%arg3] : memref<32xf32>
+    %2 = arith.mulf %0, %cst : f32
+    affine.store %2, %producer[%arg3] : memref<32xf32>
+  }
+  affine.for %arg3 = 0 to 16 {
+    %0 = affine.load %producer[%arg3] : memref<32xf32>
+    %2 = arith.addf %0, %cst : f32
+    affine.store %2, %consumer[%arg3] : memref<16xf32>
+  }
+  // Fused nest.
+  // PRODUCER-CONSUMER:      affine.for %{{.*}} = 0 to 16
+  // PRODUCER-CONSUMER-NEXT:   affine.load %{{.*}}[%{{.*}}] : memref<32xf32>
+  // PRODUCER-CONSUMER-NEXT:   arith.mulf
+  // PRODUCER-CONSUMER-NEXT:   affine.store %{{.*}}, %{{.*}}[0] : memref<1xf32>
+  // PRODUCER-CONSUMER-NEXT:   affine.load %{{.*}}[0] : memref<1xf32>
+  // PRODUCER-CONSUMER-NEXT:   arith.addf
+  // PRODUCER-CONSUMER-NEXT:   affine.store
+  // PRODUCER-CONSUMER-NEXT: }
+  return
+}

@bondhugula bondhugula force-pushed the uday/fix_fusion_private_memref_creation branch from f275eda to 1e12875 Compare February 7, 2025 06:09
Copy link
Contributor

@arnab-polymage arnab-polymage left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@bondhugula bondhugula force-pushed the uday/fix_fusion_private_memref_creation branch from 1e12875 to 8610428 Compare February 8, 2025 02:36
Fix private memref creation bug in affine fusion exposed in the case of
the same memref being loaded from/stored to in producer nest. Make the
private memref replacement sound.

Change affine fusion debug string to affine-fusion - more compact.

Fixes: llvm#48703
@bondhugula bondhugula force-pushed the uday/fix_fusion_private_memref_creation branch from 8610428 to c7cd49d Compare February 8, 2025 02:41
@bondhugula bondhugula merged commit b850ce4 into llvm:main Feb 8, 2025
8 checks passed
Icohedron pushed a commit to Icohedron/llvm-project that referenced this pull request Feb 11, 2025
…#126028)

Fix private memref creation bug in affine fusion exposed in the case of
the same memref being loaded from/stored to in producer nest. Make the
private memref replacement sound.

Change affine fusion debug string to affine-fusion - more compact.

Fixes: llvm#48703
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Incorrect loop fusion in case of a producer with self-dependence
3 participants