Skip to content

Commit 203d5ee

Browse files
tunglddcaballe
authored andcommitted
[MLIR][affine-loop-fusion] Handle defining ops between the source and dest loops
This patch handles defining ops between the source and dest loop nests, and prevents loop nests with `iter_args` from being fused. If there is any SSA value in the dest loop nest whose defining op has dependence from the source loop nest, we cannot fuse the loop nests. If there is a `affine.for` with `iter_args`, prevent it from being fused. Reviewed By: dcaballe, bondhugula Differential Revision: https://reviews.llvm.org/D97030
1 parent b368fc7 commit 203d5ee

File tree

2 files changed

+187
-5
lines changed

2 files changed

+187
-5
lines changed

mlir/lib/Transforms/LoopFusion.cpp

Lines changed: 37 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -179,8 +179,8 @@ struct MemRefDependenceGraph {
179179
// which contain accesses to the same memref 'value'. If the value is a
180180
// non-memref value, then the dependence is between a graph node which
181181
// defines an SSA value and another graph node which uses the SSA value
182-
// (e.g. a constant operation defining a value which is used inside a loop
183-
// nest).
182+
// (e.g. a constant or load operation defining a value which is used inside
183+
// a loop nest).
184184
Value value;
185185
};
186186

@@ -369,13 +369,35 @@ struct MemRefDependenceGraph {
369369
return outEdgeCount;
370370
}
371371

