@@ -41,7 +41,7 @@ namespace affine {
41
41
} // namespace affine
42
42
} // namespace mlir
43
43
44
- #define DEBUG_TYPE " affine-loop- fusion"
44
+ #define DEBUG_TYPE " affine-fusion"
45
45
46
46
using namespace mlir ;
47
47
using namespace mlir ::affine;
@@ -237,29 +237,67 @@ static void sinkSequentialLoops(MemRefDependenceGraph::Node *node) {
237
237
node->op = newRootForOp;
238
238
}
239
239
240
- // Creates and returns a private (single-user) memref for fused loop rooted
241
- // at 'forOp', with (potentially reduced) memref size based on the
242
- // MemRefRegion written to by 'srcStoreOpInst' at depth 'dstLoopDepth'.
243
- // TODO: consider refactoring the common code from generateDma and
244
- // this one.
245
- static Value createPrivateMemRef (AffineForOp forOp, Operation *srcStoreOpInst,
240
+ // / Get the operation that should act as a dominance filter while replacing
241
+ // / memref uses with a private memref for which `producerStores` and
242
+ // / `sliceInsertionBlock` are provided. This effectively determines in what
243
+ // / part of the IR we should be performing the replacement.
244
+ static Operation *
245
+ getDominanceFilterForPrivateMemRefRepl (Block *sliceInsertionBlock,
246
+ ArrayRef<Operation *> producerStores) {
247
+ assert (!producerStores.empty () && " expected producer store" );
248
+
249
+ // We first find the common block that contains the producer stores and
250
+ // the slice computation. The first ancestor among the ancestors of the
251
+ // producer stores in that common block is the dominance filter to use for
252
+ // replacement.
253
+ Block *commonBlock = nullptr ;
254
+ // Find the common block of all relevant operations.
255
+ for (Operation *store : producerStores) {
256
+ Operation *otherOp =
257
+ !commonBlock ? &*sliceInsertionBlock->begin () : &*commonBlock->begin ();
258
+ commonBlock = findInnermostCommonBlockInScope (store, otherOp);
259
+ }
260
+ assert (commonBlock &&
261
+ " common block of producer stores and slice should exist" );
262
+
263
+ // Find the first ancestor among the ancestors of `producerStores` in
264
+ // `commonBlock`.
265
+ Operation *firstAncestor = nullptr ;
266
+ for (Operation *store : producerStores) {
267
+ Operation *ancestor = commonBlock->findAncestorOpInBlock (*store);
268
+ assert (ancestor && " producer store should be contained in common block" );
269
+ firstAncestor = !firstAncestor || ancestor->isBeforeInBlock (firstAncestor)
270
+ ? ancestor
271
+ : firstAncestor;
272
+ }
273
+ return firstAncestor;
274
+ }
275
+
276
+ // Creates and returns a private (single-user) memref for fused loop rooted at
277
+ // 'forOp', with (potentially reduced) memref size based on the memref region
278
+ // written to by `storeOps` at depth 'dstLoopDepth'. 'sliceInsertionBlock'
279
+ // specifies the block in which the slice was/will be inserted.
280
+ static Value createPrivateMemRef (AffineForOp forOp,
281
+ ArrayRef<Operation *> storeOps,
246
282
unsigned dstLoopDepth,
247
283
std::optional<unsigned > fastMemorySpace,
284
+ Block *sliceInsertionBlock,
248
285
uint64_t localBufSizeThreshold) {
249
- Operation *forInst = forOp.getOperation ();
286
+ assert (!storeOps.empty () && " no source stores supplied" );
287
+ Operation *srcStoreOp = storeOps[0 ];
250
288
251
289
// Create builder to insert alloc op just before 'forOp'.
252
- OpBuilder b (forInst );
290
+ OpBuilder b (forOp );
253
291
// Builder to create constants at the top level.
254
- OpBuilder top (forInst ->getParentRegion ());
292
+ OpBuilder top (forOp ->getParentRegion ());
255
293
// Create new memref type based on slice bounds.
256
- auto oldMemRef = cast<AffineWriteOpInterface>(srcStoreOpInst ).getMemRef ();
294
+ auto oldMemRef = cast<AffineWriteOpInterface>(srcStoreOp ).getMemRef ();
257
295
auto oldMemRefType = cast<MemRefType>(oldMemRef.getType ());
258
296
unsigned rank = oldMemRefType.getRank ();
259
297
260
298
// Compute MemRefRegion for 'srcStoreOpInst' at depth 'dstLoopDepth'.
261
- MemRefRegion region (srcStoreOpInst ->getLoc ());
262
- bool validRegion = succeeded (region.compute (srcStoreOpInst , dstLoopDepth));
299
+ MemRefRegion region (srcStoreOp ->getLoc ());
300
+ bool validRegion = succeeded (region.compute (srcStoreOp , dstLoopDepth));
263
301
(void )validRegion;
264
302
assert (validRegion && " unexpected memref region failure" );
265
303
SmallVector<int64_t , 4 > newShape;
@@ -332,11 +370,12 @@ static Value createPrivateMemRef(AffineForOp forOp, Operation *srcStoreOpInst,
332
370
AffineMap::get (outerIVs.size () + rank, 0 , remapExprs, forOp.getContext ());
333
371
334
372
// Replace all users of 'oldMemRef' with 'newMemRef'.
335
- LogicalResult res =
336
- replaceAllMemRefUsesWith (oldMemRef, newMemRef, {}, indexRemap,
337
- /* extraOperands=*/ outerIVs,
338
- /* symbolOperands=*/ {},
339
- /* domOpFilter=*/ &*forOp.getBody ()->begin ());
373
+ Operation *domFilter =
374
+ getDominanceFilterForPrivateMemRefRepl (sliceInsertionBlock, storeOps);
375
+ LogicalResult res = replaceAllMemRefUsesWith (
376
+ oldMemRef, newMemRef, /* extraIndices=*/ {}, indexRemap,
377
+ /* extraOperands=*/ outerIVs,
378
+ /* symbolOperands=*/ {}, domFilter);
340
379
assert (succeeded (res) &&
341
380
" replaceAllMemrefUsesWith should always succeed here" );
342
381
(void )res;
@@ -944,6 +983,10 @@ struct GreedyFusion {
944
983
945
984
// Create private memrefs.
946
985
if (!privateMemrefs.empty ()) {
986
+ // Note the block into which fusion was performed. This can be used to
987
+ // place `alloc`s that create private memrefs.
988
+ Block *sliceInsertionBlock = bestSlice.insertPoint ->getBlock ();
989
+
947
990
// Gather stores for all the private-to-be memrefs.
948
991
DenseMap<Value, SmallVector<Operation *, 4 >> privateMemRefToStores;
949
992
dstAffineForOp.walk ([&](AffineWriteOpInterface storeOp) {
@@ -962,8 +1005,8 @@ struct GreedyFusion {
962
1005
SmallVector<Operation *, 4 > &storesForMemref =
963
1006
memrefToStoresPair.second ;
964
1007
Value newMemRef = createPrivateMemRef (
965
- dstAffineForOp, storesForMemref[ 0 ] , bestDstLoopDepth,
966
- fastMemorySpace, localBufSizeThreshold);
1008
+ dstAffineForOp, storesForMemref, bestDstLoopDepth,
1009
+ fastMemorySpace, sliceInsertionBlock, localBufSizeThreshold);
967
1010
// Create new node in dependence graph for 'newMemRef' alloc op.
968
1011
unsigned newMemRefNodeId = mdg->addNode (newMemRef.getDefiningOp ());
969
1012
// Add edge from 'newMemRef' node to dstNode.
0 commit comments