Skip to content

[MLIR][Affine] Add missing check on fusion compute tolerance on a path #128454

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 44 additions & 24 deletions mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
#include "mlir/Dialect/Affine/Analysis/AffineStructures.h"
#include "mlir/Dialect/Affine/Analysis/LoopAnalysis.h"
#include "mlir/Dialect/Affine/Analysis/Utils.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Affine/LoopFusionUtils.h"
#include "mlir/Dialect/Affine/LoopUtils.h"
#include "mlir/Dialect/Affine/Utils.h"
Expand Down Expand Up @@ -473,7 +472,8 @@ static Value createPrivateMemRef(AffineForOp forOp,
// is lower.
// TODO: Extend profitability analysis to support scenarios with multiple
// stores.
static bool isFusionProfitable(AffineForOp srcForOp, Operation *srcStoreOpInst,
static bool isFusionProfitable(AffineForOp srcForOp,
ArrayRef<Operation *> producerStores,
AffineForOp dstForOp,
ArrayRef<ComputationSliceState> depthSliceUnions,
unsigned maxLegalFusionDepth,
Expand Down Expand Up @@ -503,6 +503,35 @@ static bool isFusionProfitable(AffineForOp srcForOp, Operation *srcStoreOpInst,
if (!getLoopNestStats(dstForOp, &dstLoopNestStats))
return false;

// We limit profitability analysis to only scenarios with
// a single producer store for now. Note that some multi-store
// producer scenarios will still go through profitability analysis
// if only one of the stores is involved in the producer-consumer
// relationship of the candidate loops.
// TODO: Suppport multiple producer stores in profitability
// analysis.
if (producerStores.size() > 1) {
LLVM_DEBUG(llvm::dbgs() << "Limited profitability analysis. Not "
"supported for multiple producer store case.\n");
int64_t sliceCost;
int64_t fusedLoopNestComputeCost;
// We will still fuse if fusion obeys the specified compute
// tolerance at the max legal depth.
auto fraction = getAdditionalComputeFraction(
srcForOp, dstForOp, maxLegalFusionDepth, depthSliceUnions, sliceCost,
fusedLoopNestComputeCost);
if (!fraction || fraction > computeToleranceThreshold) {
LLVM_DEBUG(llvm::dbgs() << "Additional computation exceeds "
"compute tolerance. Not fusing.\n");
return false;
}
LLVM_DEBUG(llvm::dbgs()
<< "Considering fusion profitable at max legal depth.\n");
return true;
}

Operation *srcStoreOp = producerStores.front();

// Search for min cost value for 'dstLoopDepth'. At each value of
// 'dstLoopDepth' from 'maxLegalLoopDepth' to '1', compute computation slice
// bounds between 'srcOpInst' and each op in 'dstOpinsts' (taking the union
Expand All @@ -516,12 +545,9 @@ static bool isFusionProfitable(AffineForOp srcForOp, Operation *srcStoreOpInst,
// The best loop depth at which to materialize the slice.
std::optional<unsigned> bestDstLoopDepth;

// Compute op instance count for the src loop nest without iteration slicing.
uint64_t srcLoopNestCost = getComputeCost(srcForOp, srcLoopNestStats);

// Compute src loop nest write region size.
MemRefRegion srcWriteRegion(srcStoreOpInst->getLoc());
if (failed(srcWriteRegion.compute(srcStoreOpInst, /*loopDepth=*/0))) {
MemRefRegion srcWriteRegion(srcStoreOp->getLoc());
if (failed(srcWriteRegion.compute(srcStoreOp, /*loopDepth=*/0))) {
LLVM_DEBUG(llvm::dbgs()
<< "Unable to compute MemRefRegion for source operation\n");
return false;
Expand All @@ -533,7 +559,10 @@ static bool isFusionProfitable(AffineForOp srcForOp, Operation *srcStoreOpInst,
return false;
int64_t srcWriteRegionSizeBytes = *maybeSrcWriteRegionSizeBytes;

// Compute op instance count for the src loop nest.
// Compute op instance count for the src loop nest without iteration slicing.
uint64_t srcLoopNestCost = getComputeCost(srcForOp, srcLoopNestStats);

// Compute op instance count for the destination loop nest.
uint64_t dstLoopNestCost = getComputeCost(dstForOp, dstLoopNestStats);

// Evaluate all depth choices for materializing the slice in the destination
Expand Down Expand Up @@ -563,9 +592,8 @@ static bool isFusionProfitable(AffineForOp srcForOp, Operation *srcStoreOpInst,
// Determine what the slice write MemRefRegion would be, if the src loop
// nest slice 'slice' were to be inserted into the dst loop nest at loop
// depth 'i'.
MemRefRegion sliceWriteRegion(srcStoreOpInst->getLoc());
if (failed(sliceWriteRegion.compute(srcStoreOpInst, /*loopDepth=*/0,
&slice))) {
MemRefRegion sliceWriteRegion(srcStoreOp->getLoc());
if (failed(sliceWriteRegion.compute(srcStoreOp, /*loopDepth=*/0, &slice))) {
LLVM_DEBUG(llvm::dbgs()
<< "Failed to compute slice write region at loopDepth: " << i
<< "\n");
Expand Down Expand Up @@ -1025,21 +1053,13 @@ struct GreedyFusion {
cast<AffineWriteOpInterface>(op).getMemRef()))
producerStores.push_back(op);

// TODO: Suppport multiple producer stores in profitability
// analysis. We limit profitability analysis to only scenarios with
// a single producer store for now. Note that some multi-store
// producer scenarios will still go through profitability analysis
// if only one of the stores is involved the producer-consumer
// relationship of the candidate loops.
assert(!producerStores.empty() && "Expected producer store");
if (producerStores.size() > 1)
LLVM_DEBUG(llvm::dbgs() << "Skipping profitability analysis. Not "
"supported for this case\n");
else if (!isFusionProfitable(srcAffineForOp, producerStores[0],
dstAffineForOp, depthSliceUnions,
maxLegalFusionDepth, &bestDstLoopDepth,
computeToleranceThresholdToUse))
if (!isFusionProfitable(srcAffineForOp, producerStores,
dstAffineForOp, depthSliceUnions,
maxLegalFusionDepth, &bestDstLoopDepth,
computeToleranceThresholdToUse)) {
continue;
}
}

assert(bestDstLoopDepth > 0 && "Unexpected loop fusion depth");
Expand Down
78 changes: 78 additions & 0 deletions mlir/test/Dialect/Affine/loop-fusion-4.mlir
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
// RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline='builtin.module(func.func(affine-loop-fusion{mode=producer}))' -split-input-file | FileCheck %s --check-prefix=PRODUCER-CONSUMER
// RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline='builtin.module(func.func(affine-loop-fusion{compute-tolerance=0.0}))' -split-input-file | FileCheck %s --check-prefix=ZERO-TOLERANCE
// RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline='builtin.module(func.func(affine-loop-fusion{mode=producer maximal}))' -split-input-file | FileCheck %s --check-prefix=PRODUCER-CONSUMER-MAXIMAL
// RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline='builtin.module(func.func(affine-loop-fusion{maximal mode=sibling}))' -split-input-file | FileCheck %s --check-prefix=SIBLING-MAXIMAL
// All fusion: producer-consumer and sibling.
Expand Down Expand Up @@ -544,3 +545,80 @@ func.func @sibling_reduction(%input : memref<10xf32>, %output : memref<10xf32>,
// SIBLING-MAXIMAL-NEXT: affine.store
return
}

// -----

// From https://github.com/llvm/llvm-project/issues/54541

#map = affine_map<(d0) -> (d0 mod 65536)>
// ZERO-TOLERANCE-LABEL: func @zero_tolerance
func.func @zero_tolerance(%arg0: memref<65536xcomplex<f64>>, %arg1: memref<30x131072xi64>,
%3 : memref<30xi64>,
%4 : memref<30xi64>,
%5 : memref<30xi64>,
%6 : memref<30xi64>
) {
%c65536 = arith.constant 65536 : index
%cst = arith.constant 0.000000e+00 : f64
%cst_0 = arith.constant 0x4320000000380004 : f64
%cst_1 = arith.constant 5.000000e-01 : f64
%0 = memref.alloc() {alignment = 128 : i64} : memref<30x131072xi64>
%1 = memref.alloc() {alignment = 128 : i64} : memref<131072xi1>
%2 = memref.alloc() {alignment = 128 : i64} : memref<131072xi128>
// This nest nest shouldn't be fused in when a zero tolerance is specified.
// ZERO-TOLERANCE: affine.for %{{.*}} = 0 to 131072
affine.for %arg2 = 0 to 131072 {
%7 = affine.apply #map(%arg2)
%8 = affine.load %arg0[%7] : memref<65536xcomplex<f64>>
%9 = arith.cmpi ult, %arg2, %c65536 : index
%10 = complex.im %8 : complex<f64>
%11 = complex.re %8 : complex<f64>
%12 = arith.select %9, %11, %10 : f64
%13 = arith.cmpf olt, %12, %cst : f64
%14 = arith.negf %12 : f64
%15 = arith.select %13, %14, %12 : f64
%16 = arith.mulf %15, %cst_0 : f64
%17 = arith.addf %16, %cst_1 : f64
%18 = arith.fptosi %17 : f64 to i128
affine.store %18, %2[%arg2] : memref<131072xi128>
affine.store %13, %1[%arg2] : memref<131072xi1>
}
// The next two nests are fused.
// ZERO-TOLERANCE: affine.for %{{.*}} = 0 to 30
// ZERO-TOLERANCE-NEXT: affine.for %{{.*}} = 0 to 131072
// ZERO-TOLERANCE: func.call @__external_reduce_barrett
// ZERO-TOLERANCE: affine.store
// ZERO-TOLERANCE: affine.load
// ZERO-TOLERANCE-NEXT: affine.store
affine.for %arg2 = 0 to 30 {
affine.for %arg3 = 0 to 131072 {
%7 = affine.load %6[%arg2] : memref<30xi64>
%8 = affine.load %3[%arg2] : memref<30xi64>
%9 = affine.load %5[%arg2] : memref<30xi64>
%10 = affine.load %4[%arg2] : memref<30xi64>
%11 = affine.load %2[%arg3] : memref<131072xi128>
%12 = affine.load %1[%arg3] : memref<131072xi1>
%13 = func.call @__external_reduce_barrett(%7, %8, %9, %10, %11) {outputModFac = 1 : i64} : (i64, i64, i64, i64, i128) -> i64
%14 = arith.subi %7, %13 : i64
%15 = arith.select %12, %14, %13 : i64
affine.store %15, %0[%arg2, %arg3] : memref<30x131072xi64>
}
}
func.call @__external_levelwise_forward_ntt(%0) : (memref<30x131072xi64>) -> ()
affine.for %arg2 = 0 to 30 {
affine.for %arg3 = 0 to 131072 {
%7 = affine.load %0[%arg2, %arg3] : memref<30x131072xi64>
affine.store %7, %arg1[%arg2, %arg3] : memref<30x131072xi64>
}
}
// Under maximal fusion, just one nest.
// PRODUCER-CONSUMER-MAXIMAL: affine.for %{{.*}} = 0 to 30
// PRODUCER-CONSUMER-MAXIMAL-NEXT: affine.for %{{.*}} = 0 to 131072
// PRODUCER-CONSUMER-MAXIMAL-NOT: affine.for %{{.*}}
memref.dealloc %2 : memref<131072xi128>
memref.dealloc %1 : memref<131072xi1>
memref.dealloc %0 : memref<30x131072xi64>
return
}
func.func private @__external_levelwise_forward_ntt(memref<30x131072xi64>)
func.func private @__external_reduce_barrett(i64, i64, i64, i64, i128) -> i64