|
| 1 | +//===- AffineExprBounds.h - Compute bounds of affine expressions *- C++ -*-===// |
| 2 | +// |
| 3 | +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | +// See https://llvm.org/LICENSE.txt for license information. |
| 5 | +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | +// |
| 7 | +//===----------------------------------------------------------------------===// |
| 8 | +// |
| 9 | +// This file implements an analysis of affine expressions to compute their |
| 10 | +// ranges (lower/upper bounds) in a given context. |
| 11 | +// |
| 12 | +//===----------------------------------------------------------------------===// |
| 13 | +#include "mlir/Analysis/AffineExprBounds.h" |
| 14 | + |
| 15 | +#include "mlir/IR/AffineExprVisitor.h" |
| 16 | +#include "mlir/IR/AffineMap.h" |
| 17 | +#include "mlir/IR/BuiltinAttributes.h" |
| 18 | +#include "mlir/Interfaces/InferIntRangeInterface.h" |
| 19 | +#include "mlir/Interfaces/Utils/InferIntRangeCommon.h" |
| 20 | +#include "llvm/ADT/APInt.h" |
| 21 | + |
| 22 | +#include <cstdint> |
| 23 | + |
| 24 | +using namespace mlir; |
| 25 | + |
| 26 | +AffineExprBoundsVisitor::AffineExprBoundsVisitor( |
| 27 | + ArrayRef<std::optional<APInt>> constLowerBounds, |
| 28 | + ArrayRef<std::optional<APInt>> constUpperBounds, bool boundsSigned, |
| 29 | + uint64_t bitWidth, MLIRContext *context) |
| 30 | + : boundsSigned(boundsSigned), bitWidth(bitWidth) { |
| 31 | + assert(constLowerBounds.size() == constUpperBounds.size()); |
| 32 | + for (unsigned i = 0; i < constLowerBounds.size(); i++) { |
| 33 | + if (constLowerBounds[i].has_value()) { |
| 34 | + lb[getAffineDimExpr(i, context)] = constLowerBounds[i].value(); |
| 35 | + } |
| 36 | + if (constUpperBounds[i].has_value()) { |
| 37 | + ub[getAffineDimExpr(i, context)] = constUpperBounds[i].value(); |
| 38 | + } |
| 39 | + } |
| 40 | +} |
| 41 | + |
| 42 | +AffineExprBoundsVisitor::AffineExprBoundsVisitor( |
| 43 | + ArrayRef<std::optional<int64_t>> constLowerBounds, |
| 44 | + ArrayRef<std::optional<int64_t>> constUpperBounds, MLIRContext *context) |
| 45 | + : boundsSigned(true), bitWidth(64) { |
| 46 | + assert(constLowerBounds.size() == constUpperBounds.size()); |
| 47 | + // Convert int64_ts to APInts. |
| 48 | + for (unsigned i = 0; i < constLowerBounds.size(); i++) { |
| 49 | + if (constLowerBounds[i].has_value()) { |
| 50 | + lb[getAffineDimExpr(i, context)] = |
| 51 | + APInt(64, constLowerBounds[i].value(), /*isSigned=*/true); |
| 52 | + } |
| 53 | + if (constUpperBounds[i].has_value()) { |
| 54 | + ub[getAffineDimExpr(i, context)] = |
| 55 | + APInt(64, constUpperBounds[i].value(), /*isSigned=*/true); |
| 56 | + } |
| 57 | + } |
| 58 | +} |
| 59 | + |
| 60 | +std::optional<APInt> AffineExprBoundsVisitor::getUpperBound(AffineExpr expr) { |
| 61 | + // Use memoized bound if available. |
| 62 | + auto i = ub.find(expr); |
| 63 | + if (i != ub.end()) { |
| 64 | + return i->second; |
| 65 | + } |
| 66 | + // Compute the bound otherwise. |
| 67 | + if (failed(walkPostOrder(expr))) { |
| 68 | + return std::nullopt; |
| 69 | + } |
| 70 | + return ub[expr]; |
| 71 | +} |
| 72 | + |
| 73 | +std::optional<APInt> AffineExprBoundsVisitor::getLowerBound(AffineExpr expr) { |
| 74 | + // Use memoized bound if available. |
| 75 | + auto i = lb.find(expr); |
| 76 | + if (i != lb.end()) { |
| 77 | + return i->second; |
| 78 | + } |
| 79 | + // Compute the bound otherwise. |
| 80 | + if (failed(walkPostOrder(expr))) { |
| 81 | + return std::nullopt; |
| 82 | + } |
| 83 | + return lb[expr]; |
| 84 | +} |
| 85 | + |
| 86 | +std::optional<int64_t> |
| 87 | +AffineExprBoundsVisitor::getIndexUpperBound(AffineExpr expr) { |
| 88 | + std::optional<APInt> apIntResult = getUpperBound(expr); |
| 89 | + if (!apIntResult) |
| 90 | + return std::nullopt; |
| 91 | + |
| 92 | + return apIntResult->getSExtValue(); |
| 93 | +} |
| 94 | + |
| 95 | +std::optional<int64_t> |
| 96 | +AffineExprBoundsVisitor::getIndexLowerBound(AffineExpr expr) { |
| 97 | + std::optional<APInt> apIntResult = getLowerBound(expr); |
| 98 | + if (!apIntResult) |
| 99 | + return std::nullopt; |
| 100 | + |
| 101 | + return apIntResult->getSExtValue(); |
| 102 | +} |
| 103 | + |
| 104 | +ConstantIntRanges getRange(APInt lb, APInt ub, bool boundsSigned) { |
| 105 | + return ConstantIntRanges::range(lb, ub, boundsSigned); |
| 106 | +} |
| 107 | + |
| 108 | +/// Wrapper around the intrange::infer* functions that infers the range of |
| 109 | +/// binary operations on two ranges. |
| 110 | +void AffineExprBoundsVisitor::inferBinOpRange( |
| 111 | + AffineBinaryOpExpr expr, |
| 112 | + const std::function<ConstantIntRanges(ArrayRef<ConstantIntRanges>)> |
| 113 | + &opInference) { |
| 114 | + ConstantIntRanges lhsRange = |
| 115 | + getRange(lb[expr.getLHS()], ub[expr.getLHS()], boundsSigned); |
| 116 | + ConstantIntRanges rhsRange = |
| 117 | + getRange(lb[expr.getRHS()], ub[expr.getRHS()], boundsSigned); |
| 118 | + ConstantIntRanges result = opInference({lhsRange, rhsRange}); |
| 119 | + |
| 120 | + lb[expr] = (boundsSigned) ? result.smin() : result.umin(); |
| 121 | + ub[expr] = (boundsSigned) ? result.smax() : result.umax(); |
| 122 | +} |
| 123 | + |
| 124 | +// Visitor method overrides. |
| 125 | +LogicalResult AffineExprBoundsVisitor::visitMulExpr(AffineBinaryOpExpr expr) { |
| 126 | + inferBinOpRange(expr, [](ArrayRef<ConstantIntRanges> ranges) { |
| 127 | + return intrange::inferMul(ranges); |
| 128 | + }); |
| 129 | + return success(); |
| 130 | +} |
| 131 | +LogicalResult AffineExprBoundsVisitor::visitAddExpr(AffineBinaryOpExpr expr) { |
| 132 | + inferBinOpRange(expr, [](ArrayRef<ConstantIntRanges> ranges) { |
| 133 | + return intrange::inferAdd(ranges); |
| 134 | + }); |
| 135 | + return success(); |
| 136 | +} |
| 137 | +LogicalResult |
| 138 | +AffineExprBoundsVisitor::visitCeilDivExpr(AffineBinaryOpExpr expr) { |
| 139 | + inferBinOpRange( |
| 140 | + expr, [boundsSigned = boundsSigned](ArrayRef<ConstantIntRanges> ranges) { |
| 141 | + if (boundsSigned) { |
| 142 | + return intrange::inferCeilDivS(ranges); |
| 143 | + } |
| 144 | + return intrange::inferCeilDivU(ranges); |
| 145 | + }); |
| 146 | + return success(); |
| 147 | +} |
| 148 | +LogicalResult |
| 149 | +AffineExprBoundsVisitor::visitFloorDivExpr(AffineBinaryOpExpr expr) { |
| 150 | + // There is no inferFloorDivU in the intrange library. We only offer |
| 151 | + // computation of bounds for signed floordiv operations. |
| 152 | + if (boundsSigned) { |
| 153 | + inferBinOpRange(expr, [](ArrayRef<ConstantIntRanges> ranges) { |
| 154 | + return intrange::inferFloorDivS(ranges); |
| 155 | + }); |
| 156 | + return success(); |
| 157 | + } |
| 158 | + return failure(); |
| 159 | +} |
| 160 | +LogicalResult AffineExprBoundsVisitor::visitModExpr(AffineBinaryOpExpr expr) { |
| 161 | + inferBinOpRange( |
| 162 | + expr, [boundsSigned = boundsSigned](ArrayRef<ConstantIntRanges> ranges) { |
| 163 | + if (boundsSigned) { |
| 164 | + return intrange::inferRemS(ranges); |
| 165 | + } |
| 166 | + return intrange::inferRemU(ranges); |
| 167 | + }); |
| 168 | + return success(); |
| 169 | +} |
| 170 | +LogicalResult AffineExprBoundsVisitor::visitDimExpr(AffineDimExpr expr) { |
| 171 | + if (lb.find(expr) == lb.end() || ub.find(expr) == ub.end()) { |
| 172 | + return failure(); |
| 173 | + } |
| 174 | + return success(); |
| 175 | +} |
| 176 | +LogicalResult AffineExprBoundsVisitor::visitSymbolExpr(AffineSymbolExpr expr) { |
| 177 | + return failure(); |
| 178 | +} |
| 179 | +LogicalResult |
| 180 | +AffineExprBoundsVisitor::visitConstantExpr(AffineConstantExpr expr) { |
| 181 | + APInt apIntVal = |
| 182 | + APInt(bitWidth, static_cast<uint64_t>(expr.getValue()), boundsSigned); |
| 183 | + lb[expr] = apIntVal; |
| 184 | + ub[expr] = apIntVal; |
| 185 | + return success(); |
| 186 | +} |
0 commit comments