Skip to content

[FXML-5704] Compute affine expression bounds #482

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Mar 5, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 91 additions & 0 deletions mlir/include/mlir/Analysis/AffineExprBounds.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
//===- AffineExprBounds.h - Compute bounds of affine expressions *- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This header file defines an analysis of affine expressions to compute their
// ranges (lower/upper bounds) in a given context.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_ANALYSIS_AFFINEEXPRBOUNDS_H
#define MLIR_ANALYSIS_AFFINEEXPRBOUNDS_H

#include "mlir/IR/AffineExprVisitor.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/Interfaces/InferIntRangeInterface.h"

#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/Support/LogicalResult.h"

using namespace mlir;

/// This visitor computes the bounds of affine expressions, using as context the
/// bounds of the dimensions of the expression.
///
/// Example:
/// Given bounds 0 <= d0 <= 99 and 0 <= d1 <= 199, we can compute the bounds
/// of the following expression:
/// lb(2 * d0 + 3 * d1) = 0
/// ub(2 * d0 + 3 * d1) = 795
///
/// * The bounds given in the context are inclusive, and the bounds returned
/// are also inclusive.
/// * If bounds are not available for a dimension, std::nullopt can be used
/// instead. The bounds of an expression that involves it will be std::nullopt.
/// * Limitations:
/// - Parametric expressions (using symbols) are not supported.
/// - Unsigned FloorDiv is currently not supported.
class AffineExprBoundsVisitor
: public AffineExprVisitor<AffineExprBoundsVisitor, LogicalResult> {
public:
/// Initialize the context (bounds) with APInt. All bounds must have the same
/// signedness and bit width.
AffineExprBoundsVisitor(ArrayRef<std::optional<APInt>> constLowerBounds,
ArrayRef<std::optional<APInt>> constUpperBounds,
bool boundsSigned, uint64_t bitWidth,
MLIRContext *context);

/// Initialize the context (bounds) with 64-bit signed integers. This allows
/// to directly map index-type values such as Linalg op bounds, which are
/// represented as int64_t.
AffineExprBoundsVisitor(ArrayRef<std::optional<int64_t>> constLowerBounds,
ArrayRef<std::optional<int64_t>> constUpperBounds,
MLIRContext *context);

/// Get the upper bound of \p expr using the context bounds.
std::optional<APInt> getUpperBound(AffineExpr expr);
std::optional<int64_t> getIndexUpperBound(AffineExpr expr);

/// Get the lower bound of \p expr using the context bounds.
std::optional<APInt> getLowerBound(AffineExpr expr);
std::optional<int64_t> getIndexLowerBound(AffineExpr expr);

// These methods are directly called by the AffineExprVisitor base class.
LogicalResult visitMulExpr(AffineBinaryOpExpr expr);
LogicalResult visitAddExpr(AffineBinaryOpExpr expr);
LogicalResult visitDimExpr(AffineDimExpr expr);
LogicalResult visitSymbolExpr(AffineSymbolExpr expr);
LogicalResult visitConstantExpr(AffineConstantExpr expr);
LogicalResult visitCeilDivExpr(AffineBinaryOpExpr expr);
LogicalResult visitFloorDivExpr(AffineBinaryOpExpr expr);
LogicalResult visitModExpr(AffineBinaryOpExpr expr);

private:
bool boundsSigned;
uint64_t bitWidth;
void inferBinOpRange(
AffineBinaryOpExpr expr,
const std::function<ConstantIntRanges(ArrayRef<ConstantIntRanges>)>
&opInference);

/// Bounds that have been computed for subexpressions are memoized and reused.
llvm::DenseMap<AffineExpr, APInt> lb;
llvm::DenseMap<AffineExpr, APInt> ub;
};

#endif // MLIR_ANALYSIS_AFFINEEXPRBOUNDS_H
186 changes: 186 additions & 0 deletions mlir/lib/Analysis/AffineExprBounds.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
//===- AffineExprBounds.h - Compute bounds of affine expressions *- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements an analysis of affine expressions to compute their
// ranges (lower/upper bounds) in a given context.
//
//===----------------------------------------------------------------------===//
#include "mlir/Analysis/AffineExprBounds.h"

#include "mlir/IR/AffineExprVisitor.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/Interfaces/InferIntRangeInterface.h"
#include "mlir/Interfaces/Utils/InferIntRangeCommon.h"
#include "llvm/ADT/APInt.h"

#include <cstdint>

using namespace mlir;

AffineExprBoundsVisitor::AffineExprBoundsVisitor(
ArrayRef<std::optional<APInt>> constLowerBounds,
ArrayRef<std::optional<APInt>> constUpperBounds, bool boundsSigned,
uint64_t bitWidth, MLIRContext *context)
: boundsSigned(boundsSigned), bitWidth(bitWidth) {
assert(constLowerBounds.size() == constUpperBounds.size());
for (unsigned i = 0; i < constLowerBounds.size(); i++) {
if (constLowerBounds[i].has_value()) {
lb[getAffineDimExpr(i, context)] = constLowerBounds[i].value();
}
if (constUpperBounds[i].has_value()) {
ub[getAffineDimExpr(i, context)] = constUpperBounds[i].value();
}
}
}

