Skip to content

[mlir][transform] Plumb a simplified form of AffineMin folding into t… #145170

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jun 23, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 12 additions & 6 deletions mlir/include/mlir/Dialect/Affine/IR/AffineOps.h
Original file line number Diff line number Diff line change
Expand Up @@ -410,9 +410,11 @@ void canonicalizeSetAndOperands(IntegerSet *set,
/// other AffineApplyOps supplying those operands. The operands of the resulting
/// AffineApplyOp do not change the length of AffineApplyOp chains.
AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map,
ArrayRef<OpFoldResult> operands);
ArrayRef<OpFoldResult> operands,
bool composeAffineMin = false);
AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineExpr e,
ArrayRef<OpFoldResult> operands);
ArrayRef<OpFoldResult> operands,
bool composeAffineMin = false);

/// Constructs an AffineApplyOp that applies `map` to `operands` after composing
/// the map with the maps of any other AffineApplyOp supplying the operands,
Expand All @@ -421,16 +423,19 @@ AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineExpr e,
/// map.
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc,
AffineMap map,
ArrayRef<OpFoldResult> operands);
ArrayRef<OpFoldResult> operands,
bool composeAffineMin = false);
/// Variant of `makeComposedFoldedAffineApply` that applies to an expression.
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc,
AffineExpr expr,
ArrayRef<OpFoldResult> operands);
ArrayRef<OpFoldResult> operands,
bool composeAffineMin = false);
/// Variant of `makeComposedFoldedAffineApply` suitable for multi-result maps.
/// Note that this may create as many affine.apply operations as the map has
/// results given that affine.apply must be single-result.
SmallVector<OpFoldResult> makeComposedFoldedMultiResultAffineApply(
OpBuilder &b, Location loc, AffineMap map, ArrayRef<OpFoldResult> operands);
OpBuilder &b, Location loc, AffineMap map, ArrayRef<OpFoldResult> operands,
bool composeAffineMin = false);

/// Returns an AffineMinOp obtained by composing `map` and `operands` with
/// AffineApplyOps supplying those operands.
Expand Down Expand Up @@ -459,7 +464,8 @@ OpFoldResult makeComposedFoldedAffineMax(OpBuilder &b, Location loc,
/// terminal symbol, i.e., a symbol defined at the top level or a block/function
/// argument.
void fullyComposeAffineMapAndOperands(AffineMap *map,
SmallVectorImpl<Value> *operands);
SmallVectorImpl<Value> *operands,
bool composeAffineMin = false);

} // namespace affine
} // namespace mlir
Expand Down
135 changes: 105 additions & 30 deletions mlir/lib/Dialect/Affine/IR/AffineOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,14 @@
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/UB/IR/UBOps.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/AffineExprVisitor.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/IntegerSet.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Value.h"
#include "mlir/Interfaces/ShapedOpInterfaces.h"
#include "mlir/Interfaces/ValueBoundsOpInterface.h"
#include "mlir/Transforms/InliningUtils.h"
Expand All @@ -26,7 +28,9 @@
#include "llvm/ADT/SmallVectorExtras.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/LogicalResult.h"
#include "llvm/Support/MathExtras.h"
#include <limits>
#include <numeric>
#include <optional>

Expand Down Expand Up @@ -1042,6 +1046,62 @@ simplifyMapWithOperands(AffineMap &map, ArrayRef<Value> operands) {
map.getContext());
}

