Skip to content

[MLIR][Affine] Fix private memref creation bug in affine fusion #126028

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions mlir/include/mlir/Dialect/Affine/Analysis/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -610,6 +610,14 @@ FailureOr<AffineValueMap>
simplifyConstrainedMinMaxOp(Operation *op,
FlatAffineValueConstraints constraints);

/// Find the innermost common `Block` of `a` and `b` in the affine scope
/// that `a` and `b` are part of. Return nullptr if they belong to different
/// affine scopes. Also, return nullptr if they do not have a common `Block`
/// ancestor (for eg., when they are part of the `then` and `else` regions
/// of an op that itself starts an affine scope.
mlir::Block *findInnermostCommonBlockInScope(mlir::Operation *a,
mlir::Operation *b);

} // namespace affine
} // namespace mlir

Expand Down
39 changes: 39 additions & 0 deletions mlir/lib/Dialect/Affine/Analysis/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Affine/Analysis/Utils.h"

#include "mlir/Analysis/Presburger/PresburgerRelation.h"
#include "mlir/Dialect/Affine/Analysis/AffineAnalysis.h"
#include "mlir/Dialect/Affine/Analysis/LoopAnalysis.h"
Expand Down Expand Up @@ -2297,3 +2298,41 @@ FailureOr<AffineValueMap> mlir::affine::simplifyConstrainedMinMaxOp(
affine::canonicalizeMapAndOperands(&newMap, &newOperands);
return AffineValueMap(newMap, newOperands);
}

Block *mlir::affine::findInnermostCommonBlockInScope(Operation *a,
Operation *b) {
Region *aScope = mlir::affine::getAffineScope(a);
Region *bScope = mlir::affine::getAffineScope(b);
if (aScope != bScope)
return nullptr;

// Get the block ancestry of `op` while stopping at the affine scope `aScope`
// and store them in `ancestry`.
auto getBlockAncestry = [&](Operation *op,
SmallVectorImpl<Block *> &ancestry) {
Operation *curOp = op;
do {
ancestry.push_back(curOp->getBlock());
if (curOp->getParentRegion() == aScope)
break;
curOp = curOp->getParentOp();
} while (curOp);
assert(curOp && "can't reach root op without passing through affine scope");
std::reverse(ancestry.begin(), ancestry.end());
};

SmallVector<Block *, 4> aAncestors, bAncestors;
getBlockAncestry(a, aAncestors);
getBlockAncestry(b, bAncestors);
assert(!aAncestors.empty() && !bAncestors.empty() &&
"at least one Block ancestor expected");

Block *innermostCommonBlock = nullptr;
for (unsigned a = 0, b = 0, e = aAncestors.size(), f = bAncestors.size();
a < e && b < f; ++a, ++b) {
if (aAncestors[a] != bAncestors[b])
break;
innermostCommonBlock = aAncestors[a];
}
return innermostCommonBlock;
}
83 changes: 63 additions & 20 deletions mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ namespace affine {
} // namespace affine
} // namespace mlir

#define DEBUG_TYPE "affine-loop-fusion"
#define DEBUG_TYPE "affine-fusion"

using namespace mlir;
using namespace mlir::affine;
Expand Down Expand Up @@ -237,29 +237,67 @@ static void sinkSequentialLoops(MemRefDependenceGraph::Node *node) {
node->op = newRootForOp;
}

