Skip to content

Commit a7349f8

Browse files
committed
[MLIR][Affine] Add missing check on fusion compute tolerance on a path
When profitability analysis can't be performed, we should still be respecting the compute tolerance specified. Refactor to pull the additional computation factor computation and check. Fixes: #54541
1 parent 8b1d384 commit a7349f8

File tree

2 files changed

+184
-46
lines changed

2 files changed

+184
-46
lines changed

mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp

Lines changed: 109 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
#include "mlir/Dialect/Affine/Analysis/AffineStructures.h"
1616
#include "mlir/Dialect/Affine/Analysis/LoopAnalysis.h"
1717
#include "mlir/Dialect/Affine/Analysis/Utils.h"
18-
#include "mlir/Dialect/Affine/IR/AffineOps.h"
1918
#include "mlir/Dialect/Affine/LoopFusionUtils.h"
2019
#include "mlir/Dialect/Affine/LoopUtils.h"
2120
#include "mlir/Dialect/Affine/Utils.h"
@@ -274,6 +273,58 @@ getDominanceFilterForPrivateMemRefRepl(Block *sliceInsertionBlock,
274273
return firstAncestor;
275274
}
276275

276+
/// Returns the amount of additional (redundant) computation that will be done
277+
/// as a fraction of the total computation if `srcForOp` is fused into
278+
/// `dstForOp` at depth `depth`. The method returns the compute cost of the
279+
/// slice and the fused nest's compute cost in the trailing output arguments.
280+
static std::optional<double> getAdditionalComputeFraction(
281+
AffineForOp srcForOp, AffineForOp dstForOp, unsigned depth,
282+
ArrayRef<ComputationSliceState> depthSliceUnions, int64_t &sliceCost,
283+
int64_t &fusedLoopNestComputeCost) {
284+
LLVM_DEBUG(llvm::dbgs() << "Determining additional compute fraction...\n";);
285+
// Compute cost of sliced and unsliced src loop nest.
286+
// Walk src loop nest and collect stats.
287+
LoopNestStats srcLoopNestStats;
288+
if (!getLoopNestStats(srcForOp, &srcLoopNestStats)) {
289+
LLVM_DEBUG(llvm::dbgs() << "Failed to get source loop nest stats.\n");
290+
return std::nullopt;
291+
}
292+
293+
// Compute cost of dst loop nest.
294+
LoopNestStats dstLoopNestStats;
295+
if (!getLoopNestStats(dstForOp, &dstLoopNestStats)) {
296+
LLVM_DEBUG(llvm::dbgs() << "Failed to get destination loop nest stats.\n");
297+
return std::nullopt;
298+
}
299+
300+
// Compute op instance count for the src loop nest without iteration slicing.
301+
uint64_t srcLoopNestCost = getComputeCost(srcForOp, srcLoopNestStats);
302+
303+
// Compute op cost for the dst loop nest.
304+
uint64_t dstLoopNestCost = getComputeCost(dstForOp, dstLoopNestStats);
305+
306+
const ComputationSliceState &slice = depthSliceUnions[depth - 1];
307+
// Skip slice union if it wasn't computed for this depth.
308+
if (slice.isEmpty()) {
309+
LLVM_DEBUG(llvm::dbgs() << "Slice wasn't computed.\n");
310+
return std::nullopt;
311+
}
312+
313+
if (!getFusionComputeCost(srcForOp, srcLoopNestStats, dstForOp,
314+
dstLoopNestStats, slice,
315+
&fusedLoopNestComputeCost)) {
316+
LLVM_DEBUG(llvm::dbgs() << "Unable to compute fusion compute cost\n");
317+
return std::nullopt;
318+
}
319+
320+
double additionalComputeFraction =
321+
fusedLoopNestComputeCost /
322+
(static_cast<double>(srcLoopNestCost) + dstLoopNestCost) -
323+
1;
324+
325+
return additionalComputeFraction;
326+
}
327+
277328
// Creates and returns a private (single-user) memref for fused loop rooted at
278329
// 'forOp', with (potentially reduced) memref size based on the memref region
279330
// written to by `storeOps` at depth 'dstLoopDepth'. 'sliceInsertionBlock'
@@ -384,20 +435,19 @@ static Value createPrivateMemRef(AffineForOp forOp,
384435
}
385436