/// Assuming `dimOrSym` is a quantity in `map` that is defined by `minOp`.
/// Assuming that the quantity is of the form:
/// `affine_min(f(x, y), symbolic_cst)`.
/// This function checks that `0 < affine_min(f(x, y), symbolic_cst)` and
/// proceeds with replacing the patterns:
/// ```
/// dimOrSym.ceildiv(symbolic_cst)
/// (dimOrSym + symbolic_cst - 1).floordiv(symbolic_cst)
/// ```
/// by `1`.
///
/// Additionally, allows the caller to pass `affineMinKnownToBeNonNegative` to
/// inject static information that may not be statically discoverable.
///
/// Warning: ValueBoundsConstraintSet::computeConstantBound is needed to check
/// for the nonnegative case, if `affineMinKnownToBeNonNegative` is false.
static LogicalResult replaceAffineMinBoundingBoxExpression(
AffineMinOp minOp, AffineExpr dimOrSym, AffineMap *map,
bool affineMinKnownToBeNonNegative = false) {
auto affineMinMap = minOp.getAffineMap();
if (!affineMinKnownToBeNonNegative) {
ValueRange values = minOp->getOperands();
for (unsigned i = 0, e = affineMinMap.getNumResults(); i < e; ++i) {
AffineMap row = affineMinMap.getSubMap(ArrayRef<unsigned>{i});
FailureOr<int64_t> lowerBound =
ValueBoundsConstraintSet::computeConstantBound(
presburger::BoundType::LB, {row, values},
/*stopCondition=*/nullptr,
/*closedUB=*/true);
if (failed(lowerBound) || lowerBound.value() <= 0)
return failure();
}
}

AffineMap initialMap = *map;
for (unsigned i = 0, e = affineMinMap.getNumResults(); i != e; ++i) {
auto m = affineMinMap.getSubMap(ArrayRef<unsigned>{i});
AffineExpr expr = m.getResult(0);
if (!expr.isSymbolicOrConstant())
continue;

DenseMap<AffineExpr, AffineExpr> repl;
// dimOrSym.ceilDiv(expr) -> 1
repl[dimOrSym.ceilDiv(expr)] = getAffineConstantExpr(1, minOp.getContext());
// (dimOrSym + expr - 1).floorDiv(expr) -> 1
repl[(dimOrSym + expr - 1).floorDiv(expr)] =
getAffineConstantExpr(1, minOp.getContext());
auto newMap = map->replace(repl);
if (newMap == *map)
continue;
*map = newMap;
}

return success(*map != initialMap);
}

