15
15
#include " mlir/Dialect/Affine/Analysis/AffineStructures.h"
16
16
#include " mlir/Dialect/Affine/Analysis/LoopAnalysis.h"
17
17
#include " mlir/Dialect/Affine/Analysis/Utils.h"
18
- #include " mlir/Dialect/Affine/IR/AffineOps.h"
19
18
#include " mlir/Dialect/Affine/LoopFusionUtils.h"
20
19
#include " mlir/Dialect/Affine/LoopUtils.h"
21
20
#include " mlir/Dialect/Affine/Utils.h"
@@ -274,6 +273,58 @@ getDominanceFilterForPrivateMemRefRepl(Block *sliceInsertionBlock,
274
273
return firstAncestor;
275
274
}
276
275
276
+ // / Returns the amount of additional (redundant) computation that will be done
277
+ // / as a fraction of the total computation if `srcForOp` is fused into
278
+ // / `dstForOp` at depth `depth`. The method returns the compute cost of the
279
+ // / slice and the fused nest's compute cost in the trailing output arguments.
280
+ static std::optional<double > getAdditionalComputeFraction (
281
+ AffineForOp srcForOp, AffineForOp dstForOp, unsigned depth,
282
+ ArrayRef<ComputationSliceState> depthSliceUnions, int64_t &sliceCost,
283
+ int64_t &fusedLoopNestComputeCost) {
284
+ LLVM_DEBUG (llvm::dbgs () << " Determining additional compute fraction...\n " ;);
285
+ // Compute cost of sliced and unsliced src loop nest.
286
+ // Walk src loop nest and collect stats.
287
+ LoopNestStats srcLoopNestStats;
288
+ if (!getLoopNestStats (srcForOp, &srcLoopNestStats)) {
289
+ LLVM_DEBUG (llvm::dbgs () << " Failed to get source loop nest stats.\n " );
290
+ return std::nullopt;
291
+ }
292
+
293
+ // Compute cost of dst loop nest.
294
+ LoopNestStats dstLoopNestStats;
295
+ if (!getLoopNestStats (dstForOp, &dstLoopNestStats)) {
296
+ LLVM_DEBUG (llvm::dbgs () << " Failed to get destination loop nest stats.\n " );
297
+ return std::nullopt;
298
+ }
299
+
300
+ // Compute op instance count for the src loop nest without iteration slicing.
301
+ uint64_t srcLoopNestCost = getComputeCost (srcForOp, srcLoopNestStats);
302
+
303
+ // Compute op cost for the dst loop nest.
304
+ uint64_t dstLoopNestCost = getComputeCost (dstForOp, dstLoopNestStats);
305
+
306
+ const ComputationSliceState &slice = depthSliceUnions[depth - 1 ];
307
+ // Skip slice union if it wasn't computed for this depth.
308
+ if (slice.isEmpty ()) {
309
+ LLVM_DEBUG (llvm::dbgs () << " Slice wasn't computed.\n " );
310
+ return std::nullopt;
311
+ }
312
+
313
+ if (!getFusionComputeCost (srcForOp, srcLoopNestStats, dstForOp,
314
+ dstLoopNestStats, slice,
315
+ &fusedLoopNestComputeCost)) {
316
+ LLVM_DEBUG (llvm::dbgs () << " Unable to compute fusion compute cost\n " );
317
+ return std::nullopt;
318
+ }
319
+
320
+ double additionalComputeFraction =
321
+ fusedLoopNestComputeCost /
322
+ (static_cast <double >(srcLoopNestCost) + dstLoopNestCost) -
323
+ 1 ;
324
+
325
+ return additionalComputeFraction;
326
+ }
327
+
277
328
// Creates and returns a private (single-user) memref for fused loop rooted at
278
329
// 'forOp', with (potentially reduced) memref size based on the memref region
279
330
// written to by `storeOps` at depth 'dstLoopDepth'. 'sliceInsertionBlock'
@@ -384,20 +435,19 @@ static Value createPrivateMemRef(AffineForOp forOp,
384
435
}
385
436
386
437
// Checks the profitability of fusing a backwards slice of the loop nest
387
- // surrounding 'srcOpInst' into the loop nest surrounding 'dstLoadOpInsts'.
388
- // The argument 'srcStoreOpInst' is used to calculate the storage reduction on
389
- // the memref being produced and consumed, which is an input to the cost model.
390
- // For producer-consumer fusion, 'srcStoreOpInst' will be the same as
391
- // 'srcOpInst', as we are slicing w.r.t to that producer. For input-reuse
392
- // fusion, 'srcOpInst' will be the src loop nest LoadOp which reads from the
393
- // same memref as dst loop nest load ops, and 'srcStoreOpInst' will be the
394
- // unique store op in the src node, which will be used to check that the write
395
- // region is the same after input-reuse fusion. Computation slices are provided
396
- // in 'depthSliceUnions' for each legal fusion depth. The maximal depth at which
397
- // fusion is legal is provided in 'maxLegalFusionDepth'. Returns true if it is
398
- // profitable to fuse the candidate loop nests. Returns false otherwise.
399
- // `dstLoopDepth` is set to the most profitable depth at which to materialize
400
- // the source loop nest slice.
438
+ // `srcForOp` into the loop nest surrounding 'dstLoadOpInsts'. The argument
439
+ // 'srcStoreOpInst' is used to calculate the storage reduction on the memref
440
+ // being produced and consumed, which is an input to the cost model. For
441
+ // producer-consumer fusion, 'srcStoreOpInst' will be the same as 'srcOpInst',
442
+ // as we are slicing w.r.t to that producer. For input-reuse fusion, 'srcOpInst'
443
+ // will be the src loop nest LoadOp which reads from the same memref as dst loop
444
+ // nest load ops, and 'srcStoreOpInst' will be the unique store op in the src
445
+ // node, which will be used to check that the write region is the same after
446
+ // input-reuse fusion. Computation slices are provided in 'depthSliceUnions' for
447
+ // each legal fusion depth. The maximal depth at which fusion is legal is
448
+ // provided in 'maxLegalFusionDepth'. Returns true if it is profitable to fuse
449
+ // the candidate loop nests. Returns false otherwise. `dstLoopDepth` is set to
450
+ // the most profitable depth at which to materialize the source loop nest slice.
401
451
// The profitability model executes the following steps:
402
452
// *) Computes the backward computation slice at 'srcOpInst'. This
403
453
// computation slice of the loop nest surrounding 'srcOpInst' is
@@ -422,15 +472,16 @@ static Value createPrivateMemRef(AffineForOp forOp,
422
472
// is lower.
423
473
// TODO: Extend profitability analysis to support scenarios with multiple
424
474
// stores.
425
- static bool isFusionProfitable (Operation *srcOpInst , Operation *srcStoreOpInst,
475
+ static bool isFusionProfitable (AffineForOp srcForOp , Operation *srcStoreOpInst,
426
476
AffineForOp dstForOp,
427
477
ArrayRef<ComputationSliceState> depthSliceUnions,
428
478
unsigned maxLegalFusionDepth,
429
479
unsigned *dstLoopDepth,
430
480
double computeToleranceThreshold) {
431
481
LLVM_DEBUG ({
432
- llvm::dbgs () << " Checking whether fusion is profitable between src op:\n " ;
433
- llvm::dbgs () << ' ' << *srcOpInst << " and destination loop:\n " ;
482
+ llvm::dbgs ()
483
+ << " Checking whether fusion is profitable between source nest:\n " ;
484
+ llvm::dbgs () << ' ' << srcForOp << " and destination nest:\n " ;
434
485
llvm::dbgs () << dstForOp << " \n " ;
435
486
});
436
487
@@ -440,12 +491,10 @@ static bool isFusionProfitable(Operation *srcOpInst, Operation *srcStoreOpInst,
440
491
}
441
492
442
493
// Compute cost of sliced and unsliced src loop nest.
443
- SmallVector<AffineForOp, 4 > srcLoopIVs;
444
- getAffineForIVs (*srcOpInst, &srcLoopIVs);
445
494
446
495
// Walk src loop nest and collect stats.
447
496
LoopNestStats srcLoopNestStats;
448
- if (!getLoopNestStats (srcLoopIVs[ 0 ] , &srcLoopNestStats))
497
+ if (!getLoopNestStats (srcForOp , &srcLoopNestStats))
449
498
return false ;
450
499
451
500
// Compute cost of dst loop nest.
@@ -467,7 +516,7 @@ static bool isFusionProfitable(Operation *srcOpInst, Operation *srcStoreOpInst,
467
516
std::optional<unsigned > bestDstLoopDepth;
468
517
469
518
// Compute op instance count for the src loop nest without iteration slicing.
470
- uint64_t srcLoopNestCost = getComputeCost (srcLoopIVs[ 0 ] , srcLoopNestStats);
519
+ uint64_t srcLoopNestCost = getComputeCost (srcForOp , srcLoopNestStats);
471
520
472
521
// Compute src loop nest write region size.
473
522
MemRefRegion srcWriteRegion (srcStoreOpInst->getLoc ());
@@ -494,18 +543,21 @@ static bool isFusionProfitable(Operation *srcOpInst, Operation *srcStoreOpInst,
494
543
if (slice.isEmpty ())
495
544
continue ;
496
545
546
+ // Compute cost of the slice separately, i.e, the compute cost of the slice
547
+ // if all outer trip counts are one.
548
+ int64_t sliceCost;
549
+
497
550
int64_t fusedLoopNestComputeCost;
498
- if (!getFusionComputeCost (srcLoopIVs[0 ], srcLoopNestStats, dstForOp,
499
- dstLoopNestStats, slice,
500
- &fusedLoopNestComputeCost)) {
501
- LLVM_DEBUG (llvm::dbgs () << " Unable to compute fusion compute cost\n " );
551
+
552
+ auto mayAdditionalComputeFraction =
553
+ getAdditionalComputeFraction (srcForOp, dstForOp, i, depthSliceUnions,
554
+ sliceCost, fusedLoopNestComputeCost);
555
+ if (!mayAdditionalComputeFraction) {
556
+ LLVM_DEBUG (llvm::dbgs ()
557
+ << " Can't determine additional compute fraction.\n " );
502
558
continue ;
503
559
}
504
-
505
- double additionalComputeFraction =
506
- fusedLoopNestComputeCost /
507
- (static_cast <double >(srcLoopNestCost) + dstLoopNestCost) -
508
- 1 ;
560
+ double additionalComputeFraction = *mayAdditionalComputeFraction;
509
561
510
562
// Determine what the slice write MemRefRegion would be, if the src loop
511
563
// nest slice 'slice' were to be inserted into the dst loop nest at loop
@@ -530,14 +582,6 @@ static bool isFusionProfitable(Operation *srcOpInst, Operation *srcStoreOpInst,
530
582
}
531
583
int64_t sliceWriteRegionSizeBytes = *maybeSliceWriteRegionSizeBytes;
532
584
533
- // If we are fusing for reuse, check that write regions remain the same.
534
- // TODO: Write region check should check sizes and offsets in
535
- // each dimension, so that we are sure they are covering the same memref
536
- // region. Also, move this out to a isMemRefRegionSuperSet helper function.
537
- if (srcOpInst != srcStoreOpInst &&
538
- sliceWriteRegionSizeBytes != srcWriteRegionSizeBytes)
539
- continue ;
540
-
541
585
double storageReduction = static_cast <double >(srcWriteRegionSizeBytes) /
542
586
static_cast <double >(sliceWriteRegionSizeBytes);
543
587
@@ -595,7 +639,7 @@ static bool isFusionProfitable(Operation *srcOpInst, Operation *srcStoreOpInst,
595
639
<< minFusedLoopNestComputeCost << " \n " );
596
640
597
641
auto dstMemSize = getMemoryFootprintBytes (dstForOp);
598
- auto srcMemSize = getMemoryFootprintBytes (srcLoopIVs[ 0 ] );
642
+ auto srcMemSize = getMemoryFootprintBytes (srcForOp );
599
643
600
644
std::optional<double > storageReduction;
601
645
@@ -840,6 +884,8 @@ struct GreedyFusion {
840
884
LLVM_DEBUG (llvm::dbgs ()
841
885
<< " Trying to fuse producer loop nest " << srcId
842
886
<< " with consumer loop nest " << dstId << " \n " );
887
+ LLVM_DEBUG (llvm::dbgs () << " Compute tolerance threshold: "
888
+ << computeToleranceThreshold << ' \n ' );
843
889
LLVM_DEBUG (llvm::dbgs ()
844
890
<< " Producer loop nest:\n "
845
891
<< *srcNode->op << " \n and consumer loop nest:\n "
@@ -926,6 +972,9 @@ struct GreedyFusion {
926
972
continue ;
927
973
}
928
974
975
+ LLVM_DEBUG (llvm::dbgs () << " Max legal depth for fusion: "
976
+ << maxLegalFusionDepth << ' \n ' );
977
+
929
978
// Check if fusion would be profitable. We skip profitability analysis
930
979
// for maximal fusion since we already know the maximal legal depth to
931
980
// fuse.
@@ -945,14 +994,28 @@ struct GreedyFusion {
945
994
// if only one of the stores is involved the producer-consumer
946
995
// relationship of the candidate loops.
947
996
assert (!producerStores.empty () && " Expected producer store" );
948
- if (producerStores.size () > 1 )
997
+ if (producerStores.size () > 1 ) {
949
998
LLVM_DEBUG (llvm::dbgs () << " Skipping profitability analysis. Not "
950
999
" supported for this case\n " );
951
- else if (!isFusionProfitable (producerStores[0 ], producerStores[0 ],
952
- dstAffineForOp, depthSliceUnions,
953
- maxLegalFusionDepth, &bestDstLoopDepth,
954
- computeToleranceThreshold))
1000
+ // We will still fuse if fusion obeys the specified compute
1001
+ // tolerance at the max legal depth.
1002
+ int64_t sliceCost;
1003
+ int64_t fusedLoopNestComputeCost;
1004
+ auto fraction = getAdditionalComputeFraction (
1005
+ srcAffineForOp, dstAffineForOp, maxLegalFusionDepth,
1006
+ depthSliceUnions, sliceCost, fusedLoopNestComputeCost);
1007
+ if (!fraction || fraction > computeToleranceThreshold) {
1008
+ LLVM_DEBUG (llvm::dbgs () << " Additional computation exceeds "
1009
+ " compute tolerance. Not fusing.\n " );
1010
+ continue ;
1011
+ }
1012
+ }
1013
+ if (!isFusionProfitable (srcAffineForOp, producerStores[0 ],
1014
+ dstAffineForOp, depthSliceUnions,
1015
+ maxLegalFusionDepth, &bestDstLoopDepth,
1016
+ computeToleranceThreshold)) {
955
1017
continue ;
1018
+ }
956
1019
}
957
1020
958
1021
assert (bestDstLoopDepth > 0 && " Unexpected loop fusion depth" );
@@ -1169,7 +1232,7 @@ struct GreedyFusion {
1169
1232
// load op is treated as the src "store" op for fusion profitability
1170
1233
// purposes. The footprint of the load in the slice relative to the
1171
1234
// unfused source's determines reuse.
1172
- if (!isFusionProfitable (sibLoadOpInst , sibLoadOpInst, dstAffineForOp,
1235
+ if (!isFusionProfitable (sibAffineForOp , sibLoadOpInst, dstAffineForOp,
1173
1236
depthSliceUnions, maxLegalFusionDepth,
1174
1237
&bestDstLoopDepth, computeToleranceThreshold))
1175
1238
continue ;
0 commit comments