Skip to content

Commit b82df46

Browse files
committed
[MLIR][Affine] Fix sibling fusion - missing check
Fix sibling fusion for slice maximality check. Producer-consumer fusion had this check but not sibling fusion. Sibling fusion shouldn't be performed if the slice isn't "maximal" (i.e., if it isn't the whole of the source). Fixes: #48703
1 parent 001ba42 commit b82df46

File tree

4 files changed

+51
-18
lines changed

4 files changed

+51
-18
lines changed

mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp

Lines changed: 29 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1162,24 +1162,48 @@ struct GreedyFusion {
11621162
}
11631163

11641164
assert(bestDstLoopDepth > 0 && "Unexpected loop fusion depth");
1165-
assert(!depthSliceUnions[bestDstLoopDepth - 1].isEmpty() &&
1165+
1166+
const ComputationSliceState &bestSlice =
1167+
depthSliceUnions[bestDstLoopDepth - 1];
1168+
assert(!bestSlice.isEmpty() &&
11661169
"Fusion depth has no computed slice union");
1170+
1171+
// Do not perform sibling fusion if it isn't maximal. We always remove the
1172+
// sibling node and as such fusion shouldn't be performed if a part of the
1173+
// slice is used in the destination.
1174+
auto isMaximal = bestSlice.isMaximal();
1175+
if (!isMaximal || !*isMaximal) {
1176+
LLVM_DEBUG(llvm::dbgs()
1177+
<< "Slice isn't maximal; not performing sibling fusion.\n");
1178+
continue;
1179+
}
1180+
11671181
// Check if source loop is being inserted in the innermost
11681182
// destination loop. Based on this, the fused loop may be optimized
11691183
// further inside `fuseLoops`.
11701184
bool isInnermostInsertion = (bestDstLoopDepth == dstLoopDepthTest);
11711185
// Fuse computation slice of 'sibLoopNest' into 'dstLoopNest'.
1172-
affine::fuseLoops(sibAffineForOp, dstAffineForOp,
1173-
depthSliceUnions[bestDstLoopDepth - 1],
1186+
affine::fuseLoops(sibAffineForOp, dstAffineForOp, bestSlice,
11741187
isInnermostInsertion);
11751188

11761189
auto dstForInst = cast<AffineForOp>(dstNode->op);
11771190
// Update operation position of fused loop nest (if needed).
1178-
if (insertPointInst != dstForInst) {
1191+
if (insertPointInst != dstForInst)
11791192
dstForInst->moveBefore(insertPointInst);
1180-
}
1193+
1194+
LLVM_DEBUG(llvm::dbgs()
1195+
<< "Fused sibling nest " << sibId << " into destination nest "
1196+
<< dstNode->id << " at depth " << bestDstLoopDepth << ":\n"
1197+
<< dstAffineForOp << "\n");
1198+
11811199
// Update data dependence graph state post fusion.
11821200
updateStateAfterSiblingFusion(sibNode, dstNode);
1201+
1202+
// Remove old sibling loop nest.
1203+
// Get op before we invalidate the MDG node.
1204+
Operation *op = sibNode->op;
1205+
mdg->removeNode(sibNode->id);
1206+
op->erase();
11831207
}
11841208
}
11851209

@@ -1321,13 +1345,6 @@ struct GreedyFusion {
13211345
mdg->addToNode(dstNode->id, dstLoopCollector.loadOpInsts,
13221346
dstLoopCollector.storeOpInsts, dstLoopCollector.memrefLoads,
13231347
dstLoopCollector.memrefStores, dstLoopCollector.memrefFrees);
1324-
// Remove old sibling loop nest if it no longer has outgoing dependence
1325-
// edges, and it does not write to a memref which escapes the block.
1326-
if (mdg->getOutEdgeCount(sibNode->id) == 0) {
1327-
Operation *op = sibNode->op;
1328-
mdg->removeNode(sibNode->id);
1329-
op->erase();
1330-
}
13311348
}
13321349

13331350
// Clean up any allocs with no users.

mlir/test/Dialect/Affine/loop-fusion-2.mlir

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -389,6 +389,8 @@ func.func @should_fuse_init_loops_siblings_then_shared_producer(%arg0: memref<10
389389

390390
// -----
391391

392+
// Test sibling fusion of two matrix-vector products sharing the input matrix.
393+
392394
func.func @two_matrix_vector_products() {
393395
%in_matrix = memref.alloc() : memref<10x10xf32>
394396
%in_vec0 = memref.alloc() : memref<10xf32>

mlir/test/Dialect/Affine/loop-fusion-4.mlir

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
// RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline='builtin.module(func.func(affine-loop-fusion{mode=producer}))' -split-input-file | FileCheck %s --check-prefix=PRODUCER-CONSUMER
22
// RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline='builtin.module(func.func(affine-loop-fusion{fusion-maximal mode=sibling}))' -split-input-file | FileCheck %s --check-prefix=SIBLING-MAXIMAL
3+
// All fusion: producer-consumer and sibling.
4+
// RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline='builtin.module(func.func(affine-loop-fusion))' -split-input-file
35
// RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline='builtin.module(spirv.func(affine-loop-fusion{mode=producer}))' -split-input-file | FileCheck %s --check-prefix=SPIRV
46

57
// Part I of fusion tests in mlir/test/Transforms/loop-fusion.mlir.
@@ -108,6 +110,7 @@ func.func @check_src_dst_step(%m : memref<100xf32>,
108110
func.func @reduce_add_non_maximal_f32_f32(%arg0: memref<64x64xf32, 1>, %arg1 : memref<1x64xf32, 1>, %arg2 : memref<1x64xf32, 1>) {
109111
%cst_0 = arith.constant 0.000000e+00 : f32
110112
%cst_1 = arith.constant 1.000000e+00 : f32
113+
// This nest writes to %arg1 but can be eliminated post sibling fusion.
111114
affine.for %arg3 = 0 to 1 {
112115
affine.for %arg4 = 0 to 64 {
113116
%accum = affine.for %arg5 = 0 to 64 iter_args (%prevAccum = %cst_0) -> f32 {
@@ -137,11 +140,11 @@ func.func @reduce_add_non_maximal_f32_f32(%arg0: memref<64x64xf32, 1>, %arg1 : m
137140
// since the destination loop and source loop trip counts do not
138141
// match.
139142
// SIBLING-MAXIMAL: %[[cst_0:.*]] = arith.constant 0.000000e+00 : f32
140-
// SIBLING-MAXIMAL-NEXT: %[[cst_1:.*]] = arith.constant 1.000000e+00 : f32
141-
// SIBLING-MAXIMAL-NEXT: affine.for %[[idx_0:.*]]= 0 to 1 {
142-
// SIBLING-MAXIMAL-NEXT: affine.for %[[idx_1:.*]] = 0 to 64 {
143-
// SIBLING-MAXIMAL-NEXT: %[[result_1:.*]] = affine.for %[[idx_2:.*]] = 0 to 32 iter_args(%[[iter_0:.*]] = %[[cst_1]]) -> (f32) {
144-
// SIBLING-MAXIMAL-NEXT: %[[result_0:.*]] = affine.for %[[idx_3:.*]] = 0 to 64 iter_args(%[[iter_1:.*]] = %[[cst_0]]) -> (f32) {
143+
// SIBLING-MAXIMAL-NEXT: %[[cst_1:.*]] = arith.constant 1.000000e+00 : f32
144+
// SIBLING-MAXIMAL-NEXT: affine.for %{{.*}} = 0 to 1 {
145+
// SIBLING-MAXIMAL-NEXT: affine.for %{{.*}} = 0 to 64 {
146+
// SIBLING-MAXIMAL-NEXT: affine.for %{{.*}} = 0 to 32 iter_args(%{{.*}} = %[[cst_1]]) -> (f32) {
147+
// SIBLING-MAXIMAL-NEXT: affine.for %{{.*}} = 0 to 64 iter_args(%{{.*}} = %[[cst_0]]) -> (f32) {
145148

146149
// -----
147150

@@ -315,11 +318,16 @@ func.func @same_memref_load_store(%producer : memref<32xf32>, %consumer: memref<
315318
return
316319
}
317320

321+
// -----
322+
318323
// PRODUCER-CONSUMER-LABEL: func @same_memref_load_multiple_stores
324+
// ALL-LABEL: func @same_memref_load_multiple_stores
319325
func.func @same_memref_load_multiple_stores(%producer : memref<32xf32>, %producer_2 : memref<32xf32>, %consumer: memref<16xf32>){
320326
%cst = arith.constant 2.000000e+00 : f32
321-
// Source isn't removed.
327+
// Ensure that source isn't removed during both producer-consumer fusion and
328+
// sibling fusion.
322329
// PRODUCER-CONSUMER: affine.for %{{.*}} = 0 to 32
330+
// ALL: affine.for %{{.*}} = 0 to 32
323331
affine.for %arg3 = 0 to 32 {
324332
%0 = affine.load %producer[%arg3] : memref<32xf32>
325333
%2 = arith.mulf %0, %cst : f32
@@ -343,5 +351,8 @@ func.func @same_memref_load_multiple_stores(%producer : memref<32xf32>, %produce
343351
// PRODUCER-CONSUMER-NEXT: arith.addf
344352
// PRODUCER-CONSUMER-NEXT: affine.store
345353
// PRODUCER-CONSUMER-NEXT: }
354+
// ALL: affine.for %{{.*}} = 0 to 16
355+
// ALL: mulf
356+
// ALL: addf
346357
return
347358
}

mlir/test/Dialect/Affine/loop-fusion.mlir

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1206,6 +1206,9 @@ func.func @should_fuse_with_private_memref() {
12061206
// CHECK: affine.for %{{.*}} = 0 to 17 {
12071207
// CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[0] : memref<1xf32>
12081208
// CHECK-NEXT: affine.load %{{.*}}[0] : memref<1xf32>
1209+
// CHECK-NEXT: }
1210+
// CHECK: affine.for %{{.*}} = 0 to 82 {
1211+
// CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[0] : memref<1xf32>
12091212
// CHECK-NEXT: affine.load %{{.*}}[0] : memref<1xf32>
12101213
// CHECK-NEXT: }
12111214
// CHECK-NEXT: return

0 commit comments

Comments
 (0)