/// Replace all occurrences of AffineExpr at position `pos` in `map` by the
/// defining AffineApplyOp expression and operands.
/// When `dimOrSymbolPosition < dims.size()`, AffineDimExpr@[pos] is replaced.
Expand All @@ -1052,10 +1112,13 @@ simplifyMapWithOperands(AffineMap &map, ArrayRef<Value> operands) {
/// 2. `map` dim and symbols are gradually shifted to higher positions.
/// 3. Old `dim` and `sym` entries are replaced by nullptr
/// This avoids the need for any bookkeeping.
/// If `replaceAffineMin` is set to true, additionally triggers more expensive
/// replacements involving affine_min operations.
static LogicalResult replaceDimOrSym(AffineMap *map,
unsigned dimOrSymbolPosition,
SmallVectorImpl<Value> &dims,
SmallVectorImpl<Value> &syms) {
SmallVectorImpl<Value> &syms,
bool replaceAffineMin) {
MLIRContext *ctx = map->getContext();
bool isDimReplacement = (dimOrSymbolPosition < dims.size());
unsigned pos = isDimReplacement ? dimOrSymbolPosition
Expand All @@ -1064,6 +1127,13 @@ static LogicalResult replaceDimOrSym(AffineMap *map,
if (!v)
return failure();

auto minOp = v.getDefiningOp<AffineMinOp>();
if (minOp && replaceAffineMin) {
AffineExpr dimOrSym = isDimReplacement ? getAffineDimExpr(pos, ctx)
: getAffineSymbolExpr(pos, ctx);
return replaceAffineMinBoundingBoxExpression(minOp, dimOrSym, map);
}

auto affineApply = v.getDefiningOp<AffineApplyOp>();
if (!affineApply)
return failure();
Expand Down Expand Up @@ -1101,7 +1171,8 @@ static LogicalResult replaceDimOrSym(AffineMap *map,
/// iteratively. Perform canonicalization of map and operands as well as
/// AffineMap simplification. `map` and `operands` are mutated in place.
static void composeAffineMapAndOperands(AffineMap *map,
SmallVectorImpl<Value> *operands) {
SmallVectorImpl<Value> *operands,
bool composeAffineMin = false) {
if (map->getNumResults() == 0) {
canonicalizeMapAndOperands(map, operands);
*map = simplifyAffineMap(*map);
Expand All @@ -1122,7 +1193,8 @@ static void composeAffineMapAndOperands(AffineMap *map,
while (true) {
bool changed = false;
for (unsigned pos = 0; pos != dims.size() + syms.size(); ++pos)
if ((changed |= succeeded(replaceDimOrSym(map, pos, dims, syms))))
if ((changed |=
succeeded(replaceDimOrSym(map, pos, dims, syms, composeAffineMin))))
break;
if (!changed)
break;
Expand Down Expand Up @@ -1163,38 +1235,41 @@ static void composeAffineMapAndOperands(AffineMap *map,
}

void mlir::affine::fullyComposeAffineMapAndOperands(
AffineMap *map, SmallVectorImpl<Value> *operands) {
AffineMap *map, SmallVectorImpl<Value> *operands, bool composeAffineMin) {
while (llvm::any_of(*operands, [](Value v) {
return isa_and_nonnull<AffineApplyOp>(v.getDefiningOp());
})) {
composeAffineMapAndOperands(map, operands);
composeAffineMapAndOperands(map, operands, composeAffineMin);
}
}

AffineApplyOp
mlir::affine::makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map,
ArrayRef<OpFoldResult> operands) {
ArrayRef<OpFoldResult> operands,
bool composeAffineMin) {
SmallVector<Value> valueOperands;
map = foldAttributesIntoMap(b, map, operands, valueOperands);
composeAffineMapAndOperands(&map, &valueOperands);
composeAffineMapAndOperands(&map, &valueOperands, composeAffineMin);
assert(map);
return b.create<AffineApplyOp>(loc, map, valueOperands);
}

AffineApplyOp
mlir::affine::makeComposedAffineApply(OpBuilder &b, Location loc, AffineExpr e,
ArrayRef<OpFoldResult> operands) {
ArrayRef<OpFoldResult> operands,
bool composeAffineMin) {
return makeComposedAffineApply(
b, loc,
AffineMap::inferFromExprList(ArrayRef<AffineExpr>{e}, b.getContext())
.front(),
operands);
operands, composeAffineMin);
}

/// Composes the given affine map with the given list of operands, pulling in
/// the maps from any affine.apply operations that supply the operands.
static void composeMultiResultAffineMap(AffineMap &map,
SmallVectorImpl<Value> &operands) {
SmallVectorImpl<Value> &operands,
bool composeAffineMin = false) {
// Compose and canonicalize each expression in the map individually because
// composition only applies to single-result maps, collecting potentially
// duplicate operands in a single list with shifted dimensions and symbols.
Expand All @@ -1203,7 +1278,8 @@ static void composeMultiResultAffineMap(AffineMap &map,
for (unsigned i : llvm::seq<unsigned>(0, map.getNumResults())) {
SmallVector<Value> submapOperands(operands.begin(), operands.end());
AffineMap submap = map.getSubMap({i});
fullyComposeAffineMapAndOperands(&submap, &submapOperands);
fullyComposeAffineMapAndOperands(&submap, &submapOperands,
composeAffineMin);
canonicalizeMapAndOperands(&submap, &submapOperands);
unsigned numNewDims = submap.getNumDims();
submap = submap.shiftDims(dims.size()).shiftSymbols(symbols.size());
Expand All @@ -1221,10 +1297,9 @@ static void composeMultiResultAffineMap(AffineMap &map,
canonicalizeMapAndOperands(&map, &operands);
}

OpFoldResult
mlir::affine::makeComposedFoldedAffineApply(OpBuilder &b, Location loc,
AffineMap map,
ArrayRef<OpFoldResult> operands) {
OpFoldResult mlir::affine::makeComposedFoldedAffineApply(
OpBuilder &b, Location loc, AffineMap map, ArrayRef<OpFoldResult> operands,
bool composeAffineMin) {
assert(map.getNumResults() == 1 && "building affine.apply with !=1 result");

// Create new builder without a listener, so that no notification is
Expand All @@ -1236,7 +1311,7 @@ mlir::affine::makeComposedFoldedAffineApply(OpBuilder &b, Location loc,

// Create op.
AffineApplyOp applyOp =
makeComposedAffineApply(newBuilder, loc, map, operands);
makeComposedAffineApply(newBuilder, loc, map, operands, composeAffineMin);

// Get constant operands.
SmallVector<Attribute> constOperands(applyOp->getNumOperands());
Expand All @@ -1256,26 +1331,25 @@ mlir::affine::makeComposedFoldedAffineApply(OpBuilder &b, Location loc,
return llvm::getSingleElement(foldResults);
}

OpFoldResult
mlir::affine::makeComposedFoldedAffineApply(OpBuilder &b, Location loc,
AffineExpr expr,
ArrayRef<OpFoldResult> operands) {
OpFoldResult mlir::affine::makeComposedFoldedAffineApply(
OpBuilder &b, Location loc, AffineExpr expr,
ArrayRef<OpFoldResult> operands, bool composeAffineMin) {
return makeComposedFoldedAffineApply(
b, loc,
AffineMap::inferFromExprList(ArrayRef<AffineExpr>{expr}, b.getContext())
.front(),
operands);
operands, composeAffineMin);
}

SmallVector<OpFoldResult>
mlir::affine::makeComposedFoldedMultiResultAffineApply(
OpBuilder &b, Location loc, AffineMap map,
ArrayRef<OpFoldResult> operands) {
return llvm::map_to_vector(llvm::seq<unsigned>(0, map.getNumResults()),
[&](unsigned i) {
return makeComposedFoldedAffineApply(
b, loc, map.getSubMap({i}), operands);
});
OpBuilder &b, Location loc, AffineMap map, ArrayRef<OpFoldResult> operands,
bool composeAffineMin) {
return llvm::map_to_vector(
llvm::seq<unsigned>(0, map.getNumResults()), [&](unsigned i) {
return makeComposedFoldedAffineApply(b, loc, map.getSubMap({i}),
operands, composeAffineMin);
});
}

template <typename OpTy>
Expand Down Expand Up @@ -3024,7 +3098,8 @@ void AffineIfOp::build(OpBuilder &builder, OperationState &result,
/// `set` by composing the maps of such affine.apply ops with the integer
/// set constraints.
static void composeSetAndOperands(IntegerSet &set,
SmallVectorImpl<Value> &operands) {
SmallVectorImpl<Value> &operands,
bool composeAffineMin = false) {
// We will simply reuse the API of the map composition by viewing the LHSs of
// the equalities and inequalities of `set` as the affine exprs of an affine
// map. Convert to equivalent map, compose, and convert back to set.
Expand All @@ -3035,7 +3110,7 @@ static void composeSetAndOperands(IntegerSet &set,
[](Value v) { return v.getDefiningOp<AffineApplyOp>(); }))
return;

composeAffineMapAndOperands(&map, &operands);
composeAffineMapAndOperands(&map, &operands, composeAffineMin);
set = IntegerSet::get(map.getNumDims(), map.getNumSymbols(), map.getResults(),
set.getEqFlags());
}
Expand Down
5 changes: 3 additions & 2 deletions mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ SmallVector<OpFoldResult> linalg::computePaddedShape(
getDimsToSize(rewriter, indexingSizes, options);

// For each dimension in the operand's shape, iterate over indexingSizes and
// add
// add the various term contributions.
for (const auto &enResults : enumerate(indexingMap.getResults())) {
int64_t resultIndex = enResults.index();
AffineMap partialIndexingMap = indexingMap.getSubMap(
Expand Down Expand Up @@ -122,7 +122,8 @@ SmallVector<OpFoldResult> linalg::computePaddedShape(
AffineMap composedMap = projectedMap.compose(ceilMap);
OpFoldResult paddingDimOfr = affine::makeComposedFoldedAffineApply(
rewriter, loc, composedMap,
{indexingSizes[paddingDim], paddingSize});
{indexingSizes[paddingDim], paddingSize},
/*composeAffineMin=*/true);
terms.push_back(paddingDimOfr);
} else {
// Otherwise just set to paddingSize.
Expand Down
Loading
Loading