386437
// Checks the profitability of fusing a backwards slice of the loop nest
387-
// surrounding 'srcOpInst' into the loop nest surrounding 'dstLoadOpInsts'.
388-
// The argument 'srcStoreOpInst' is used to calculate the storage reduction on
389-
// the memref being produced and consumed, which is an input to the cost model.
390-
// For producer-consumer fusion, 'srcStoreOpInst' will be the same as
391-
// 'srcOpInst', as we are slicing w.r.t to that producer. For input-reuse
392-
// fusion, 'srcOpInst' will be the src loop nest LoadOp which reads from the
393-
// same memref as dst loop nest load ops, and 'srcStoreOpInst' will be the
394-
// unique store op in the src node, which will be used to check that the write
395-
// region is the same after input-reuse fusion. Computation slices are provided
396-
// in 'depthSliceUnions' for each legal fusion depth. The maximal depth at which
397-
// fusion is legal is provided in 'maxLegalFusionDepth'. Returns true if it is
398-
// profitable to fuse the candidate loop nests. Returns false otherwise.
399-
// `dstLoopDepth` is set to the most profitable depth at which to materialize
400-
// the source loop nest slice.
438+
// `srcForOp` into the loop nest surrounding 'dstLoadOpInsts'. The argument
439+
// 'srcStoreOpInst' is used to calculate the storage reduction on the memref
440+
// being produced and consumed, which is an input to the cost model. For
441+
// producer-consumer fusion, 'srcStoreOpInst' will be the same as 'srcOpInst',
442+
// as we are slicing w.r.t to that producer. For input-reuse fusion, 'srcOpInst'
443+
// will be the src loop nest LoadOp which reads from the same memref as dst loop
444+
// nest load ops, and 'srcStoreOpInst' will be the unique store op in the src
445+
// node, which will be used to check that the write region is the same after
446+
// input-reuse fusion. Computation slices are provided in 'depthSliceUnions' for
447+
// each legal fusion depth. The maximal depth at which fusion is legal is
448+
// provided in 'maxLegalFusionDepth'. Returns true if it is profitable to fuse
449+
// the candidate loop nests. Returns false otherwise. `dstLoopDepth` is set to
450+
// the most profitable depth at which to materialize the source loop nest slice.
401451
// The profitability model executes the following steps:
402452
// *) Computes the backward computation slice at 'srcOpInst'. This
403453
// computation slice of the loop nest surrounding 'srcOpInst' is
@@ -422,15 +472,16 @@ static Value createPrivateMemRef(AffineForOp forOp,
422472
// is lower.
423473
// TODO: Extend profitability analysis to support scenarios with multiple
424474
// stores.
425-
static bool isFusionProfitable(Operation *srcOpInst, Operation *srcStoreOpInst,
475+
static bool isFusionProfitable(AffineForOp srcForOp, Operation *srcStoreOpInst,
426476
AffineForOp dstForOp,
427477
ArrayRef<ComputationSliceState> depthSliceUnions,
428478
unsigned maxLegalFusionDepth,
429479
unsigned *dstLoopDepth,
430480
double computeToleranceThreshold) {
431481
LLVM_DEBUG({
432-
llvm::dbgs() << "Checking whether fusion is profitable between src op:\n";
433-
llvm::dbgs() << ' ' << *srcOpInst << " and destination loop:\n";
482+
llvm::dbgs()
483+
<< "Checking whether fusion is profitable between source nest:\n";
484+
llvm::dbgs() << ' ' << srcForOp << " and destination nest:\n";
434485
llvm::dbgs() << dstForOp << "\n";
435486
});
436487

@@ -440,12 +491,10 @@ static bool isFusionProfitable(Operation *srcOpInst, Operation *srcStoreOpInst,
440491
}
441492

442493
// Compute cost of sliced and unsliced src loop nest.
443-
SmallVector<AffineForOp, 4> srcLoopIVs;
444-
getAffineForIVs(*srcOpInst, &srcLoopIVs);
445494

446495
// Walk src loop nest and collect stats.
447496
LoopNestStats srcLoopNestStats;
448-
if (!getLoopNestStats(srcLoopIVs[0], &srcLoopNestStats))
497+
if (!getLoopNestStats(srcForOp, &srcLoopNestStats))
449498
return false;
450499

451500
// Compute cost of dst loop nest.
@@ -467,7 +516,7 @@ static bool isFusionProfitable(Operation *srcOpInst, Operation *srcStoreOpInst,
467516
std::optional<unsigned> bestDstLoopDepth;
468517

469518
// Compute op instance count for the src loop nest without iteration slicing.
470-
uint64_t srcLoopNestCost = getComputeCost(srcLoopIVs[0], srcLoopNestStats);
519+
uint64_t srcLoopNestCost = getComputeCost(srcForOp, srcLoopNestStats);
471520

472521
// Compute src loop nest write region size.
473522
MemRefRegion srcWriteRegion(srcStoreOpInst->getLoc());
@@ -494,18 +543,21 @@ static bool isFusionProfitable(Operation *srcOpInst, Operation *srcStoreOpInst,
494543
if (slice.isEmpty())
495544
continue;
496545

546+
// Compute cost of the slice separately, i.e, the compute cost of the slice
547+
// if all outer trip counts are one.
548+
int64_t sliceCost;
549+
497550
int64_t fusedLoopNestComputeCost;
498-
if (!getFusionComputeCost(srcLoopIVs[0], srcLoopNestStats, dstForOp,
499-
dstLoopNestStats, slice,
500-
&fusedLoopNestComputeCost)) {
501-
LLVM_DEBUG(llvm::dbgs() << "Unable to compute fusion compute cost\n");
551+
552+
auto mayAdditionalComputeFraction =
553+
getAdditionalComputeFraction(srcForOp, dstForOp, i, depthSliceUnions,
554+
sliceCost, fusedLoopNestComputeCost);
555+
if (!mayAdditionalComputeFraction) {
556+
LLVM_DEBUG(llvm::dbgs()
557+
<< "Can't determine additional compute fraction.\n");
502558
continue;
503559
}
504-
505-
double additionalComputeFraction =
506-
fusedLoopNestComputeCost /
507-
(static_cast<double>(srcLoopNestCost) + dstLoopNestCost) -
508-
1;
560+
double additionalComputeFraction = *mayAdditionalComputeFraction;
509561

510562
// Determine what the slice write MemRefRegion would be, if the src loop
511563
// nest slice 'slice' were to be inserted into the dst loop nest at loop
@@ -530,14 +582,6 @@ static bool isFusionProfitable(Operation *srcOpInst, Operation *srcStoreOpInst,
530582
}
531583
int64_t sliceWriteRegionSizeBytes = *maybeSliceWriteRegionSizeBytes;
532584

533-
// If we are fusing for reuse, check that write regions remain the same.
534-
// TODO: Write region check should check sizes and offsets in
535-
// each dimension, so that we are sure they are covering the same memref
536-
// region. Also, move this out to a isMemRefRegionSuperSet helper function.
537-
if (srcOpInst != srcStoreOpInst &&
538-
sliceWriteRegionSizeBytes != srcWriteRegionSizeBytes)
539-
continue;
540-
541585
double storageReduction = static_cast<double>(srcWriteRegionSizeBytes) /
542586
static_cast<double>(sliceWriteRegionSizeBytes);
543587

@@ -595,7 +639,7 @@ static bool isFusionProfitable(Operation *srcOpInst, Operation *srcStoreOpInst,
595639
<< minFusedLoopNestComputeCost << "\n");
596640

597641
auto dstMemSize = getMemoryFootprintBytes(dstForOp);
598-
auto srcMemSize = getMemoryFootprintBytes(srcLoopIVs[0]);
642+
auto srcMemSize = getMemoryFootprintBytes(srcForOp);
599643

600644
std::optional<double> storageReduction;
601645

@@ -840,6 +884,8 @@ struct GreedyFusion {
840884
LLVM_DEBUG(llvm::dbgs()
841885
<< "Trying to fuse producer loop nest " << srcId
842886
<< " with consumer loop nest " << dstId << "\n");
887+
LLVM_DEBUG(llvm::dbgs() << "Compute tolerance threshold: "
888+
<< computeToleranceThreshold << '\n');
843889
LLVM_DEBUG(llvm::dbgs()
844890
<< "Producer loop nest:\n"
845891
<< *srcNode->op << "\n and consumer loop nest:\n"
@@ -926,6 +972,9 @@ struct GreedyFusion {
926972
continue;
927973
}
928974

975+
LLVM_DEBUG(llvm::dbgs() << "Max legal depth for fusion: "
976+
<< maxLegalFusionDepth << '\n');
977+
929978
// Check if fusion would be profitable. We skip profitability analysis
930979
// for maximal fusion since we already know the maximal legal depth to
931980
// fuse.
@@ -945,14 +994,28 @@ struct GreedyFusion {
945994
// if only one of the stores is involved the producer-consumer
946995
// relationship of the candidate loops.
947996
assert(!producerStores.empty() && "Expected producer store");
948-
if (producerStores.size() > 1)
997+
if (producerStores.size() > 1) {
949998
LLVM_DEBUG(llvm::dbgs() << "Skipping profitability analysis. Not "
950999
"supported for this case\n");
951-
else if (!isFusionProfitable(producerStores[0], producerStores[0],
952-
dstAffineForOp, depthSliceUnions,
953-
maxLegalFusionDepth, &bestDstLoopDepth,
954-
computeToleranceThreshold))
1000+
// We will still fuse if fusion obeys the specified compute
1001+
// tolerance at the max legal depth.
1002+
int64_t sliceCost;
1003+
int64_t fusedLoopNestComputeCost;
1004+
auto fraction = getAdditionalComputeFraction(
1005+
srcAffineForOp, dstAffineForOp, maxLegalFusionDepth,
1006+
depthSliceUnions, sliceCost, fusedLoopNestComputeCost);
1007+
if (!fraction || fraction > computeToleranceThreshold) {
1008+
LLVM_DEBUG(llvm::dbgs() << "Additional computation exceeds "
1009+
"compute tolerance. Not fusing.\n");
1010+
continue;
1011+
}
1012+
}
1013+
if (!isFusionProfitable(srcAffineForOp, producerStores[0],
1014+
dstAffineForOp, depthSliceUnions,
1015+
maxLegalFusionDepth, &bestDstLoopDepth,
1016+
computeToleranceThreshold)) {
9551017
continue;
1018+
}
9561019
}
9571020

9581021
assert(bestDstLoopDepth > 0 && "Unexpected loop fusion depth");
@@ -1169,7 +1232,7 @@ struct GreedyFusion {
11691232
// load op is treated as the src "store" op for fusion profitability
11701233
// purposes. The footprint of the load in the slice relative to the
11711234
// unfused source's determines reuse.
1172-
if (!isFusionProfitable(sibLoadOpInst, sibLoadOpInst, dstAffineForOp,
1235+
if (!isFusionProfitable(sibAffineForOp, sibLoadOpInst, dstAffineForOp,
11731236
depthSliceUnions, maxLegalFusionDepth,
11741237
&bestDstLoopDepth, computeToleranceThreshold))
11751238
continue;

mlir/test/Dialect/Affine/loop-fusion-4.mlir

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
// RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline='builtin.module(func.func(affine-loop-fusion{mode=producer}))' -split-input-file | FileCheck %s --check-prefix=PRODUCER-CONSUMER
2+
// RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline='builtin.module(func.func(affine-loop-fusion{compute-tolerance=0.0}))' -split-input-file | FileCheck %s --check-prefix=ZERO-TOLERANCE
23
// RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline='builtin.module(func.func(affine-loop-fusion{mode=producer maximal}))' -split-input-file | FileCheck %s --check-prefix=PRODUCER-CONSUMER-MAXIMAL
34
// RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline='builtin.module(func.func(affine-loop-fusion{maximal mode=sibling}))' -split-input-file | FileCheck %s --check-prefix=SIBLING-MAXIMAL
45
// All fusion: producer-consumer and sibling.
@@ -495,3 +496,77 @@ func.func @test_add_slice_bounds() {
495496
}
496497
return
497498
}
499+
500+
// -----
501+
502+
// From https://github.com/llvm/llvm-project/issues/54541
503+
504+
#map = affine_map<(d0) -> (d0 mod 65536)>
505+
// ZERO-TOLERANCE-LABEL: func @zero_tolerance
506+
func.func @zero_tolerance(%arg0: memref<65536xcomplex<f64>>, %arg1: memref<30x131072xi64>,
507+
%3 : memref<30xi64>,
508+
%4 : memref<30xi64>,
509+
%5 : memref<30xi64>,
510+
%6 : memref<30xi64>
511+
) {
512+
%c65536 = arith.constant 65536 : index
513+
%cst = arith.constant 0.000000e+00 : f64
514+
%cst_0 = arith.constant 0x4320000000380004 : f64
515+
%cst_1 = arith.constant 5.000000e-01 : f64
516+
%0 = memref.alloc() {alignment = 128 : i64} : memref<30x131072xi64>
517+
%1 = memref.alloc() {alignment = 128 : i64} : memref<131072xi1>
518+
%2 = memref.alloc() {alignment = 128 : i64} : memref<131072xi128>
519+
// The two nests shouldn't be fused when a zero tolerance is specified.
520+
// ZERO-TOLERANCE: affine.for %{{.*}} = 0 to 131072
521+
affine.for %arg2 = 0 to 131072 {
522+
%7 = affine.apply #map(%arg2)
523+
%8 = affine.load %arg0[%7] : memref<65536xcomplex<f64>>
524+
%9 = arith.cmpi ult, %arg2, %c65536 : index
525+
%10 = complex.im %8 : complex<f64>
526+
%11 = complex.re %8 : complex<f64>
527+
%12 = arith.select %9, %11, %10 : f64
528+
%13 = arith.cmpf olt, %12, %cst : f64
529+
%14 = arith.negf %12 : f64
530+
%15 = arith.select %13, %14, %12 : f64
531+
%16 = arith.mulf %15, %cst_0 : f64
532+
%17 = arith.addf %16, %cst_1 : f64
533+
%18 = arith.fptosi %17 : f64 to i128
534+
affine.store %18, %2[%arg2] : memref<131072xi128>
535+
affine.store %13, %1[%arg2] : memref<131072xi1>
536+
}
537+
// ZERO-TOLERANCE: affine.for %{{.*}} = 0 to 30
538+
// ZERO-TOLERANCE-NEXT: affine.for %{{.*}} = 0 to 131072
539+
affine.for %arg2 = 0 to 30 {
540+
affine.for %arg3 = 0 to 131072 {
541+
%7 = affine.load %6[%arg2] : memref<30xi64>
542+
%8 = affine.load %3[%arg2] : memref<30xi64>
543+
%9 = affine.load %5[%arg2] : memref<30xi64>
544+
%10 = affine.load %4[%arg2] : memref<30xi64>
545+
%11 = affine.load %2[%arg3] : memref<131072xi128>
546+
%12 = affine.load %1[%arg3] : memref<131072xi1>
547+
%13 = func.call @__external_reduce_barrett(%7, %8, %9, %10, %11) {outputModFac = 1 : i64} : (i64, i64, i64, i64, i128) -> i64
548+
%14 = arith.subi %7, %13 : i64
549+
%15 = arith.select %12, %14, %13 : i64
550+
affine.store %15, %0[%arg2, %arg3] : memref<30x131072xi64>
551+
}
552+
}
553+
func.call @__external_levelwise_forward_ntt(%0) : (memref<30x131072xi64>) -> ()
554+
// ZERO-TOLERANCE: affine.for %{{.*}} = 0 to 30
555+
// ZERO-TOLERANCE-NEXT: affine.for %{{.*}} = 0 to 131072
556+
affine.for %arg2 = 0 to 30 {
557+
affine.for %arg3 = 0 to 131072 {
558+
%7 = affine.load %0[%arg2, %arg3] : memref<30x131072xi64>
559+
affine.store %7, %arg1[%arg2, %arg3] : memref<30x131072xi64>
560+
}
561+
}
562+
// Under maximal fusion, just one nest.
563+
// PRODUCER-CONSUMER-MAXIMAL: affine.for %{{.*}} = 0 to 30
564+
// PRODUCER-CONSUMER-MAXIMAL-NEXT: affine.for %{{.*}} = 0 to 131072
565+
// PRODUCER-CONSUMER-MAXIMAL-NOT: affine.for %{{.*}}
566+
memref.dealloc %2 : memref<131072xi128>
567+
memref.dealloc %1 : memref<131072xi1>
568+
memref.dealloc %0 : memref<30x131072xi64>
569+
return
570+
}
571+
func.func private @__external_levelwise_forward_ntt(memref<30x131072xi64>)
572+
func.func private @__external_reduce_barrett(i64, i64, i64, i64, i128) -> i64

0 commit comments

Comments
 (0)