Skip to content

Commit b2619a1

Browse files
committed
[MLIR][Affine] Fix sibling fusion - missing check
Fix sibling fusion for slice maximality check. Producer-consumer fusion had this check but not sibling fusion. The sibling source node can't be simply deleted post fusion if the slice isn't maximal. Fixes: #48703
1 parent 001ba42 commit b2619a1

File tree

2 files changed

+38
-18
lines changed

2 files changed

+38
-18
lines changed

mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp

Lines changed: 23 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1150,6 +1150,8 @@ struct GreedyFusion {
11501150
continue;
11511151

11521152
unsigned bestDstLoopDepth = maxLegalFusionDepth;
1153+
const ComputationSliceState &bestSlice =
1154+
depthSliceUnions[bestDstLoopDepth - 1];
11531155
if (!maximalFusion) {
11541156
// Check if fusion would be profitable. For sibling fusion, the sibling
11551157
// load op is treated as the src "store" op for fusion profitability
@@ -1162,24 +1164,40 @@ struct GreedyFusion {
11621164
}
11631165

11641166
assert(bestDstLoopDepth > 0 && "Unexpected loop fusion depth");
1165-
assert(!depthSliceUnions[bestDstLoopDepth - 1].isEmpty() &&
1167+
assert(!bestSlice.isEmpty() &&
11661168
"Fusion depth has no computed slice union");
11671169
// Check if source loop is being inserted in the innermost
11681170
// destination loop. Based on this, the fused loop may be optimized
11691171
// further inside `fuseLoops`.
11701172
bool isInnermostInsertion = (bestDstLoopDepth == dstLoopDepthTest);
11711173
// Fuse computation slice of 'sibLoopNest' into 'dstLoopNest'.
1172-
affine::fuseLoops(sibAffineForOp, dstAffineForOp,
1173-
depthSliceUnions[bestDstLoopDepth - 1],
1174+
affine::fuseLoops(sibAffineForOp, dstAffineForOp, bestSlice,
11741175
isInnermostInsertion);
11751176

11761177
auto dstForInst = cast<AffineForOp>(dstNode->op);
11771178
// Update operation position of fused loop nest (if needed).
1178-
if (insertPointInst != dstForInst) {
1179+
if (insertPointInst != dstForInst)
11791180
dstForInst->moveBefore(insertPointInst);
1180-
}
1181+
11811182
// Update data dependence graph state post fusion.
11821183
updateStateAfterSiblingFusion(sibNode, dstNode);
1184+
1185+
// Remove old sibling loop nest if it no longer has outgoing dependence
1186+
// edges, and the slice is maximal.
1187+
bool removeSrcNode = [&]() {
1188+
if (mdg->getOutEdgeCount(sibNode->id) > 0)
1189+
return false;
1190+
auto isMaximal = bestSlice.isMaximal();
1191+
return isMaximal && *isMaximal;
1192+
}();
1193+
LLVM_DEBUG(llvm::dbgs() << "Can remove source node after fusion: "
1194+
<< removeSrcNode << '\n');
1195+
if (removeSrcNode) {
1196+
// Get op before we invalidate the MDG node.
1197+
Operation *op = sibNode->op;
1198+
mdg->removeNode(sibNode->id);
1199+
op->erase();
1200+
}
11831201
}
11841202
}
11851203

@@ -1321,13 +1339,6 @@ struct GreedyFusion {
13211339
mdg->addToNode(dstNode->id, dstLoopCollector.loadOpInsts,
13221340
dstLoopCollector.storeOpInsts, dstLoopCollector.memrefLoads,
13231341
dstLoopCollector.memrefStores, dstLoopCollector.memrefFrees);
1324-
// Remove old sibling loop nest if it no longer has outgoing dependence
1325-
// edges, and it does not write to a memref which escapes the block.
1326-
if (mdg->getOutEdgeCount(sibNode->id) == 0) {
1327-
Operation *op = sibNode->op;
1328-
mdg->removeNode(sibNode->id);
1329-
op->erase();
1330-
}
13311342
}
13321343

13331344
// Clean up any allocs with no users.

mlir/test/Dialect/Affine/loop-fusion-4.mlir

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
// RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline='builtin.module(func.func(affine-loop-fusion{mode=producer}))' -split-input-file | FileCheck %s --check-prefix=PRODUCER-CONSUMER
22
// RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline='builtin.module(func.func(affine-loop-fusion{fusion-maximal mode=sibling}))' -split-input-file | FileCheck %s --check-prefix=SIBLING-MAXIMAL
3+
// All fusion: producer-consumer and sibling.
4+
// RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline='builtin.module(func.func(affine-loop-fusion{fusion-maximal mode=sibling}))' -split-input-file
35
// RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline='builtin.module(spirv.func(affine-loop-fusion{mode=producer}))' -split-input-file | FileCheck %s --check-prefix=SPIRV
46

57
// Part I of fusion tests in mlir/test/Transforms/loop-fusion.mlir.
@@ -108,6 +110,7 @@ func.func @check_src_dst_step(%m : memref<100xf32>,
108110
func.func @reduce_add_non_maximal_f32_f32(%arg0: memref<64x64xf32, 1>, %arg1 : memref<1x64xf32, 1>, %arg2 : memref<1x64xf32, 1>) {
109111
%cst_0 = arith.constant 0.000000e+00 : f32
110112
%cst_1 = arith.constant 1.000000e+00 : f32
113+
// This nest writes to %arg1 but can be eliminated post sibling fusion.
111114
affine.for %arg3 = 0 to 1 {
112115
affine.for %arg4 = 0 to 64 {
113116
%accum = affine.for %arg5 = 0 to 64 iter_args (%prevAccum = %cst_0) -> f32 {
@@ -137,11 +140,11 @@ func.func @reduce_add_non_maximal_f32_f32(%arg0: memref<64x64xf32, 1>, %arg1 : m
137140
// since the destination loop and source loop trip counts do not
138141
// match.
139142
// SIBLING-MAXIMAL: %[[cst_0:.*]] = arith.constant 0.000000e+00 : f32
140-
// SIBLING-MAXIMAL-NEXT: %[[cst_1:.*]] = arith.constant 1.000000e+00 : f32
141-
// SIBLING-MAXIMAL-NEXT: affine.for %[[idx_0:.*]]= 0 to 1 {
142-
// SIBLING-MAXIMAL-NEXT: affine.for %[[idx_1:.*]] = 0 to 64 {
143-
// SIBLING-MAXIMAL-NEXT: %[[result_1:.*]] = affine.for %[[idx_2:.*]] = 0 to 32 iter_args(%[[iter_0:.*]] = %[[cst_1]]) -> (f32) {
144-
// SIBLING-MAXIMAL-NEXT: %[[result_0:.*]] = affine.for %[[idx_3:.*]] = 0 to 64 iter_args(%[[iter_1:.*]] = %[[cst_0]]) -> (f32) {
143+
// SIBLING-MAXIMAL-NEXT: %[[cst_1:.*]] = arith.constant 1.000000e+00 : f32
144+
// SIBLING-MAXIMAL-NEXT: affine.for %{{.*}} = 0 to 1 {
145+
// SIBLING-MAXIMAL-NEXT: affine.for %{{.*}} = 0 to 64 {
146+
// SIBLING-MAXIMAL-NEXT: affine.for %{{.*}} = 0 to 32 iter_args(%{{.*}} = %[[cst_1]]) -> (f32) {
147+
// SIBLING-MAXIMAL-NEXT: affine.for %{{.*}} = 0 to 64 iter_args(%{{.*}} = %[[cst_0]]) -> (f32) {
145148

146149
// -----
147150

@@ -316,10 +319,13 @@ func.func @same_memref_load_store(%producer : memref<32xf32>, %consumer: memref<
316319
}
317320

318321
// PRODUCER-CONSUMER-LABEL: func @same_memref_load_multiple_stores
322+
// ALL-LABEL: func @same_memref_load_multiple_stores
319323
func.func @same_memref_load_multiple_stores(%producer : memref<32xf32>, %producer_2 : memref<32xf32>, %consumer: memref<16xf32>){
320324
%cst = arith.constant 2.000000e+00 : f32
321-
// Source isn't removed.
325+
// Ensure that source isn't removed during both producer-consumer fusion and
326+
// sibling fusion.
322327
// PRODUCER-CONSUMER: affine.for %{{.*}} = 0 to 32
328+
// ALL: affine.for %{{.*}} = 0 to 32
323329
affine.for %arg3 = 0 to 32 {
324330
%0 = affine.load %producer[%arg3] : memref<32xf32>
325331
%2 = arith.mulf %0, %cst : f32
@@ -343,5 +349,8 @@ func.func @same_memref_load_multiple_stores(%producer : memref<32xf32>, %produce
343349
// PRODUCER-CONSUMER-NEXT: arith.addf
344350
// PRODUCER-CONSUMER-NEXT: affine.store
345351
// PRODUCER-CONSUMER-NEXT: }
352+
// ALL: affine.for %{{.*}} = 0 to 16
353+
// ALL: mulf
354+
// ALL: addf
346355
return
347356
}

0 commit comments

Comments
 (0)