Skip to content

[mlir][Interfaces][NFC] ValueBoundsConstraintSet: Pass stop condition in the constructor #86099

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,12 @@ struct ValueBoundsConstraintSet : protected ::mlir::ValueBoundsConstraintSet {
struct ScalableValueBoundsConstraintSet
: public llvm::RTTIExtends<ScalableValueBoundsConstraintSet,
detail::ValueBoundsConstraintSet> {
ScalableValueBoundsConstraintSet(MLIRContext *context, unsigned vscaleMin,
unsigned vscaleMax)
: RTTIExtends(context), vscaleMin(vscaleMin), vscaleMax(vscaleMax){};
ScalableValueBoundsConstraintSet(
MLIRContext *context,
ValueBoundsConstraintSet::StopConditionFn stopCondition,
unsigned vscaleMin, unsigned vscaleMax)
: RTTIExtends(context, stopCondition), vscaleMin(vscaleMin),
vscaleMax(vscaleMax) {};

using RTTIExtends::bound;
using RTTIExtends::StopConditionFn;
Expand Down
16 changes: 9 additions & 7 deletions mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h
Original file line number Diff line number Diff line change
Expand Up @@ -117,8 +117,9 @@ class ValueBoundsConstraintSet
///
/// The first parameter of the function is the shaped value/index-typed
/// value. The second parameter is the dimension in case of a shaped value.
using StopConditionFn =
function_ref<bool(Value, std::optional<int64_t> /*dim*/)>;
/// The third parameter is this constraint set.
using StopConditionFn = std::function<bool(
Value, std::optional<int64_t> /*dim*/, ValueBoundsConstraintSet &cstr)>;
Comment on lines +120 to +122
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This comment is not directly related to the changes in this PR, but It's not immediately clear to me if which value indicates stop/continuation. I like the walk function much more which exposes WalkResult::interrup() and WalkRedult::advance(), but I think that changing to LogicalResult and renaming this could also help.
Would be cool if we could revisit this separately.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point, I am also thinking of splitting up stopCondition into two functions, so that IR traversal boundary and values that are projected out can be specified separately. We had a use case for that recently.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm going to prepare a follow-up PR for this, after my chain of PRs has landed.


/// Compute a bound for the given index-typed value or shape dimension size.
/// The computed bound is stored in `resultMap`. The operands of the bound are
Expand Down Expand Up @@ -271,22 +272,20 @@ class ValueBoundsConstraintSet
/// An index-typed value or the dimension of a shaped-type value.
using ValueDim = std::pair<Value, int64_t>;

ValueBoundsConstraintSet(MLIRContext *ctx);
ValueBoundsConstraintSet(MLIRContext *ctx, StopConditionFn stopCondition);

/// Populates the constraint set for a value/map without actually computing
/// the bound. Returns the position for the value/map (via the return value
/// and `posOut` output parameter).
int64_t populateConstraintsSet(Value value,
std::optional<int64_t> dim = std::nullopt,
StopConditionFn stopCondition = nullptr);
std::optional<int64_t> dim = std::nullopt);
int64_t populateConstraintsSet(AffineMap map, ValueDimList mapOperands,
StopConditionFn stopCondition = nullptr,
int64_t *posOut = nullptr);

/// Iteratively process all elements on the worklist until an index-typed
/// value or shaped value meets `stopCondition`. Such values are not processed
/// any further.
void processWorklist(StopConditionFn stopCondition);
void processWorklist();

/// Bound the given column in the underlying constraint set by the given
/// expression.
Expand Down Expand Up @@ -333,6 +332,9 @@ class ValueBoundsConstraintSet

/// Builder for constructing affine expressions.
Builder builder;

/// The current stop condition function.
StopConditionFn stopCondition = nullptr;
};

} // namespace mlir
Expand Down
6 changes: 4 additions & 2 deletions mlir/lib/Dialect/Affine/Transforms/ReifyValueBounds.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,8 @@ FailureOr<OpFoldResult> mlir::affine::reifyShapedValueDimBound(
OpBuilder &b, Location loc, presburger::BoundType type, Value value,
int64_t dim, ValueBoundsConstraintSet::StopConditionFn stopCondition,
bool closedUB) {
auto reifyToOperands = [&](Value v, std::optional<int64_t> d) {
auto reifyToOperands = [&](Value v, std::optional<int64_t> d,
ValueBoundsConstraintSet &cstr) {
// We are trying to reify a bound for `value` in terms of the owning op's
// operands. Construct a stop condition that evaluates to "true" for any SSA
// value except for `value`. I.e., the bound will be computed in terms of
Expand All @@ -100,7 +101,8 @@ FailureOr<OpFoldResult> mlir::affine::reifyShapedValueDimBound(
FailureOr<OpFoldResult> mlir::affine::reifyIndexValueBound(
OpBuilder &b, Location loc, presburger::BoundType type, Value value,
ValueBoundsConstraintSet::StopConditionFn stopCondition, bool closedUB) {
auto reifyToOperands = [&](Value v, std::optional<int64_t> d) {
auto reifyToOperands = [&](Value v, std::optional<int64_t> d,
ValueBoundsConstraintSet &cstr) {
return v != value;
};
return reifyValueBound(b, loc, type, value, /*dim=*/std::nullopt,
Expand Down
6 changes: 4 additions & 2 deletions mlir/lib/Dialect/Arith/Transforms/ReifyValueBounds.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,8 @@ FailureOr<OpFoldResult> mlir::arith::reifyShapedValueDimBound(
OpBuilder &b, Location loc, presburger::BoundType type, Value value,
int64_t dim, ValueBoundsConstraintSet::StopConditionFn stopCondition,
bool closedUB) {
auto reifyToOperands = [&](Value v, std::optional<int64_t> d) {
auto reifyToOperands = [&](Value v, std::optional<int64_t> d,
ValueBoundsConstraintSet &cstr) {
// We are trying to reify a bound for `value` in terms of the owning op's
// operands. Construct a stop condition that evaluates to "true" for any SSA
// value expect for `value`. I.e., the bound will be computed in terms of
Expand All @@ -135,7 +136,8 @@ FailureOr<OpFoldResult> mlir::arith::reifyShapedValueDimBound(
FailureOr<OpFoldResult> mlir::arith::reifyIndexValueBound(
OpBuilder &b, Location loc, presburger::BoundType type, Value value,
ValueBoundsConstraintSet::StopConditionFn stopCondition, bool closedUB) {
auto reifyToOperands = [&](Value v, std::optional<int64_t> d) {
auto reifyToOperands = [&](Value v, std::optional<int64_t> d,
ValueBoundsConstraintSet &cstr) {
return v != value;
};
return reifyValueBound(b, loc, type, value, /*dim=*/std::nullopt,
Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -468,7 +468,7 @@ HoistPaddingAnalysis::getHoistedPackedTensorSizes(RewriterBase &rewriter,
FailureOr<OpFoldResult> loopUb = affine::reifyIndexValueBound(
rewriter, loc, presburger::BoundType::UB, forOp.getUpperBound(),
/*stopCondition=*/
[&](Value v, std::optional<int64_t> d) {
[&](Value v, std::optional<int64_t> d, ValueBoundsConstraintSet &cstr) {
if (v == forOp.getUpperBound())
return false;
// Compute a bound that is independent of any affine op results.
Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ struct ForOpInterface
ValueDimList boundOperands;
LogicalResult status = ValueBoundsConstraintSet::computeBound(
bound, boundOperands, BoundType::EQ, yieldedValue, dim,
[&](Value v, std::optional<int64_t> d) {
[&](Value v, std::optional<int64_t> d, ValueBoundsConstraintSet &cstr) {
// Stop when reaching a block argument of the loop body.
if (auto bbArg = llvm::dyn_cast<BlockArgument>(v))
return bbArg.getOwner()->getParentOp() == forOp;
Expand Down
21 changes: 15 additions & 6 deletions mlir/lib/Dialect/Vector/IR/ScalableValueBoundsConstraintSet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,17 +47,26 @@ ScalableValueBoundsConstraintSet::computeScalableBound(
unsigned vscaleMax, presburger::BoundType boundType, bool closedUB,
StopConditionFn stopCondition) {
using namespace presburger;

assert(vscaleMin <= vscaleMax);
ScalableValueBoundsConstraintSet scalableCstr(value.getContext(), vscaleMin,
vscaleMax);

int64_t pos = scalableCstr.populateConstraintsSet(value, dim, stopCondition);
// No stop condition specified: Keep adding constraints until the worklist
// is empty.
auto defaultStopCondition = [&](Value v, std::optional<int64_t> dim,
mlir::ValueBoundsConstraintSet &cstr) {
return false;
};

ScalableValueBoundsConstraintSet scalableCstr(
value.getContext(), stopCondition ? stopCondition : defaultStopCondition,
vscaleMin, vscaleMax);
int64_t pos = scalableCstr.populateConstraintsSet(value, dim);

// Project out all variables apart from vscale.
// This should result in constraints in terms of vscale only.
scalableCstr.projectOut(
[&](ValueDim p) { return p.first != scalableCstr.getVscaleValue(); });
auto projectOutFn = [&](ValueDim p) {
return p.first != scalableCstr.getVscaleValue();
};
scalableCstr.projectOut(projectOutFn);

assert(scalableCstr.cstr.getNumDimAndSymbolVars() ==
scalableCstr.positionToValueDim.size() &&
Expand Down
82 changes: 44 additions & 38 deletions mlir/lib/Interfaces/ValueBoundsOpInterface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,11 @@ static std::optional<int64_t> getConstantIntValue(OpFoldResult ofr) {
return std::nullopt;
}

ValueBoundsConstraintSet::ValueBoundsConstraintSet(MLIRContext *ctx)
: builder(ctx) {}
ValueBoundsConstraintSet::ValueBoundsConstraintSet(
MLIRContext *ctx, StopConditionFn stopCondition)
: builder(ctx), stopCondition(stopCondition) {
assert(stopCondition && "expected non-null stop condition");
}

char ValueBoundsConstraintSet::ID = 0;

Expand Down Expand Up @@ -193,7 +196,8 @@ static Operation *getOwnerOfValue(Value value) {
return value.getDefiningOp();
}

void ValueBoundsConstraintSet::processWorklist(StopConditionFn stopCondition) {
void ValueBoundsConstraintSet::processWorklist() {
LLVM_DEBUG(llvm::dbgs() << "Processing value bounds worklist...\n");
while (!worklist.empty()) {
int64_t pos = worklist.front();
worklist.pop();
Expand All @@ -214,13 +218,19 @@ void ValueBoundsConstraintSet::processWorklist(StopConditionFn stopCondition) {

// Do not process any further if the stop condition is met.
auto maybeDim = dim == kIndexValue ? std::nullopt : std::make_optional(dim);
if (stopCondition(value, maybeDim))
if (stopCondition(value, maybeDim, *this)) {
LLVM_DEBUG(llvm::dbgs() << "Stop condition met for: " << value
<< " (dim: " << maybeDim << ")\n");
continue;
}

// Query `ValueBoundsOpInterface` for constraints. New items may be added to
// the worklist.
auto valueBoundsOp =
dyn_cast<ValueBoundsOpInterface>(getOwnerOfValue(value));
LLVM_DEBUG(llvm::dbgs()
<< "Query value bounds for: " << value
<< " (owner: " << getOwnerOfValue(value)->getName() << ")\n");
if (valueBoundsOp) {
if (dim == kIndexValue) {
valueBoundsOp.populateBoundsForIndexValue(value, *this);
Expand All @@ -229,6 +239,7 @@ void ValueBoundsConstraintSet::processWorklist(StopConditionFn stopCondition) {
}
continue;
}
LLVM_DEBUG(llvm::dbgs() << "--> ValueBoundsOpInterface not implemented\n");

// If the op does not implement `ValueBoundsOpInterface`, check if it
// implements the `DestinationStyleOpInterface`. OpResults of such ops are
Expand Down Expand Up @@ -278,8 +289,6 @@ LogicalResult ValueBoundsConstraintSet::computeBound(
bool closedUB) {
#ifndef NDEBUG
assertValidValueDim(value, dim);
assert(!stopCondition(value, dim) &&
"stop condition should not be satisfied for starting point");
#endif // NDEBUG

int64_t ubAdjustment = closedUB ? 0 : 1;
Expand All @@ -289,9 +298,11 @@ LogicalResult ValueBoundsConstraintSet::computeBound(
// Process the backward slice of `value` (i.e., reverse use-def chain) until
// `stopCondition` is met.
ValueDim valueDim = std::make_pair(value, dim.value_or(kIndexValue));
ValueBoundsConstraintSet cstr(value.getContext());
ValueBoundsConstraintSet cstr(value.getContext(), stopCondition);
assert(!stopCondition(value, dim, cstr) &&
"stop condition should not be satisfied for starting point");
int64_t pos = cstr.insert(value, dim, /*isSymbol=*/false);
cstr.processWorklist(stopCondition);
cstr.processWorklist();

// Project out all variables (apart from `valueDim`) that do not match the
// stop condition.
Expand All @@ -301,7 +312,7 @@ LogicalResult ValueBoundsConstraintSet::computeBound(
return false;
auto maybeDim =
p.second == kIndexValue ? std::nullopt : std::make_optional(p.second);
return !stopCondition(p.first, maybeDim);
return !stopCondition(p.first, maybeDim, cstr);
});

// Compute lower and upper bounds for `valueDim`.
Expand Down Expand Up @@ -407,7 +418,7 @@ LogicalResult ValueBoundsConstraintSet::computeDependentBound(
bool closedUB) {
return computeBound(
resultMap, mapOperands, type, value, dim,
[&](Value v, std::optional<int64_t> d) {
[&](Value v, std::optional<int64_t> d, ValueBoundsConstraintSet &cstr) {
return llvm::is_contained(dependencies, std::make_pair(v, d));
},
closedUB);
Expand Down Expand Up @@ -443,7 +454,9 @@ LogicalResult ValueBoundsConstraintSet::computeIndependentBound(
// Reify bounds in terms of any independent values.
return computeBound(
resultMap, mapOperands, type, value, dim,
[&](Value v, std::optional<int64_t> d) { return isIndependent(v); },
[&](Value v, std::optional<int64_t> d, ValueBoundsConstraintSet &cstr) {
return isIndependent(v);
},
closedUB);
}

Expand Down Expand Up @@ -476,43 +489,42 @@ FailureOr<int64_t> ValueBoundsConstraintSet::computeConstantBound(
presburger::BoundType type, AffineMap map, ValueDimList operands,
StopConditionFn stopCondition, bool closedUB) {
assert(map.getNumResults() == 1 && "expected affine map with one result");
ValueBoundsConstraintSet cstr(map.getContext());

int64_t pos = 0;
if (stopCondition) {
cstr.populateConstraintsSet(map, operands, stopCondition, &pos);
} else {
// No stop condition specified: Keep adding constraints until a bound could
// be computed.
cstr.populateConstraintsSet(
map, operands,
[&](Value v, std::optional<int64_t> dim) {
return cstr.cstr.getConstantBound64(type, pos).has_value();
},
&pos);
}
// Default stop condition if none was specified: Keep adding constraints until
// a bound could be computed.
int64_t pos;
auto defaultStopCondition = [&](Value v, std::optional<int64_t> dim,
ValueBoundsConstraintSet &cstr) {
return cstr.cstr.getConstantBound64(type, pos).has_value();
};

ValueBoundsConstraintSet cstr(
map.getContext(), stopCondition ? stopCondition : defaultStopCondition);
cstr.populateConstraintsSet(map, operands, &pos);

// Compute constant bound for `valueDim`.
int64_t ubAdjustment = closedUB ? 0 : 1;
if (auto bound = cstr.cstr.getConstantBound64(type, pos))
return type == BoundType::UB ? *bound + ubAdjustment : *bound;
return failure();
}

int64_t ValueBoundsConstraintSet::populateConstraintsSet(
Value value, std::optional<int64_t> dim, StopConditionFn stopCondition) {
int64_t
ValueBoundsConstraintSet::populateConstraintsSet(Value value,
std::optional<int64_t> dim) {
#ifndef NDEBUG
assertValidValueDim(value, dim);
#endif // NDEBUG

AffineMap map =
AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0,
Builder(value.getContext()).getAffineDimExpr(0));
return populateConstraintsSet(map, {{value, dim}}, stopCondition);
return populateConstraintsSet(map, {{value, dim}});
}

int64_t ValueBoundsConstraintSet::populateConstraintsSet(
AffineMap map, ValueDimList operands, StopConditionFn stopCondition,
int64_t *posOut) {
int64_t ValueBoundsConstraintSet::populateConstraintsSet(AffineMap map,
ValueDimList operands,
int64_t *posOut) {
assert(map.getNumResults() == 1 && "expected affine map with one result");
int64_t pos = insert(/*isSymbol=*/false);
if (posOut)
Expand All @@ -533,13 +545,7 @@ int64_t ValueBoundsConstraintSet::populateConstraintsSet(

// Process the backward slice of `operands` (i.e., reverse use-def chain)
// until `stopCondition` is met.
if (stopCondition) {
processWorklist(stopCondition);
} else {
// No stop condition specified: Keep adding constraints until the worklist
// is empty.
processWorklist([](Value v, std::optional<int64_t> dim) { return false; });
}
processWorklist();

return pos;
}
Expand Down
9 changes: 6 additions & 3 deletions mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -117,14 +117,17 @@ static LogicalResult testReifyValueBounds(func::FuncOp funcOp,

// Prepare stop condition. By default, reify in terms of the op's
// operands. No stop condition is used when a constant was requested.
std::function<bool(Value, std::optional<int64_t>)> stopCondition =
[&](Value v, std::optional<int64_t> d) {
std::function<bool(Value, std::optional<int64_t>,
ValueBoundsConstraintSet & cstr)>
stopCondition = [&](Value v, std::optional<int64_t> d,
ValueBoundsConstraintSet &cstr) {
// Reify in terms of SSA values that are different from `value`.
return v != value;
};
if (reifyToFuncArgs) {
// Reify in terms of function block arguments.
stopCondition = stopCondition = [](Value v, std::optional<int64_t> d) {
stopCondition = [](Value v, std::optional<int64_t> d,
ValueBoundsConstraintSet &cstr) {
auto bbArg = dyn_cast<BlockArgument>(v);
if (!bbArg)
return false;
Expand Down