Skip to content

[MLIR][Affine] Fix sibling fusion - missing check #126626

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 13, 2025

Conversation

bondhugula
Copy link
Contributor

@bondhugula bondhugula commented Feb 11, 2025

Fix sibling fusion for slice maximality check. Producer-consumer fusion
had this check but not sibling fusion. Sibling fusion shouldn't be
performed if the slice isn't "maximal" (i.e., if it isn't the whole of
the source).

Fixes: #48703

@llvmbot
Copy link
Member

llvmbot commented Feb 11, 2025

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-affine

Author: Uday Bondhugula (bondhugula)

Changes

Fix sibling fusion for slice maximality check. Producer-consumer fusion
had this check but not sibling fusion. The sibling source node can't be
simply deleted post fusion if the slice isn't maximal.


Full diff: https://github.com/llvm/llvm-project/pull/126626.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp (+23-12)
  • (modified) mlir/test/Dialect/Affine/loop-fusion-4.mlir (+15-6)
diff --git a/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp b/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp
index b38dd8effe669df..f70cf8bcc9d3f95 100644
--- a/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp
@@ -1150,6 +1150,8 @@ struct GreedyFusion {
         continue;
 
       unsigned bestDstLoopDepth = maxLegalFusionDepth;
+      const ComputationSliceState &bestSlice =
+          depthSliceUnions[bestDstLoopDepth - 1];
       if (!maximalFusion) {
         // Check if fusion would be profitable. For sibling fusion, the sibling
         // load op is treated as the src "store" op for fusion profitability
@@ -1162,24 +1164,40 @@ struct GreedyFusion {
       }
 
       assert(bestDstLoopDepth > 0 && "Unexpected loop fusion depth");
-      assert(!depthSliceUnions[bestDstLoopDepth - 1].isEmpty() &&
+      assert(!bestSlice.isEmpty() &&
              "Fusion depth has no computed slice union");
       // Check if source loop is being inserted in the innermost
       // destination loop. Based on this, the fused loop may be optimized
       // further inside `fuseLoops`.
       bool isInnermostInsertion = (bestDstLoopDepth == dstLoopDepthTest);
       // Fuse computation slice of 'sibLoopNest' into 'dstLoopNest'.
-      affine::fuseLoops(sibAffineForOp, dstAffineForOp,
-                        depthSliceUnions[bestDstLoopDepth - 1],
+      affine::fuseLoops(sibAffineForOp, dstAffineForOp, bestSlice,
                         isInnermostInsertion);
 
       auto dstForInst = cast<AffineForOp>(dstNode->op);
       // Update operation position of fused loop nest (if needed).
-      if (insertPointInst != dstForInst) {
+      if (insertPointInst != dstForInst)
         dstForInst->moveBefore(insertPointInst);
-      }
+
       // Update data dependence graph state post fusion.
       updateStateAfterSiblingFusion(sibNode, dstNode);
+
+      // Remove old sibling loop nest if it no longer has outgoing dependence
+      // edges, and the slice is maximal.
+      bool removeSrcNode = [&]() {
+        if (mdg->getOutEdgeCount(sibNode->id) > 0)
+          return false;
+        auto isMaximal = bestSlice.isMaximal();
+        return isMaximal && *isMaximal;
+      }();
+      LLVM_DEBUG(llvm::dbgs() << "Can remove source node after fusion: "
+                              << removeSrcNode << '\n');
+      if (removeSrcNode) {
+        // Get op before we invalidate the MDG node.
+        Operation *op = sibNode->op;
+        mdg->removeNode(sibNode->id);
+        op->erase();
+      }
     }
   }
 
@@ -1321,13 +1339,6 @@ struct GreedyFusion {
     mdg->addToNode(dstNode->id, dstLoopCollector.loadOpInsts,
                    dstLoopCollector.storeOpInsts, dstLoopCollector.memrefLoads,
                    dstLoopCollector.memrefStores, dstLoopCollector.memrefFrees);
-    // Remove old sibling loop nest if it no longer has outgoing dependence
-    // edges, and it does not write to a memref which escapes the block.
-    if (mdg->getOutEdgeCount(sibNode->id) == 0) {
-      Operation *op = sibNode->op;
-      mdg->removeNode(sibNode->id);
-      op->erase();
-    }
   }
 
   // Clean up any allocs with no users.
diff --git a/mlir/test/Dialect/Affine/loop-fusion-4.mlir b/mlir/test/Dialect/Affine/loop-fusion-4.mlir
index 2830235431c7646..e2d71021793a572 100644
--- a/mlir/test/Dialect/Affine/loop-fusion-4.mlir
+++ b/mlir/test/Dialect/Affine/loop-fusion-4.mlir
@@ -1,5 +1,7 @@
 // RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline='builtin.module(func.func(affine-loop-fusion{mode=producer}))' -split-input-file | FileCheck %s --check-prefix=PRODUCER-CONSUMER
 // RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline='builtin.module(func.func(affine-loop-fusion{fusion-maximal mode=sibling}))' -split-input-file | FileCheck %s --check-prefix=SIBLING-MAXIMAL
+// All fusion: producer-consumer and sibling.
+// RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline='builtin.module(func.func(affine-loop-fusion{fusion-maximal mode=sibling}))' -split-input-file
 // RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline='builtin.module(spirv.func(affine-loop-fusion{mode=producer}))' -split-input-file | FileCheck %s --check-prefix=SPIRV
 
 // Part I of fusion tests in  mlir/test/Transforms/loop-fusion.mlir.
@@ -108,6 +110,7 @@ func.func @check_src_dst_step(%m : memref<100xf32>,
 func.func @reduce_add_non_maximal_f32_f32(%arg0: memref<64x64xf32, 1>, %arg1 : memref<1x64xf32, 1>, %arg2 : memref<1x64xf32, 1>) {
     %cst_0 = arith.constant 0.000000e+00 : f32
     %cst_1 = arith.constant 1.000000e+00 : f32
+    // This nest writes to %arg1 but can be eliminated post sibling fusion.
     affine.for %arg3 = 0 to 1 {
       affine.for %arg4 = 0 to 64 {
         %accum = affine.for %arg5 = 0 to 64 iter_args (%prevAccum = %cst_0) -> f32 {
@@ -137,11 +140,11 @@ func.func @reduce_add_non_maximal_f32_f32(%arg0: memref<64x64xf32, 1>, %arg1 : m
 // since the destination loop and source loop trip counts do not
 // match.
 // SIBLING-MAXIMAL:        %[[cst_0:.*]] = arith.constant 0.000000e+00 : f32
-// SIBLING-MAXIMAL-NEXT:        %[[cst_1:.*]] = arith.constant 1.000000e+00 : f32
-// SIBLING-MAXIMAL-NEXT:           affine.for %[[idx_0:.*]]= 0 to 1 {
-// SIBLING-MAXIMAL-NEXT:             affine.for %[[idx_1:.*]] = 0 to 64 {
-// SIBLING-MAXIMAL-NEXT:               %[[result_1:.*]] = affine.for %[[idx_2:.*]] = 0 to 32 iter_args(%[[iter_0:.*]] = %[[cst_1]]) -> (f32) {
-// SIBLING-MAXIMAL-NEXT:                 %[[result_0:.*]] = affine.for %[[idx_3:.*]] = 0 to 64 iter_args(%[[iter_1:.*]] = %[[cst_0]]) -> (f32) {
+// SIBLING-MAXIMAL-NEXT:   %[[cst_1:.*]] = arith.constant 1.000000e+00 : f32
+// SIBLING-MAXIMAL-NEXT:   affine.for %{{.*}} = 0 to 1 {
+// SIBLING-MAXIMAL-NEXT:     affine.for %{{.*}} = 0 to 64 {
+// SIBLING-MAXIMAL-NEXT:       affine.for %{{.*}} = 0 to 32 iter_args(%{{.*}} = %[[cst_1]]) -> (f32) {
+// SIBLING-MAXIMAL-NEXT:       affine.for %{{.*}} = 0 to 64 iter_args(%{{.*}} = %[[cst_0]]) -> (f32) {
 
 // -----
 
@@ -316,10 +319,13 @@ func.func @same_memref_load_store(%producer : memref<32xf32>, %consumer: memref<
 }
 
 // PRODUCER-CONSUMER-LABEL: func @same_memref_load_multiple_stores
+// ALL-LABEL: func @same_memref_load_multiple_stores
 func.func @same_memref_load_multiple_stores(%producer : memref<32xf32>, %producer_2 : memref<32xf32>, %consumer: memref<16xf32>){
   %cst = arith.constant 2.000000e+00 : f32
-  // Source isn't removed.
+  // Ensure that source isn't removed during both producer-consumer fusion and
+  // sibling fusion.
   // PRODUCER-CONSUMER: affine.for %{{.*}} = 0 to 32
+  // ALL: affine.for %{{.*}} = 0 to 32
   affine.for %arg3 = 0 to 32 {
     %0 = affine.load %producer[%arg3] : memref<32xf32>
     %2 = arith.mulf %0, %cst : f32
@@ -343,5 +349,8 @@ func.func @same_memref_load_multiple_stores(%producer : memref<32xf32>, %produce
   // PRODUCER-CONSUMER-NEXT:   arith.addf
   // PRODUCER-CONSUMER-NEXT:   affine.store
   // PRODUCER-CONSUMER-NEXT: }
+  // ALL:     affine.for %{{.*}} = 0 to 16
+  // ALL:       mulf
+  // ALL:       addf
   return
 }

@bondhugula
Copy link
Contributor Author

Changes planned.

@bondhugula bondhugula force-pushed the uday/fix_sibling_fusion branch from b2619a1 to 45e041f Compare February 11, 2025 03:47
@bondhugula bondhugula marked this pull request as ready for review February 11, 2025 03:47
@bondhugula
Copy link
Contributor Author

Changes planned.

Done.

@bondhugula bondhugula force-pushed the uday/fix_sibling_fusion branch 3 times, most recently from b82df46 to 09cf7d4 Compare February 12, 2025 00:57
Copy link
Contributor

@patel-vimal patel-vimal left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.

Fix sibling fusion for slice maximality check. Producer-consumer fusion
had this check but not sibling fusion. Sibling fusion shouldn't be
performed if the slice isn't "maximal" (i.e., if it isn't the whole of
the source).

Fixes: llvm#48703
@bondhugula bondhugula force-pushed the uday/fix_sibling_fusion branch from 09cf7d4 to 7a03b9f Compare February 12, 2025 23:57
@bondhugula bondhugula merged commit 8421ad7 into llvm:main Feb 13, 2025
8 checks passed
flovent pushed a commit to flovent/llvm-project that referenced this pull request Feb 13, 2025
Fix sibling fusion for slice maximality check. Producer-consumer fusion
had this check but not sibling fusion. Sibling fusion shouldn't be
performed if the slice isn't "maximal" (i.e., if it isn't the whole of
the source).


Fixes: llvm#48703
joaosaffran pushed a commit to joaosaffran/llvm-project that referenced this pull request Feb 14, 2025
Fix sibling fusion for slice maximality check. Producer-consumer fusion
had this check but not sibling fusion. Sibling fusion shouldn't be
performed if the slice isn't "maximal" (i.e., if it isn't the whole of
the source).


Fixes: llvm#48703
sivan-shani pushed a commit to sivan-shani/llvm-project that referenced this pull request Feb 24, 2025
Fix sibling fusion for slice maximality check. Producer-consumer fusion
had this check but not sibling fusion. Sibling fusion shouldn't be
performed if the slice isn't "maximal" (i.e., if it isn't the whole of
the source).


Fixes: llvm#48703
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Incorrect loop fusion in case of a producer with self-dependence
3 participants