Skip to content

Commit b40461d

Browse files
committed
Compute affine expression bounds
1 parent 2b2e860 commit b40461d

File tree

7 files changed

+675
-0
lines changed

7 files changed

+675
-0
lines changed
Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
//===- AffineExprBounds.h - Compute bounds of affine expressions *- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This header file defines an analysis of affine expressions to compute their
10+
// ranges (lower/upper bounds) in a given context.
11+
//
12+
//===----------------------------------------------------------------------===//
13+
#ifndef MLIR_ANALYSIS_AFFINEEXPRBOUNDS_H
14+
#define MLIR_ANALYSIS_AFFINEEXPRBOUNDS_H
15+
16+
#include "mlir/IR/AffineExprVisitor.h"
17+
#include "mlir/IR/Attributes.h"
18+
#include "mlir/IR/BuiltinAttributes.h"
19+
#include "mlir/Interfaces/InferIntRangeInterface.h"
20+
21+
#include "mlir/IR/AffineExpr.h"
22+
#include "mlir/IR/AffineMap.h"
23+
#include "mlir/Support/LogicalResult.h"
24+
25+
using namespace mlir;
26+
27+
/// This visitor computes the bounds of affine expressions, using as context the
28+
/// bounds of the dimensions of the expression.
29+
///
30+
/// Example:
31+
/// Given bounds 0 <= d0 <= 99 and 0 <= d1 <= 199, we can compute the bounds
32+
/// of the following expression:
33+
/// lb(2 * d0 + 3 * d1) = 0
34+
/// ub(2 * d0 + 3 * d1) = 795
35+
///
36+
/// * The bounds given in the context are inclusive, and the bounds returned
37+
/// are also inclusive.
38+
/// * If bounds are not available for a dimension, std::nullopt can be used
39+
/// instead. The bounds of an expression that involves it will be std::nullopt.
40+
/// * Limitations:
41+
/// - Parametric expressions (using symbols) are not supported.
42+
/// - Unsigned FloorDiv is currently not supported.
43+
class AffineExprBoundsVisitor
44+
: public AffineExprVisitor<AffineExprBoundsVisitor, LogicalResult> {
45+
public:
46+
/// Initialize the context (bounds) with APInt. All bounds must have the same
47+
/// signedness and bit width.
48+
AffineExprBoundsVisitor(ArrayRef<std::optional<APInt>> constLowerBounds,
49+
ArrayRef<std::optional<APInt>> constUpperBounds,
50+
bool boundsSigned, uint64_t bitWidth,
51+
MLIRContext *context);
52+
53+
/// Initialize the context (bounds) with 64-bit signed integers. This allows
54+
/// to directly map index-type values such as Linalg op bounds, which are
55+
/// represented as int64_t.
56+
AffineExprBoundsVisitor(ArrayRef<std::optional<int64_t>> constLowerBounds,
57+
ArrayRef<std::optional<int64_t>> constUpperBounds,
58+
MLIRContext *context);
59+
60+
/// Get the upper bound of \p expr using the context bounds.
61+
std::optional<APInt> getUpperBound(AffineExpr expr);
62+
std::optional<int64_t> getIndexUpperBound(AffineExpr expr);
63+
64+
/// Get the lower bound of \p expr using the context bounds.
65+
std::optional<APInt> getLowerBound(AffineExpr expr);
66+
std::optional<int64_t> getIndexLowerBound(AffineExpr expr);
67+
68+
// These methods are directly called by the AffineExprVisitor base class.
69+
LogicalResult visitMulExpr(AffineBinaryOpExpr expr);
70+
LogicalResult visitAddExpr(AffineBinaryOpExpr expr);
71+
LogicalResult visitDimExpr(AffineDimExpr expr);
72+
LogicalResult visitSymbolExpr(AffineSymbolExpr expr);
73+
LogicalResult visitConstantExpr(AffineConstantExpr expr);
74+
LogicalResult visitCeilDivExpr(AffineBinaryOpExpr expr);
75+
LogicalResult visitFloorDivExpr(AffineBinaryOpExpr expr);
76+
LogicalResult visitModExpr(AffineBinaryOpExpr expr);
77+
78+
private:
79+
bool boundsSigned;
80+
uint64_t bitWidth;
81+
void
82+
inferBinOpRange(AffineBinaryOpExpr expr,
83+
std::function<ConstantIntRanges(ArrayRef<ConstantIntRanges>)>
84+
opInference);
85+
86+
/// Bounds that have been computed for subexpressions are memoized and reused.
87+
llvm::DenseMap<AffineExpr, APInt> lb;
88+
llvm::DenseMap<AffineExpr, APInt> ub;
89+
};
90+
91+
#endif // MLIR_ANALYSIS_AFFINEEXPRBOUNDS_H
Lines changed: 183 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,183 @@
1+
//===- AffineExprBounds.h - Compute bounds of affine expressions *- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This file implements an analysis of affine expressions to compute their
10+
// ranges (lower/upper bounds) in a given context.
11+
//
12+
//===----------------------------------------------------------------------===//
13+
#include "mlir/Analysis/AffineExprBounds.h"
14+
15+
#include "mlir/IR/AffineExprVisitor.h"
16+
#include "mlir/IR/AffineMap.h"
17+
#include "mlir/IR/BuiltinAttributes.h"
18+
#include "mlir/Interfaces/InferIntRangeInterface.h"
19+
#include "mlir/Interfaces/Utils/InferIntRangeCommon.h"
20+
21+
#include <cstdint>
22+
23+
using namespace mlir;
24+
25+
AffineExprBoundsVisitor::AffineExprBoundsVisitor(
26+
ArrayRef<std::optional<APInt>> constLowerBounds,
27+
ArrayRef<std::optional<APInt>> constUpperBounds, bool boundsSigned,
28+
uint64_t bitWidth, MLIRContext *context)
29+
: boundsSigned(boundsSigned), bitWidth(bitWidth) {
30+
assert(constLowerBounds.size() == constUpperBounds.size());
31+
for (unsigned i = 0; i < constLowerBounds.size(); i++) {
32+
if (constLowerBounds[i].has_value()) {
33+
lb[getAffineDimExpr(i, context)] = constLowerBounds[i].value();
34+
}
35+
if (constUpperBounds[i].has_value()) {
36+
ub[getAffineDimExpr(i, context)] = constUpperBounds[i].value();
37+
}
38+
}
39+
}
40+
41+
AffineExprBoundsVisitor::AffineExprBoundsVisitor(
42+
ArrayRef<std::optional<int64_t>> constLowerBounds,
43+
ArrayRef<std::optional<int64_t>> constUpperBounds, MLIRContext *context) {
44+
assert(constLowerBounds.size() == constUpperBounds.size());
45+
// Convert int64_ts to APInts.
46+
for (unsigned i = 0; i < constLowerBounds.size(); i++) {
47+
if (constLowerBounds[i].has_value()) {
48+
lb[getAffineDimExpr(i, context)] =
49+
APInt(64, constLowerBounds[i].value(), /*isSigned=*/true);
50+
}
51+
if (constUpperBounds[i].has_value()) {
52+
ub[getAffineDimExpr(i, context)] =
53+
APInt(64, constUpperBounds[i].value(), /*isSigned=*/true);
54+
}
55+
}
56+
}
57+
58+
std::optional<APInt> AffineExprBoundsVisitor::getUpperBound(AffineExpr expr) {
59+
// Use memoized bound if available.
60+
auto i = ub.find(expr);
61+
if (i != ub.end()) {
62+
return i->second;
63+
}
64+
// Compute the bound otherwise.
65+
if (failed(walkPostOrder(expr))) {
66+
return std::nullopt;
67+
}
68+
return ub[expr];
69+
}
70+
71+
std::optional<APInt> AffineExprBoundsVisitor::getLowerBound(AffineExpr expr) {
72+
// Use memoized bound if available.
73+
auto i = lb.find(expr);
74+
if (i != lb.end()) {
75+
return i->second;
76+
}
77+
// Compute the bound otherwise.
78+
if (failed(walkPostOrder(expr))) {
79+
return std::nullopt;
80+
}
81+
return lb[expr];
82+
}
83+
84+
std::optional<int64_t>
85+
AffineExprBoundsVisitor::getIndexUpperBound(AffineExpr expr) {
86+
std::optional<APInt> apIntResult = getUpperBound(expr);
87+
if (!apIntResult)
88+
return std::nullopt;
89+
90+
return apIntResult->getSExtValue();
91+
}
92+
93+
std::optional<int64_t>
94+
AffineExprBoundsVisitor::getIndexLowerBound(AffineExpr expr) {
95+
std::optional<APInt> apIntResult = getLowerBound(expr);
96+
if (!apIntResult)
97+
return std::nullopt;
98+
99+
return apIntResult->getSExtValue();
100+
}
101+
102+
ConstantIntRanges getRange(APInt lb, APInt ub, bool boundsSigned) {
103+
return ConstantIntRanges::range(lb, ub, boundsSigned);
104+
}
105+
106+
/// Wrapper around the intrange::infer* functions that infers the range of
107+
/// binary operations on two ranges.
108+
void AffineExprBoundsVisitor::inferBinOpRange(
109+
AffineBinaryOpExpr expr,
110+
std::function<ConstantIntRanges(ArrayRef<ConstantIntRanges>)> opInference) {
111+
ConstantIntRanges lhsRange =
112+
getRange(lb[expr.getLHS()], ub[expr.getLHS()], boundsSigned);
113+
ConstantIntRanges rhsRange =
114+
getRange(lb[expr.getRHS()], ub[expr.getRHS()], boundsSigned);
115+
ConstantIntRanges result = opInference({lhsRange, rhsRange});
116+
117+
lb[expr] = (boundsSigned) ? result.smin() : result.umin();
118+
ub[expr] = (boundsSigned) ? result.smax() : result.umax();
119+
}
120+
121+
// Visitor method overrides.
122+
LogicalResult AffineExprBoundsVisitor::visitMulExpr(AffineBinaryOpExpr expr) {
123+
inferBinOpRange(expr, [](ArrayRef<ConstantIntRanges> ranges) {
124+
return intrange::inferMul(ranges);
125+
});
126+
return success();
127+
}
128+
LogicalResult AffineExprBoundsVisitor::visitAddExpr(AffineBinaryOpExpr expr) {
129+
inferBinOpRange(expr, [](ArrayRef<ConstantIntRanges> ranges) {
130+
return intrange::inferAdd(ranges);
131+
});
132+
return success();
133+
}
134+
LogicalResult
135+
AffineExprBoundsVisitor::visitCeilDivExpr(AffineBinaryOpExpr expr) {
136+
inferBinOpRange(
137+
expr, [boundsSigned = boundsSigned](ArrayRef<ConstantIntRanges> ranges) {
138+
if (boundsSigned) {
139+
return intrange::inferCeilDivS(ranges);
140+
}
141+
return intrange::inferCeilDivU(ranges);
142+
});
143+
return success();
144+
}
145+
LogicalResult
146+
AffineExprBoundsVisitor::visitFloorDivExpr(AffineBinaryOpExpr expr) {
147+
// There is no inferFloorDivU in the intrange library. We only offer
148+
// computation of bounds for signed floordiv operations.
149+
if (boundsSigned) {
150+
inferBinOpRange(expr, [](ArrayRef<ConstantIntRanges> ranges) {
151+
return intrange::inferFloorDivS(ranges);
152+
});
153+
return success();
154+
}
155+
return failure();
156+
}
157+
LogicalResult AffineExprBoundsVisitor::visitModExpr(AffineBinaryOpExpr expr) {
158+
inferBinOpRange(
159+
expr, [boundsSigned = boundsSigned](ArrayRef<ConstantIntRanges> ranges) {
160+
if (boundsSigned) {
161+
return intrange::inferRemS(ranges);
162+
}
163+
return intrange::inferRemU(ranges);
164+
});
165+
return success();
166+
}
167+
LogicalResult AffineExprBoundsVisitor::visitDimExpr(AffineDimExpr expr) {
168+
if (lb.find(expr) == lb.end() || ub.find(expr) == ub.end()) {
169+
return failure();
170+
}
171+
return success();
172+
}
173+
LogicalResult AffineExprBoundsVisitor::visitSymbolExpr(AffineSymbolExpr expr) {
174+
return failure();
175+
}
176+
LogicalResult
177+
AffineExprBoundsVisitor::visitConstantExpr(AffineConstantExpr expr) {
178+
APInt apIntVal =
179+
APInt(bitWidth, static_cast<uint64_t>(expr.getValue()), boundsSigned);
180+
lb[expr] = apIntVal;
181+
ub[expr] = apIntVal;
182+
return success();
183+
}

mlir/lib/Analysis/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ set(LLVM_OPTIONAL_SOURCES
2121
add_subdirectory(Presburger)
2222

2323
add_mlir_library(MLIRAnalysis
24+
AffineExprBounds.cpp
2425
AliasAnalysis.cpp
2526
CallGraph.cpp
2627
DataFlowFramework.cpp

0 commit comments

Comments
 (0)