372+
/// Return all nodes which define SSA values used in node 'id'.
373+
void gatherDefiningNodes(unsigned id, DenseSet<unsigned> &definingNodes) {
374+
for (MemRefDependenceGraph::Edge edge : inEdges[id])
375+
// By definition of edge, if the edge value is a non-memref value,
376+
// then the dependence is between a graph node which defines an SSA value
377+
// and another graph node which uses the SSA value.
378+
if (!edge.value.getType().isa<MemRefType>())
379+
definingNodes.insert(edge.id);
380+
}
381+
372382
// Computes and returns an insertion point operation, before which the
373383
// the fused <srcId, dstId> loop nest can be inserted while preserving
374384
// dependences. Returns nullptr if no such insertion point is found.
375385
Operation *getFusedLoopNestInsertionPoint(unsigned srcId, unsigned dstId) {
376386
if (outEdges.count(srcId) == 0)
377387
return getNode(dstId)->op;
378388

389+
// Skip if there is any defining node of 'dstId' that depends on 'srcId'.
390+
DenseSet<unsigned> definingNodes;
391+
gatherDefiningNodes(dstId, definingNodes);
392+
if (llvm::any_of(definingNodes, [&](unsigned id) {
393+
return hasDependencePath(srcId, id);
394+
})) {
395+
LLVM_DEBUG(llvm::dbgs()
396+
<< "Can't fuse: a defining op with a user in the dst "
397+
"loop has dependence from the src loop\n");
398+
return nullptr;
399+
}
400+
379401
// Build set of insts in range (srcId, dstId) which depend on 'srcId'.
380402
SmallPtrSet<Operation *, 2> srcDepInsts;
381403
for (auto &outEdge : outEdges[srcId])
@@ -784,10 +806,11 @@ bool MemRefDependenceGraph::init(FuncOp f) {
784806
}
785807

786808
// Add dependence edges between nodes which produce SSA values and their
787-
// users.
809+
// users. Load ops can be considered as the ones producing SSA values.
788810
for (auto &idAndNode : nodes) {
789811
const Node &node = idAndNode.second;
790-
if (!node.loads.empty() || !node.stores.empty())
812+
// Stores don't define SSA values, skip them.
813+
if (!node.stores.empty())
791814
continue;
792815
auto *opInst = node.op;
793816
for (auto value : opInst->getResults()) {
@@ -956,7 +979,7 @@ static Value createPrivateMemRef(AffineForOp forOp, Operation *srcStoreOpInst,
956979

957980
/// Walking from node 'srcId' to node 'dstId' (exclusive of 'srcId' and
958981
/// 'dstId'), if there is any non-affine operation accessing 'memref', return
959-
/// false. Otherwise, return true.
982+
/// true. Otherwise, return false.
960983
static bool hasNonAffineUsersOnThePath(unsigned srcId, unsigned dstId,
961984
Value memref,
962985
MemRefDependenceGraph *mdg) {
@@ -1389,6 +1412,10 @@ struct GreedyFusion {
13891412
// Skip if 'dstNode' is not a loop nest.
13901413
if (!isa<AffineForOp>(dstNode->op))
13911414
continue;
1415+
// Skip if 'dstNode' is a loop nest returning values.
1416+
// TODO: support loop nests that return values.
1417+
if (dstNode->op->getNumResults() > 0)
1418+
continue;
13921419

13931420
LLVM_DEBUG(llvm::dbgs() << "Evaluating dst loop " << dstId << "\n");
13941421

@@ -1419,6 +1446,11 @@ struct GreedyFusion {
14191446
LLVM_DEBUG(llvm::dbgs() << "Evaluating src loop " << srcId
14201447
<< " for dst loop " << dstId << "\n");
14211448

1449+
// Skip if 'srcNode' is a loop nest returning values.
1450+
// TODO: support loop nests that return values.
1451+
if (isa<AffineForOp>(srcNode->op) && srcNode->op->getNumResults() > 0)
1452+
continue;
1453+
14221454
DenseSet<Value> producerConsumerMemrefs;
14231455
gatherProducerConsumerMemrefs(srcId, dstId, mdg,
14241456
producerConsumerMemrefs);

mlir/test/Transforms/loop-fusion.mlir

Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2837,6 +2837,7 @@ func @should_fuse_multi_store_producer_with_scaping_memrefs_and_preserve_src(
28372837
}
28382838

28392839
// -----
2840+
28402841
func @should_not_fuse_due_to_dealloc(%arg0: memref<16xf32>){
28412842
%A = alloc() : memref<16xf32>
28422843
%C = alloc() : memref<16xf32>
@@ -2866,3 +2867,152 @@ func @should_not_fuse_due_to_dealloc(%arg0: memref<16xf32>){
28662867
// CHECK-NEXT: affine.load
28672868
// CHECK-NEXT: addf
28682869
// CHECK-NEXT: affine.store
2870+
2871+
// -----
2872+
2873+
// CHECK-LABEL: func @should_fuse_defining_node_has_no_dependence_from_source_node
2874+
func @should_fuse_defining_node_has_no_dependence_from_source_node(
2875+
%a : memref<10xf32>, %b : memref<f32>) -> () {
2876+
affine.for %i0 = 0 to 10 {
2877+
%0 = affine.load %b[] : memref<f32>
2878+
affine.store %0, %a[%i0] : memref<10xf32>
2879+
}
2880+
%0 = affine.load %b[] : memref<f32>
2881+
affine.for %i1 = 0 to 10 {
2882+
%1 = affine.load %a[%i1] : memref<10xf32>
2883+
%2 = divf %0, %1 : f32
2884+
}
2885+
2886+
// Loops '%i0' and '%i1' should be fused even though there is a defining
2887+
// node between the loops. It is because the node has no dependence from '%i0'.
2888+
// CHECK: affine.load %{{.*}}[] : memref<f32>
2889+
// CHECK-NEXT: affine.for %{{.*}} = 0 to 10 {
2890+
// CHECK-NEXT: affine.load %{{.*}}[] : memref<f32>
2891+
// CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[%{{.*}}] : memref<10xf32>
2892+
// CHECK-NEXT: affine.load %{{.*}}[%{{.*}}] : memref<10xf32>
2893+
// CHECK-NEXT: divf
2894+
// CHECK-NEXT: }
2895+
// CHECK-NOT: affine.for
2896+
return
2897+
}
2898+
2899+
// -----
2900+
2901+
// CHECK-LABEL: func @should_not_fuse_defining_node_has_dependence_from_source_loop
2902+
func @should_not_fuse_defining_node_has_dependence_from_source_loop(
2903+
%a : memref<10xf32>, %b : memref<f32>) -> () {
2904+
%cst = constant 0.000000e+00 : f32
2905+
affine.for %i0 = 0 to 10 {
2906+
affine.store %cst, %b[] : memref<f32>
2907+
affine.store %cst, %a[%i0] : memref<10xf32>
2908+
}
2909+
%0 = affine.load %b[] : memref<f32>
2910+
affine.for %i1 = 0 to 10 {
2911+
%1 = affine.load %a[%i1] : memref<10xf32>
2912+
%2 = divf %0, %1 : f32
2913+
}
2914+
2915+
// Loops '%i0' and '%i1' should not be fused because the defining node
2916+
// of '%0' used in '%i1' has dependence from loop '%i0'.
2917+
// CHECK: affine.for %{{.*}} = 0 to 10 {
2918+
// CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[] : memref<f32>
2919+
// CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[%{{.*}}] : memref<10xf32>
2920+
// CHECK-NEXT: }
2921+
// CHECK-NEXT: affine.load %{{.*}}[] : memref<f32>
2922+
// CHECK: affine.for %{{.*}} = 0 to 10 {
2923+
// CHECK-NEXT: affine.load %{{.*}}[%{{.*}}] : memref<10xf32>
2924+
// CHECK-NEXT: divf
2925+
// CHECK-NEXT: }
2926+
return
2927+
}
2928+
2929+
// -----
2930+
2931+
// CHECK-LABEL: func @should_not_fuse_defining_node_has_transitive_dependence_from_source_loop
2932+
func @should_not_fuse_defining_node_has_transitive_dependence_from_source_loop(
2933+
%a : memref<10xf32>, %b : memref<10xf32>, %c : memref<f32>) -> () {
2934+
%cst = constant 0.000000e+00 : f32
2935+
affine.for %i0 = 0 to 10 {
2936+
affine.store %cst, %a[%i0] : memref<10xf32>
2937+
affine.store %cst, %b[%i0] : memref<10xf32>
2938+
}
2939+
affine.for %i1 = 0 to 10 {
2940+
%1 = affine.load %b[%i1] : memref<10xf32>
2941+
affine.store %1, %c[] : memref<f32>
2942+
}
2943+
%0 = affine.load %c[] : memref<f32>
2944+
affine.for %i2 = 0 to 10 {
2945+
%1 = affine.load %a[%i2] : memref<10xf32>
2946+
%2 = divf %0, %1 : f32
2947+
}
2948+
2949+
// When loops '%i0' and '%i2' are evaluated first, they should not be
2950+
// fused. The defining node of '%0' in loop '%i2' has transitive dependence
2951+
// from loop '%i0'. After that, loops '%i0' and '%i1' are evaluated, and they
2952+
// will be fused as usual.
2953+
// CHECK: affine.for %{{.*}} = 0 to 10 {
2954+
// CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[%{{.*}}] : memref<10xf32>
2955+
// CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[%{{.*}}] : memref<10xf32>
2956+
// CHECK-NEXT: affine.load %{{.*}}[%{{.*}}] : memref<10xf32>
2957+
// CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[] : memref<f32>
2958+
// CHECK-NEXT: }
2959+
// CHECK-NEXT: affine.load %{{.*}}[] : memref<f32>
2960+
// CHECK: affine.for %{{.*}} = 0 to 10 {
2961+
// CHECK-NEXT: affine.load %{{.*}}[%{{.*}}] : memref<10xf32>
2962+
// CHECK-NEXT: divf
2963+
// CHECK-NEXT: }
2964+
// CHECK-NOT: affine.for
2965+
return
2966+
}
2967+
2968+
// -----
2969+
2970+
// CHECK-LABEL: func @should_not_fuse_dest_loop_nest_return_value
2971+
func @should_not_fuse_dest_loop_nest_return_value(
2972+
%a : memref<10xf32>) -> () {
2973+
%cst = constant 0.000000e+00 : f32
2974+
affine.for %i0 = 0 to 10 {
2975+
affine.store %cst, %a[%i0] : memref<10xf32>
2976+
}
2977+
%b = affine.for %i1 = 0 to 10 step 2 iter_args(%b_iter = %cst) -> f32 {
2978+
%load_a = affine.load %a[%i1] : memref<10xf32>
2979+
affine.yield %load_a: f32
2980+
}
2981+
2982+
// CHECK: affine.for %{{.*}} = 0 to 10 {
2983+
// CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[%{{.*}}] : memref<10xf32>
2984+
// CHECK-NEXT: }
2985+
// CHECK: affine.for %{{.*}} = 0 to 10 step 2 iter_args(%{{.*}} = %{{.*}}) -> (f32) {
2986+
// CHECK-NEXT: affine.load
2987+
// CHECK-NEXT: affine.yield
2988+
// CHECK-NEXT: }
2989+
2990+
return
2991+
}
2992+
2993+
// -----
2994+
2995+
// CHECK-LABEL: func @should_not_fuse_src_loop_nest_return_value
2996+
func @should_not_fuse_src_loop_nest_return_value(
2997+
%a : memref<10xf32>) -> () {
2998+
%cst = constant 1.000000e+00 : f32
2999+
%b = affine.for %i = 0 to 10 step 2 iter_args(%b_iter = %cst) -> f32 {
3000+
%c = addf %b_iter, %b_iter : f32
3001+
affine.store %c, %a[%i] : memref<10xf32>
3002+
affine.yield %c: f32
3003+
}
3004+
affine.for %i1 = 0 to 10 {
3005+
%1 = affine.load %a[%i1] : memref<10xf32>
3006+
}
3007+
3008+
// CHECK: %{{.*}} = affine.for %{{.*}} = 0 to 10 step 2 iter_args(%{{.*}} = %{{.*}}) -> (f32) {
3009+
// CHECK-NEXT: %{{.*}} = addf %{{.*}}, %{{.*}} : f32
3010+
// CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[%{{.*}}] : memref<10xf32>
3011+
// CHECK-NEXT: affine.yield %{{.*}} : f32
3012+
// CHECK-NEXT: }
3013+
// CHECK: affine.for %{{.*}} = 0 to 10 {
3014+
// CHECK-NEXT: affine.load %{{.*}}[%{{.*}}] : memref<10xf32>
3015+
// CHECK-NEXT: }
3016+
3017+
return
3018+
}

0 commit comments

Comments
 (0)