-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[MLIR][Affine] Fix sibling fusion - missing check #126626
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-affine Author: Uday Bondhugula (bondhugula) ChangesFix sibling fusion for slice maximality check. Producer-consumer fusion Full diff: https://github.com/llvm/llvm-project/pull/126626.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp b/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp
index b38dd8effe669df..f70cf8bcc9d3f95 100644
--- a/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp
@@ -1150,6 +1150,8 @@ struct GreedyFusion {
continue;
unsigned bestDstLoopDepth = maxLegalFusionDepth;
+ const ComputationSliceState &bestSlice =
+ depthSliceUnions[bestDstLoopDepth - 1];
if (!maximalFusion) {
// Check if fusion would be profitable. For sibling fusion, the sibling
// load op is treated as the src "store" op for fusion profitability
@@ -1162,24 +1164,40 @@ struct GreedyFusion {
}
assert(bestDstLoopDepth > 0 && "Unexpected loop fusion depth");
- assert(!depthSliceUnions[bestDstLoopDepth - 1].isEmpty() &&
+ assert(!bestSlice.isEmpty() &&
"Fusion depth has no computed slice union");
// Check if source loop is being inserted in the innermost
// destination loop. Based on this, the fused loop may be optimized
// further inside `fuseLoops`.
bool isInnermostInsertion = (bestDstLoopDepth == dstLoopDepthTest);
// Fuse computation slice of 'sibLoopNest' into 'dstLoopNest'.
- affine::fuseLoops(sibAffineForOp, dstAffineForOp,
- depthSliceUnions[bestDstLoopDepth - 1],
+ affine::fuseLoops(sibAffineForOp, dstAffineForOp, bestSlice,
isInnermostInsertion);
auto dstForInst = cast<AffineForOp>(dstNode->op);
// Update operation position of fused loop nest (if needed).
- if (insertPointInst != dstForInst) {
+ if (insertPointInst != dstForInst)
dstForInst->moveBefore(insertPointInst);
- }
+
// Update data dependence graph state post fusion.
updateStateAfterSiblingFusion(sibNode, dstNode);
+
+ // Remove old sibling loop nest if it no longer has outgoing dependence
+ // edges, and the slice is maximal.
+ bool removeSrcNode = [&]() {
+ if (mdg->getOutEdgeCount(sibNode->id) > 0)
+ return false;
+ auto isMaximal = bestSlice.isMaximal();
+ return isMaximal && *isMaximal;
+ }();
+ LLVM_DEBUG(llvm::dbgs() << "Can remove source node after fusion: "
+ << removeSrcNode << '\n');
+ if (removeSrcNode) {
+ // Get op before we invalidate the MDG node.
+ Operation *op = sibNode->op;
+ mdg->removeNode(sibNode->id);
+ op->erase();
+ }
}
}
@@ -1321,13 +1339,6 @@ struct GreedyFusion {
mdg->addToNode(dstNode->id, dstLoopCollector.loadOpInsts,
dstLoopCollector.storeOpInsts, dstLoopCollector.memrefLoads,
dstLoopCollector.memrefStores, dstLoopCollector.memrefFrees);
- // Remove old sibling loop nest if it no longer has outgoing dependence
- // edges, and it does not write to a memref which escapes the block.
- if (mdg->getOutEdgeCount(sibNode->id) == 0) {
- Operation *op = sibNode->op;
- mdg->removeNode(sibNode->id);
- op->erase();
- }
}
// Clean up any allocs with no users.
diff --git a/mlir/test/Dialect/Affine/loop-fusion-4.mlir b/mlir/test/Dialect/Affine/loop-fusion-4.mlir
index 2830235431c7646..e2d71021793a572 100644
--- a/mlir/test/Dialect/Affine/loop-fusion-4.mlir
+++ b/mlir/test/Dialect/Affine/loop-fusion-4.mlir
@@ -1,5 +1,7 @@
// RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline='builtin.module(func.func(affine-loop-fusion{mode=producer}))' -split-input-file | FileCheck %s --check-prefix=PRODUCER-CONSUMER
// RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline='builtin.module(func.func(affine-loop-fusion{fusion-maximal mode=sibling}))' -split-input-file | FileCheck %s --check-prefix=SIBLING-MAXIMAL
+// All fusion: producer-consumer and sibling.
+// RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline='builtin.module(func.func(affine-loop-fusion{fusion-maximal mode=sibling}))' -split-input-file
// RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline='builtin.module(spirv.func(affine-loop-fusion{mode=producer}))' -split-input-file | FileCheck %s --check-prefix=SPIRV
// Part I of fusion tests in mlir/test/Transforms/loop-fusion.mlir.
@@ -108,6 +110,7 @@ func.func @check_src_dst_step(%m : memref<100xf32>,
func.func @reduce_add_non_maximal_f32_f32(%arg0: memref<64x64xf32, 1>, %arg1 : memref<1x64xf32, 1>, %arg2 : memref<1x64xf32, 1>) {
%cst_0 = arith.constant 0.000000e+00 : f32
%cst_1 = arith.constant 1.000000e+00 : f32
+ // This nest writes to %arg1 but can be eliminated post sibling fusion.
affine.for %arg3 = 0 to 1 {
affine.for %arg4 = 0 to 64 {
%accum = affine.for %arg5 = 0 to 64 iter_args (%prevAccum = %cst_0) -> f32 {
@@ -137,11 +140,11 @@ func.func @reduce_add_non_maximal_f32_f32(%arg0: memref<64x64xf32, 1>, %arg1 : m
// since the destination loop and source loop trip counts do not
// match.
// SIBLING-MAXIMAL: %[[cst_0:.*]] = arith.constant 0.000000e+00 : f32
-// SIBLING-MAXIMAL-NEXT: %[[cst_1:.*]] = arith.constant 1.000000e+00 : f32
-// SIBLING-MAXIMAL-NEXT: affine.for %[[idx_0:.*]]= 0 to 1 {
-// SIBLING-MAXIMAL-NEXT: affine.for %[[idx_1:.*]] = 0 to 64 {
-// SIBLING-MAXIMAL-NEXT: %[[result_1:.*]] = affine.for %[[idx_2:.*]] = 0 to 32 iter_args(%[[iter_0:.*]] = %[[cst_1]]) -> (f32) {
-// SIBLING-MAXIMAL-NEXT: %[[result_0:.*]] = affine.for %[[idx_3:.*]] = 0 to 64 iter_args(%[[iter_1:.*]] = %[[cst_0]]) -> (f32) {
+// SIBLING-MAXIMAL-NEXT: %[[cst_1:.*]] = arith.constant 1.000000e+00 : f32
+// SIBLING-MAXIMAL-NEXT: affine.for %{{.*}} = 0 to 1 {
+// SIBLING-MAXIMAL-NEXT: affine.for %{{.*}} = 0 to 64 {
+// SIBLING-MAXIMAL-NEXT: affine.for %{{.*}} = 0 to 32 iter_args(%{{.*}} = %[[cst_1]]) -> (f32) {
+// SIBLING-MAXIMAL-NEXT: affine.for %{{.*}} = 0 to 64 iter_args(%{{.*}} = %[[cst_0]]) -> (f32) {
// -----
@@ -316,10 +319,13 @@ func.func @same_memref_load_store(%producer : memref<32xf32>, %consumer: memref<
}
// PRODUCER-CONSUMER-LABEL: func @same_memref_load_multiple_stores
+// ALL-LABEL: func @same_memref_load_multiple_stores
func.func @same_memref_load_multiple_stores(%producer : memref<32xf32>, %producer_2 : memref<32xf32>, %consumer: memref<16xf32>){
%cst = arith.constant 2.000000e+00 : f32
- // Source isn't removed.
+ // Ensure that source isn't removed during both producer-consumer fusion and
+ // sibling fusion.
// PRODUCER-CONSUMER: affine.for %{{.*}} = 0 to 32
+ // ALL: affine.for %{{.*}} = 0 to 32
affine.for %arg3 = 0 to 32 {
%0 = affine.load %producer[%arg3] : memref<32xf32>
%2 = arith.mulf %0, %cst : f32
@@ -343,5 +349,8 @@ func.func @same_memref_load_multiple_stores(%producer : memref<32xf32>, %produce
// PRODUCER-CONSUMER-NEXT: arith.addf
// PRODUCER-CONSUMER-NEXT: affine.store
// PRODUCER-CONSUMER-NEXT: }
+ // ALL: affine.for %{{.*}} = 0 to 16
+ // ALL: mulf
+ // ALL: addf
return
}
|
366d78d
to
b2619a1
Compare
Changes planned. |
b2619a1
to
45e041f
Compare
Done. |
b82df46
to
09cf7d4
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM.
Fix sibling fusion for slice maximality check. Producer-consumer fusion had this check but not sibling fusion. Sibling fusion shouldn't be performed if the slice isn't "maximal" (i.e., if it isn't the whole of the source). Fixes: llvm#48703
09cf7d4
to
7a03b9f
Compare
Fix sibling fusion for slice maximality check. Producer-consumer fusion had this check but not sibling fusion. Sibling fusion shouldn't be performed if the slice isn't "maximal" (i.e., if it isn't the whole of the source). Fixes: llvm#48703
Fix sibling fusion for slice maximality check. Producer-consumer fusion had this check but not sibling fusion. Sibling fusion shouldn't be performed if the slice isn't "maximal" (i.e., if it isn't the whole of the source). Fixes: llvm#48703
Fix sibling fusion for slice maximality check. Producer-consumer fusion had this check but not sibling fusion. Sibling fusion shouldn't be performed if the slice isn't "maximal" (i.e., if it isn't the whole of the source). Fixes: llvm#48703
Fix sibling fusion for slice maximality check. Producer-consumer fusion
had this check but not sibling fusion. Sibling fusion shouldn't be
performed if the slice isn't "maximal" (i.e., if it isn't the whole of
the source).
Fixes: #48703