// Creates and returns a private (single-user) memref for fused loop rooted
// at 'forOp', with (potentially reduced) memref size based on the
// MemRefRegion written to by 'srcStoreOpInst' at depth 'dstLoopDepth'.
// TODO: consider refactoring the common code from generateDma and
// this one.
static Value createPrivateMemRef(AffineForOp forOp, Operation *srcStoreOpInst,
/// Get the operation that should act as a dominance filter while replacing
/// memref uses with a private memref for which `producerStores` and
/// `sliceInsertionBlock` are provided. This effectively determines in what
/// part of the IR we should be performing the replacement.
static Operation *
getDominanceFilterForPrivateMemRefRepl(Block *sliceInsertionBlock,
ArrayRef<Operation *> producerStores) {
assert(!producerStores.empty() && "expected producer store");

// We first find the common block that contains the producer stores and
// the slice computation. The first ancestor among the ancestors of the
// producer stores in that common block is the dominance filter to use for
// replacement.
Block *commonBlock = nullptr;
// Find the common block of all relevant operations.
for (Operation *store : producerStores) {
Operation *otherOp =
!commonBlock ? &*sliceInsertionBlock->begin() : &*commonBlock->begin();
commonBlock = findInnermostCommonBlockInScope(store, otherOp);
}
assert(commonBlock &&
"common block of producer stores and slice should exist");

// Find the first ancestor among the ancestors of `producerStores` in
// `commonBlock`.
Operation *firstAncestor = nullptr;
for (Operation *store : producerStores) {
Operation *ancestor = commonBlock->findAncestorOpInBlock(*store);
assert(ancestor && "producer store should be contained in common block");
firstAncestor = !firstAncestor || ancestor->isBeforeInBlock(firstAncestor)
? ancestor
: firstAncestor;
}
return firstAncestor;
}

// Creates and returns a private (single-user) memref for fused loop rooted at
// 'forOp', with (potentially reduced) memref size based on the memref region
// written to by `storeOps` at depth 'dstLoopDepth'. 'sliceInsertionBlock'
// specifies the block in which the slice was/will be inserted.
static Value createPrivateMemRef(AffineForOp forOp,
ArrayRef<Operation *> storeOps,
unsigned dstLoopDepth,
std::optional<unsigned> fastMemorySpace,
Block *sliceInsertionBlock,
uint64_t localBufSizeThreshold) {
Operation *forInst = forOp.getOperation();
assert(!storeOps.empty() && "no source stores supplied");
Operation *srcStoreOp = storeOps[0];

// Create builder to insert alloc op just before 'forOp'.
OpBuilder b(forInst);
OpBuilder b(forOp);
// Builder to create constants at the top level.
OpBuilder top(forInst->getParentRegion());
OpBuilder top(forOp->getParentRegion());
// Create new memref type based on slice bounds.
auto oldMemRef = cast<AffineWriteOpInterface>(srcStoreOpInst).getMemRef();
auto oldMemRef = cast<AffineWriteOpInterface>(srcStoreOp).getMemRef();
auto oldMemRefType = cast<MemRefType>(oldMemRef.getType());
unsigned rank = oldMemRefType.getRank();

// Compute MemRefRegion for 'srcStoreOpInst' at depth 'dstLoopDepth'.
MemRefRegion region(srcStoreOpInst->getLoc());
bool validRegion = succeeded(region.compute(srcStoreOpInst, dstLoopDepth));
MemRefRegion region(srcStoreOp->getLoc());
bool validRegion = succeeded(region.compute(srcStoreOp, dstLoopDepth));
(void)validRegion;
assert(validRegion && "unexpected memref region failure");
SmallVector<int64_t, 4> newShape;
Expand Down Expand Up @@ -332,11 +370,12 @@ static Value createPrivateMemRef(AffineForOp forOp, Operation *srcStoreOpInst,
AffineMap::get(outerIVs.size() + rank, 0, remapExprs, forOp.getContext());

// Replace all users of 'oldMemRef' with 'newMemRef'.
LogicalResult res =
replaceAllMemRefUsesWith(oldMemRef, newMemRef, {}, indexRemap,
/*extraOperands=*/outerIVs,
/*symbolOperands=*/{},
/*domOpFilter=*/&*forOp.getBody()->begin());
Operation *domFilter =
getDominanceFilterForPrivateMemRefRepl(sliceInsertionBlock, storeOps);
LogicalResult res = replaceAllMemRefUsesWith(
oldMemRef, newMemRef, /*extraIndices=*/{}, indexRemap,
/*extraOperands=*/outerIVs,
/*symbolOperands=*/{}, domFilter);
assert(succeeded(res) &&
"replaceAllMemrefUsesWith should always succeed here");
(void)res;
Expand Down Expand Up @@ -944,6 +983,10 @@ struct GreedyFusion {

// Create private memrefs.
if (!privateMemrefs.empty()) {
// Note the block into which fusion was performed. This can be used to
// place `alloc`s that create private memrefs.
Block *sliceInsertionBlock = bestSlice.insertPoint->getBlock();

// Gather stores for all the private-to-be memrefs.
DenseMap<Value, SmallVector<Operation *, 4>> privateMemRefToStores;
dstAffineForOp.walk([&](AffineWriteOpInterface storeOp) {
Expand All @@ -962,8 +1005,8 @@ struct GreedyFusion {
SmallVector<Operation *, 4> &storesForMemref =
memrefToStoresPair.second;
Value newMemRef = createPrivateMemRef(
dstAffineForOp, storesForMemref[0], bestDstLoopDepth,
fastMemorySpace, localBufSizeThreshold);
dstAffineForOp, storesForMemref, bestDstLoopDepth,
fastMemorySpace, sliceInsertionBlock, localBufSizeThreshold);
// Create new node in dependence graph for 'newMemRef' alloc op.
unsigned newMemRefNodeId = mdg->addNode(newMemRef.getDefiningOp());
// Add edge from 'newMemRef' node to dstNode.
Expand Down
60 changes: 60 additions & 0 deletions mlir/test/Dialect/Affine/loop-fusion-4.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -285,3 +285,63 @@ module {
spirv.ReturnValue %3 : !spirv.array<8192 x f32>
}
}

// -----

// PRODUCER-CONSUMER-LABEL: func @same_memref_load_store
func.func @same_memref_load_store(%producer : memref<32xf32>, %consumer: memref<16xf32>){
%cst = arith.constant 2.000000e+00 : f32
// Source isn't removed.
// PRODUCER-CONSUMER: affine.for %{{.*}} = 0 to 32
affine.for %arg3 = 0 to 32 {
%0 = affine.load %producer[%arg3] : memref<32xf32>
%2 = arith.mulf %0, %cst : f32
affine.store %2, %producer[%arg3] : memref<32xf32>
}
affine.for %arg3 = 0 to 16 {
%0 = affine.load %producer[%arg3] : memref<32xf32>
%2 = arith.addf %0, %cst : f32
affine.store %2, %consumer[%arg3] : memref<16xf32>
}
// Fused nest.
// PRODUCER-CONSUMER: affine.for %{{.*}} = 0 to 16
// PRODUCER-CONSUMER-NEXT: affine.load %{{.*}}[%{{.*}}] : memref<32xf32>
// PRODUCER-CONSUMER-NEXT: arith.mulf
// PRODUCER-CONSUMER-NEXT: affine.store %{{.*}}, %{{.*}}[0] : memref<1xf32>
// PRODUCER-CONSUMER-NEXT: affine.load %{{.*}}[0] : memref<1xf32>
// PRODUCER-CONSUMER-NEXT: arith.addf
// PRODUCER-CONSUMER-NEXT: affine.store
// PRODUCER-CONSUMER-NEXT: }
return
}

// PRODUCER-CONSUMER-LABEL: func @same_memref_load_multiple_stores
func.func @same_memref_load_multiple_stores(%producer : memref<32xf32>, %producer_2 : memref<32xf32>, %consumer: memref<16xf32>){
%cst = arith.constant 2.000000e+00 : f32
// Source isn't removed.
// PRODUCER-CONSUMER: affine.for %{{.*}} = 0 to 32
affine.for %arg3 = 0 to 32 {
%0 = affine.load %producer[%arg3] : memref<32xf32>
%2 = arith.mulf %0, %cst : f32
affine.store %2, %producer[%arg3] : memref<32xf32>
affine.store %2, %producer_2[%arg3] : memref<32xf32>
}
affine.for %arg3 = 0 to 16 {
%0 = affine.load %producer[%arg3] : memref<32xf32>
%1 = affine.load %producer_2[%arg3] : memref<32xf32>
%2 = arith.addf %0, %1 : f32
affine.store %2, %consumer[%arg3] : memref<16xf32>
}
// Fused nest.
// PRODUCER-CONSUMER: affine.for %{{.*}} = 0 to 16
// PRODUCER-CONSUMER-NEXT: affine.load %{{.*}}[%{{.*}}] : memref<32xf32>
// PRODUCER-CONSUMER-NEXT: arith.mulf
// PRODUCER-CONSUMER-NEXT: affine.store %{{.*}}, %{{.*}}[0] : memref<1xf32>
// PRODUCER-CONSUMER-NEXT: affine.store %{{.*}}, %{{.*}}[0] : memref<1xf32>
// PRODUCER-CONSUMER-NEXT: affine.load %{{.*}}[0] : memref<1xf32>
// PRODUCER-CONSUMER-NEXT: affine.load %{{.*}}[0] : memref<1xf32>
// PRODUCER-CONSUMER-NEXT: arith.addf
// PRODUCER-CONSUMER-NEXT: affine.store
// PRODUCER-CONSUMER-NEXT: }
return
}