Skip to content

Commit 8127a75

Browse files
committed
[mlir][scf] Considering affine.apply when fusing scf::ParallelOp
When checking the load indices of the second loop coincide with the store indices of the first loop, it only considers the index values are the same or not. However, there are some cases the index values come from affine.apply operator. In these cases, it will treat them as different even the affine map is the same and the affine operands are the same. We already check if the iteration space is the same in isFusionLegal(). When checking affine.apply, we only need to consider the operands come from the same induction variables. If so, we know the results of affine.apply are the same.
1 parent f96e85b commit 8127a75

File tree

2 files changed

+77
-2
lines changed

2 files changed

+77
-2
lines changed

mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
//
1111
//===----------------------------------------------------------------------===//
1212

13+
#include "mlir/Dialect/Affine/IR/AffineOps.h"
1314
#include "mlir/Dialect/SCF/Transforms/Passes.h"
1415

1516
#include "mlir/Analysis/AliasAnalysis.h"
@@ -27,6 +28,7 @@ namespace mlir {
2728
} // namespace mlir
2829

2930
using namespace mlir;
31+
using namespace mlir::affine;
3032
using namespace mlir::scf;
3133

3234
/// Verify there are no nested ParallelOps.
@@ -54,6 +56,16 @@ static bool equalIterationSpaces(ParallelOp firstPloop,
5456
matchOperands(firstPloop.getStep(), secondPloop.getStep());
5557
}
5658

59+
static int getInductionVarIndex(Value operand, ParallelOp loop) {
60+
auto indVars = loop.getInductionVars();
61+
auto it = std::find(indVars.begin(), indVars.end(), operand);
62+
63+
if (it != indVars.end())
64+
return static_cast<int>(std::distance(indVars.begin(), it));
65+
66+
return -1;
67+
}
68+
5769
/// Checks if the parallel loops have mixed access to the same buffers. Returns
5870
/// `true` if the first parallel loop writes to the same indices that the second
5971
/// loop reads.
@@ -102,8 +114,25 @@ static bool haveNoReadsAfterWriteExceptSameIndex(
102114
return WalkResult::interrupt();
103115
for (int i = 0, e = storeIndices.size(); i < e; ++i) {
104116
if (firstToSecondPloopIndices.lookupOrDefault(storeIndices[i]) !=
105-
loadIndices[i])
106-
return WalkResult::interrupt();
117+
loadIndices[i]) {
118+
auto storeIndexDef = storeIndices[i].getDefiningOp<AffineApplyOp>();
119+
auto loadIndexDef = loadIndices[i].getDefiningOp<AffineApplyOp>();
120+
if (storeIndexDef && loadIndexDef) {
121+
// When two indices come from affine.apply, we check the results of
122+
// these two affine.apply are the same or not.
123+
if (storeIndexDef.getAffineMap() != loadIndexDef.getAffineMap())
124+
return WalkResult::interrupt();
125+
if (storeIndexDef.getNumOperands() != loadIndexDef.getNumOperands())
126+
return WalkResult::interrupt();
127+
for (unsigned i = 0; i < storeIndexDef.getNumOperands(); ++i) {
128+
if (getInductionVarIndex(storeIndexDef.getOperand(i), firstPloop) !=
129+
getInductionVarIndex(loadIndexDef.getOperand(i), secondPloop))
130+
return WalkResult::interrupt();
131+
}
132+
} else {
133+
return WalkResult::interrupt();
134+
}
135+
}
107136
}
108137
return WalkResult::advance();
109138
});

mlir/test/Dialect/SCF/parallel-loop-fusion.mlir

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -480,3 +480,49 @@ func.func @do_not_fuse_multiple_stores_on_diff_indices(
480480
// CHECK: scf.reduce
481481
// CHECK: }
482482
// CHECK: memref.dealloc [[SUM]]
483+
484+
// -----
485+
486+
func.func @fuse_same_indices_by_affine_apply(
487+
%A: memref<2x2xf32>, %B: memref<2x2xf32>) {
488+
%c0 = arith.constant 0 : index
489+
%c1 = arith.constant 1 : index
490+
%c2 = arith.constant 2 : index
491+
%sum = memref.alloc() : memref<2x3xf32>
492+
scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
493+
%B_elem = memref.load %B[%i, %j] : memref<2x2xf32>
494+
%1 = affine.apply affine_map<(d0, d1) -> (d0 + d1)>(%i, %j)
495+
memref.store %B_elem, %sum[%i, %1] : memref<2x3xf32>
496+
scf.reduce
497+
}
498+
scf.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
499+
%1 = affine.apply affine_map<(d0, d1) -> (d0 + d1)>(%i, %j)
500+
%sum_elem = memref.load %sum[%i, %1] : memref<2x3xf32>
501+
%A_elem = memref.load %A[%i, %j] : memref<2x2xf32>
502+
%product = arith.mulf %sum_elem, %A_elem : f32
503+
memref.store %product, %B[%i, %j] : memref<2x2xf32>
504+
scf.reduce
505+
}
506+
memref.dealloc %sum : memref<2x3xf32>
507+
return
508+
}
509+
// CHECK: #[[$MAP:.+]] = affine_map<(d0, d1) -> (d0 + d1)>
510+
// CHECK-LABEL: fuse_same_indices_by_affine_apply
511+
// CHECK-SAME: (%[[ARG0:.+]]: memref<2x2xf32>, %[[ARG1:.+]]: memref<2x2xf32>) {
512+
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
513+
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
514+
// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index
515+
// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<2x3xf32>
516+
// CHECK-NEXT: scf.parallel (%[[ARG2:.+]], %[[ARG3:.+]]) = (%[[C0]], %[[C0]]) to (%[[C2]], %[[C2]]) step (%[[C1]], %[[C1]]) {
517+
// CHECK-NEXT: %[[S0:.+]] = memref.load %[[ARG1]][%[[ARG2]], %[[ARG3]]] : memref<2x2xf32>
518+
// CHECK-NEXT: %[[S1:.+]] = affine.apply #[[$MAP]](%[[ARG2]], %[[ARG3]])
519+
// CHECK-NEXT: memref.store %[[S0]], %[[ALLOC]][%[[ARG2]], %[[S1]]] : memref<2x3xf32>
520+
// CHECK-NEXT: %[[S2:.+]] = affine.apply #[[$MAP]](%[[ARG2]], %[[ARG3]])
521+
// CHECK-NEXT: %[[S3:.+]] = memref.load %[[ALLOC]][%[[ARG2]], %[[S2]]] : memref<2x3xf32>
522+
// CHECK-NEXT: %[[S4:.+]] = memref.load %[[ARG0]][%[[ARG2]], %[[ARG3]]] : memref<2x2xf32>
523+
// CHECK-NEXT: %[[S5:.+]] = arith.mulf %[[S3]], %[[S4]] : f32
524+
// CHECK-NEXT: memref.store %[[S5]], %[[ARG1]][%[[ARG2]], %[[ARG3]]] : memref<2x2xf32>
525+
// CHECK-NEXT: scf.reduce
526+
// CHECK-NEXT: }
527+
// CHECK-NEXT: memref.dealloc %[[ALLOC]] : memref<2x3xf32>
528+
// CHECK-NEXT: return

0 commit comments

Comments
 (0)