Skip to content

Commit 8421ad7

Browse files
authored
[MLIR][Affine] Fix sibling fusion - missing check (llvm#126626)
Fix sibling fusion for slice maximality check. Producer-consumer fusion had this check but not sibling fusion. Sibling fusion shouldn't be performed if the slice isn't "maximal" (i.e., if it isn't the whole of the source). Fixes: llvm#48703
1 parent a6f7cb5 commit 8421ad7

File tree

4 files changed

+51
-18
lines changed

4 files changed

+51
-18
lines changed

mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp

Lines changed: 29 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1165,24 +1165,48 @@ struct GreedyFusion {
11651165
}
11661166

11671167
assert(bestDstLoopDepth > 0 && "Unexpected loop fusion depth");
1168-
assert(!depthSliceUnions[bestDstLoopDepth - 1].isEmpty() &&
1168+
1169+
const ComputationSliceState &bestSlice =
1170+
depthSliceUnions[bestDstLoopDepth - 1];
1171+
assert(!bestSlice.isEmpty() &&
11691172
"Fusion depth has no computed slice union");
1173+
1174+
// Do not perform sibling fusion if it isn't maximal. We always remove the
1175+
// sibling node and as such fusion shouldn't be performed if a part of the
1176+
// slice is used in the destination.
1177+
auto isMaximal = bestSlice.isMaximal();
1178+
if (!isMaximal.value_or(false)) {
1179+
LLVM_DEBUG(llvm::dbgs()
1180+
<< "Slice isn't maximal; not performing sibling fusion.\n");
1181+
continue;
1182+
}
1183+
11701184
// Check if source loop is being inserted in the innermost
11711185
// destination loop. Based on this, the fused loop may be optimized
11721186
// further inside `fuseLoops`.
11731187
bool isInnermostInsertion = (bestDstLoopDepth == dstLoopDepthTest);
11741188
// Fuse computation slice of 'sibLoopNest' into 'dstLoopNest'.
1175-
affine::fuseLoops(sibAffineForOp, dstAffineForOp,
1176-
depthSliceUnions[bestDstLoopDepth - 1],
1189+
affine::fuseLoops(sibAffineForOp, dstAffineForOp, bestSlice,
11771190
isInnermostInsertion);
11781191

11791192
auto dstForInst = cast<AffineForOp>(dstNode->op);
11801193
// Update operation position of fused loop nest (if needed).
1181-
if (insertPointInst != dstForInst) {
1194+
if (insertPointInst != dstForInst)
11821195
dstForInst->moveBefore(insertPointInst);
1183-
}
1196+
1197+
LLVM_DEBUG(llvm::dbgs()
1198+
<< "Fused sibling nest " << sibId << " into destination nest "
1199+
<< dstNode->id << " at depth " << bestDstLoopDepth << ":\n"
1200+
<< dstAffineForOp << "\n");
1201+
11841202
// Update data dependence graph state post fusion.
11851203
updateStateAfterSiblingFusion(sibNode, dstNode);
1204+
1205+
// Remove old sibling loop nest.
1206+
// Get op before we invalidate the MDG node.
1207+
Operation *op = sibNode->op;
1208+
mdg->removeNode(sibNode->id);
1209+
op->erase();
11861210
}
11871211
}
11881212

@@ -1324,13 +1348,6 @@ struct GreedyFusion {
13241348
mdg->addToNode(dstNode->id, dstLoopCollector.loadOpInsts,
13251349
dstLoopCollector.storeOpInsts, dstLoopCollector.memrefLoads,
13261350
dstLoopCollector.memrefStores, dstLoopCollector.memrefFrees);
1327-
// Remove old sibling loop nest if it no longer has outgoing dependence
1328-
// edges, and it does not write to a memref which escapes the block.
1329-
if (mdg->getOutEdgeCount(sibNode->id) == 0) {
1330-
Operation *op = sibNode->op;
1331-
mdg->removeNode(sibNode->id);
1332-
op->erase();
1333-
}
13341351
}
13351352

13361353
// Clean up any allocs with no users.

mlir/test/Dialect/Affine/loop-fusion-2.mlir

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -389,6 +389,8 @@ func.func @should_fuse_init_loops_siblings_then_shared_producer(%arg0: memref<10
389389

390390
// -----
391391

392+
// Test sibling fusion of two matrix-vector products sharing the input matrix.
393+
392394
func.func @two_matrix_vector_products() {
393395
%in_matrix = memref.alloc() : memref<10x10xf32>
394396
%in_vec0 = memref.alloc() : memref<10xf32>

mlir/test/Dialect/Affine/loop-fusion-4.mlir

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
// RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline='builtin.module(func.func(affine-loop-fusion{mode=producer}))' -split-input-file | FileCheck %s --check-prefix=PRODUCER-CONSUMER
22
// RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline='builtin.module(func.func(affine-loop-fusion{mode=producer fusion-maximal}))' -split-input-file | FileCheck %s --check-prefix=PRODUCER-CONSUMER-MAXIMAL
33
// RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline='builtin.module(func.func(affine-loop-fusion{fusion-maximal mode=sibling}))' -split-input-file | FileCheck %s --check-prefix=SIBLING-MAXIMAL
4+
// All fusion: producer-consumer and sibling.
5+
// RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline='builtin.module(func.func(affine-loop-fusion))' -split-input-file | FileCheck %s --check-prefix=ALL
46
// RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline='builtin.module(spirv.func(affine-loop-fusion{mode=producer}))' -split-input-file | FileCheck %s --check-prefix=SPIRV
57

68
// Part I of fusion tests in mlir/test/Transforms/loop-fusion.mlir.
@@ -109,6 +111,7 @@ func.func @check_src_dst_step(%m : memref<100xf32>,
109111
func.func @reduce_add_non_maximal_f32_f32(%arg0: memref<64x64xf32, 1>, %arg1 : memref<1x64xf32, 1>, %arg2 : memref<1x64xf32, 1>) {
110112
%cst_0 = arith.constant 0.000000e+00 : f32
111113
%cst_1 = arith.constant 1.000000e+00 : f32
114+
// This nest writes to %arg1 but can be eliminated post sibling fusion.
112115
affine.for %arg3 = 0 to 1 {
113116
affine.for %arg4 = 0 to 64 {
114117
%accum = affine.for %arg5 = 0 to 64 iter_args (%prevAccum = %cst_0) -> f32 {
@@ -138,11 +141,11 @@ func.func @reduce_add_non_maximal_f32_f32(%arg0: memref<64x64xf32, 1>, %arg1 : m
138141
// since the destination loop and source loop trip counts do not
139142
// match.
140143
// SIBLING-MAXIMAL: %[[cst_0:.*]] = arith.constant 0.000000e+00 : f32
141-
// SIBLING-MAXIMAL-NEXT: %[[cst_1:.*]] = arith.constant 1.000000e+00 : f32
142-
// SIBLING-MAXIMAL-NEXT: affine.for %[[idx_0:.*]]= 0 to 1 {
143-
// SIBLING-MAXIMAL-NEXT: affine.for %[[idx_1:.*]] = 0 to 64 {
144-
// SIBLING-MAXIMAL-NEXT: %[[result_1:.*]] = affine.for %[[idx_2:.*]] = 0 to 32 iter_args(%[[iter_0:.*]] = %[[cst_1]]) -> (f32) {
145-
// SIBLING-MAXIMAL-NEXT: %[[result_0:.*]] = affine.for %[[idx_3:.*]] = 0 to 64 iter_args(%[[iter_1:.*]] = %[[cst_0]]) -> (f32) {
144+
// SIBLING-MAXIMAL-NEXT: %[[cst_1:.*]] = arith.constant 1.000000e+00 : f32
145+
// SIBLING-MAXIMAL-NEXT: affine.for %{{.*}} = 0 to 1 {
146+
// SIBLING-MAXIMAL-NEXT: affine.for %{{.*}} = 0 to 64 {
147+
// SIBLING-MAXIMAL-NEXT: affine.for %{{.*}} = 0 to 32 iter_args(%{{.*}} = %[[cst_1]]) -> (f32) {
148+
// SIBLING-MAXIMAL-NEXT: affine.for %{{.*}} = 0 to 64 iter_args(%{{.*}} = %[[cst_0]]) -> (f32) {
146149

147150
// -----
148151

@@ -316,11 +319,16 @@ func.func @same_memref_load_store(%producer : memref<32xf32>, %consumer: memref<
316319
return
317320
}
318321

322+
// -----
323+
319324
// PRODUCER-CONSUMER-LABEL: func @same_memref_load_multiple_stores
325+
// ALL-LABEL: func @same_memref_load_multiple_stores
320326
func.func @same_memref_load_multiple_stores(%producer : memref<32xf32>, %producer_2 : memref<32xf32>, %consumer: memref<16xf32>){
321327
%cst = arith.constant 2.000000e+00 : f32
322-
// Source isn't removed.
328+
// Ensure that source isn't removed during both producer-consumer fusion and
329+
// sibling fusion.
323330
// PRODUCER-CONSUMER: affine.for %{{.*}} = 0 to 32
331+
// ALL: affine.for %{{.*}} = 0 to 32
324332
affine.for %arg3 = 0 to 32 {
325333
%0 = affine.load %producer[%arg3] : memref<32xf32>
326334
%2 = arith.mulf %0, %cst : f32
@@ -344,6 +352,9 @@ func.func @same_memref_load_multiple_stores(%producer : memref<32xf32>, %produce
344352
// PRODUCER-CONSUMER-NEXT: arith.addf
345353
// PRODUCER-CONSUMER-NEXT: affine.store
346354
// PRODUCER-CONSUMER-NEXT: }
355+
// ALL: affine.for %{{.*}} = 0 to 16
356+
// ALL: mulf
357+
// ALL: addf
347358
return
348359
}
349360

mlir/test/Dialect/Affine/loop-fusion.mlir

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1206,6 +1206,9 @@ func.func @should_fuse_with_private_memref() {
12061206
// CHECK: affine.for %{{.*}} = 0 to 17 {
12071207
// CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[0] : memref<1xf32>
12081208
// CHECK-NEXT: affine.load %{{.*}}[0] : memref<1xf32>
1209+
// CHECK-NEXT: }
1210+
// CHECK: affine.for %{{.*}} = 0 to 82 {
1211+
// CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[0] : memref<1xf32>
12091212
// CHECK-NEXT: affine.load %{{.*}}[0] : memref<1xf32>
12101213
// CHECK-NEXT: }
12111214
// CHECK-NEXT: return

0 commit comments

Comments
 (0)