AffineExprBoundsVisitor::AffineExprBoundsVisitor(
ArrayRef<std::optional<int64_t>> constLowerBounds,
ArrayRef<std::optional<int64_t>> constUpperBounds, MLIRContext *context)
: boundsSigned(true), bitWidth(64) {
assert(constLowerBounds.size() == constUpperBounds.size());
// Convert int64_ts to APInts.
for (unsigned i = 0; i < constLowerBounds.size(); i++) {
if (constLowerBounds[i].has_value()) {
lb[getAffineDimExpr(i, context)] =
APInt(64, constLowerBounds[i].value(), /*isSigned=*/true);
}
if (constUpperBounds[i].has_value()) {
ub[getAffineDimExpr(i, context)] =
APInt(64, constUpperBounds[i].value(), /*isSigned=*/true);
}
}
}

std::optional<APInt> AffineExprBoundsVisitor::getUpperBound(AffineExpr expr) {
// Use memoized bound if available.
auto i = ub.find(expr);
if (i != ub.end()) {
return i->second;
}
// Compute the bound otherwise.
if (failed(walkPostOrder(expr))) {
return std::nullopt;
}
return ub[expr];
}

std::optional<APInt> AffineExprBoundsVisitor::getLowerBound(AffineExpr expr) {
// Use memoized bound if available.
auto i = lb.find(expr);
if (i != lb.end()) {
return i->second;
}
// Compute the bound otherwise.
if (failed(walkPostOrder(expr))) {
return std::nullopt;
}
return lb[expr];
}

std::optional<int64_t>
AffineExprBoundsVisitor::getIndexUpperBound(AffineExpr expr) {
std::optional<APInt> apIntResult = getUpperBound(expr);
if (!apIntResult)
return std::nullopt;

return apIntResult->getSExtValue();
}

std::optional<int64_t>
AffineExprBoundsVisitor::getIndexLowerBound(AffineExpr expr) {
std::optional<APInt> apIntResult = getLowerBound(expr);
if (!apIntResult)
return std::nullopt;

return apIntResult->getSExtValue();
}

ConstantIntRanges getRange(APInt lb, APInt ub, bool boundsSigned) {
return ConstantIntRanges::range(lb, ub, boundsSigned);
}

/// Wrapper around the intrange::infer* functions that infers the range of
/// binary operations on two ranges.
void AffineExprBoundsVisitor::inferBinOpRange(
AffineBinaryOpExpr expr,
const std::function<ConstantIntRanges(ArrayRef<ConstantIntRanges>)>
&opInference) {
ConstantIntRanges lhsRange =
getRange(lb[expr.getLHS()], ub[expr.getLHS()], boundsSigned);
ConstantIntRanges rhsRange =
getRange(lb[expr.getRHS()], ub[expr.getRHS()], boundsSigned);
ConstantIntRanges result = opInference({lhsRange, rhsRange});

lb[expr] = (boundsSigned) ? result.smin() : result.umin();
ub[expr] = (boundsSigned) ? result.smax() : result.umax();
}

// Visitor method overrides.
LogicalResult AffineExprBoundsVisitor::visitMulExpr(AffineBinaryOpExpr expr) {
inferBinOpRange(expr, [](ArrayRef<ConstantIntRanges> ranges) {
return intrange::inferMul(ranges);
});
return success();
}
LogicalResult AffineExprBoundsVisitor::visitAddExpr(AffineBinaryOpExpr expr) {
inferBinOpRange(expr, [](ArrayRef<ConstantIntRanges> ranges) {
return intrange::inferAdd(ranges);
});
return success();
}
LogicalResult
AffineExprBoundsVisitor::visitCeilDivExpr(AffineBinaryOpExpr expr) {
inferBinOpRange(
expr, [boundsSigned = boundsSigned](ArrayRef<ConstantIntRanges> ranges) {
if (boundsSigned) {
return intrange::inferCeilDivS(ranges);
}
return intrange::inferCeilDivU(ranges);
});
return success();
}
LogicalResult
AffineExprBoundsVisitor::visitFloorDivExpr(AffineBinaryOpExpr expr) {
// There is no inferFloorDivU in the intrange library. We only offer
// computation of bounds for signed floordiv operations.
if (boundsSigned) {
inferBinOpRange(expr, [](ArrayRef<ConstantIntRanges> ranges) {
return intrange::inferFloorDivS(ranges);
});
return success();
}
return failure();
}
LogicalResult AffineExprBoundsVisitor::visitModExpr(AffineBinaryOpExpr expr) {
inferBinOpRange(
expr, [boundsSigned = boundsSigned](ArrayRef<ConstantIntRanges> ranges) {
if (boundsSigned) {
return intrange::inferRemS(ranges);
}
return intrange::inferRemU(ranges);
});
return success();
}
LogicalResult AffineExprBoundsVisitor::visitDimExpr(AffineDimExpr expr) {
if (lb.find(expr) == lb.end() || ub.find(expr) == ub.end()) {
return failure();
}
return success();
}
LogicalResult AffineExprBoundsVisitor::visitSymbolExpr(AffineSymbolExpr expr) {
return failure();
}
LogicalResult
AffineExprBoundsVisitor::visitConstantExpr(AffineConstantExpr expr) {
APInt apIntVal =
APInt(bitWidth, static_cast<uint64_t>(expr.getValue()), boundsSigned);
lb[expr] = apIntVal;
ub[expr] = apIntVal;
return success();
}
1 change: 1 addition & 0 deletions mlir/lib/Analysis/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ set(LLVM_OPTIONAL_SOURCES
add_subdirectory(Presburger)

add_mlir_library(MLIRAnalysis
AffineExprBounds.cpp
AliasAnalysis.cpp
CallGraph.cpp
DataFlowFramework.cpp
Expand Down
Loading