-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][Vector] Add utility for computing scalable value bounds #83876
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
2 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
104 changes: 104 additions & 0 deletions
104
mlir/include/mlir/Dialect/Vector/IR/ScalableValueBoundsConstraintSet.h
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,104 @@ | ||
//===- ScalableValueBoundsConstraintSet.h - Scalable Value Bounds ---------===// | ||
// | ||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://llvm.org/LICENSE.txt for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
// | ||
//===----------------------------------------------------------------------===// | ||
|
||
#ifndef MLIR_DIALECT_VECTOR_IR_SCALABLEVALUEBOUNDSCONSTRAINTSET_H | ||
#define MLIR_DIALECT_VECTOR_IR_SCALABLEVALUEBOUNDSCONSTRAINTSET_H | ||
|
||
#include "mlir/Analysis/Presburger/IntegerRelation.h" | ||
#include "mlir/Dialect/Vector/IR/VectorOps.h" | ||
#include "mlir/Interfaces/ValueBoundsOpInterface.h" | ||
|
||
namespace mlir::vector { | ||
|
||
namespace detail { | ||
|
||
/// Parent class for the value bounds RTTIExtends. Uses protected inheritance to | ||
/// hide all ValueBoundsConstraintSet methods by default (as some do not use the | ||
/// ScalableValueBoundsConstraintSet, so may produce unexpected results). | ||
struct ValueBoundsConstraintSet : protected ::mlir::ValueBoundsConstraintSet { | ||
using ::mlir::ValueBoundsConstraintSet::ValueBoundsConstraintSet; | ||
}; | ||
} // namespace detail | ||
|
||
/// A version of `ValueBoundsConstraintSet` that can solve for scalable bounds. | ||
struct ScalableValueBoundsConstraintSet | ||
: public llvm::RTTIExtends<ScalableValueBoundsConstraintSet, | ||
detail::ValueBoundsConstraintSet> { | ||
ScalableValueBoundsConstraintSet(MLIRContext *context, unsigned vscaleMin, | ||
unsigned vscaleMax) | ||
: RTTIExtends(context), vscaleMin(vscaleMin), vscaleMax(vscaleMax){}; | ||
|
||
using RTTIExtends::bound; | ||
using RTTIExtends::StopConditionFn; | ||
|
||
/// A thin wrapper over an `AffineMap` which can represent a constant bound, | ||
/// or a scalable bound (in terms of vscale). The `AffineMap` will always | ||
/// take at most one parameter, vscale, and returns a single result, which is | ||
/// the bound of value. | ||
struct ConstantOrScalableBound { | ||
AffineMap map; | ||
|
||
struct BoundSize { | ||
int64_t baseSize{0}; | ||
bool scalable{false}; | ||
}; | ||
|
||
/// Get the (possibly) scalable size of the bound, returns failure if | ||
/// the bound cannot be represented as a single quantity. | ||
FailureOr<BoundSize> getSize() const; | ||
}; | ||
|
||
/// Computes a (possibly) scalable bound for a given value. This is | ||
/// similar to `ValueBoundsConstraintSet::computeConstantBound()`, but | ||
/// uses knowledge of the range of vscale to compute either a constant | ||
/// bound, an expression in terms of vscale, or failure if no bound can | ||
/// be computed. | ||
/// | ||
/// The resulting `AffineMap` will always take at most one parameter, | ||
/// vscale, and return a single result, which is the bound of `value`. | ||
/// | ||
/// Note: `vscaleMin` must be `<=` to `vscaleMax`. If `vscaleMin` == | ||
/// `vscaleMax`, the resulting bound (if found), will be constant. | ||
static FailureOr<ConstantOrScalableBound> | ||
MacDue marked this conversation as resolved.
Show resolved
Hide resolved
|
||
computeScalableBound(Value value, std::optional<int64_t> dim, | ||
unsigned vscaleMin, unsigned vscaleMax, | ||
presburger::BoundType boundType, bool closedUB = true, | ||
StopConditionFn stopCondition = nullptr); | ||
|
||
/// Get the value of vscale. Returns `nullptr` vscale as not been encountered. | ||
Value getVscaleValue() const { return vscale; } | ||
|
||
/// Sets the value of vscale. Asserts if vscale has already been set. | ||
void setVscale(vector::VectorScaleOp vscaleOp) { | ||
assert(!vscale && "expected vscale to be unset"); | ||
vscale = vscaleOp.getResult(); | ||
} | ||
|
||
/// The minimum possible value of vscale. | ||
unsigned getVscaleMin() const { return vscaleMin; } | ||
|
||
/// The maximum possible value of vscale. | ||
unsigned getVscaleMax() const { return vscaleMax; } | ||
|
||
static char ID; | ||
|
||
private: | ||
const unsigned vscaleMin; | ||
const unsigned vscaleMax; | ||
|
||
// This will be set when the first `vector.vscale` operation is found within | ||
// the `ValueBoundsOpInterface` implementation then reused from there on. | ||
Value vscale = nullptr; | ||
MacDue marked this conversation as resolved.
Show resolved
Hide resolved
|
||
}; | ||
|
||
using ConstantOrScalableBound = | ||
ScalableValueBoundsConstraintSet::ConstantOrScalableBound; | ||
|
||
} // namespace mlir::vector | ||
|
||
#endif // MLIR_DIALECT_VECTOR_IR_SCALABLEVALUEBOUNDSCONSTRAINTSET_H |
20 changes: 20 additions & 0 deletions
20
mlir/include/mlir/Dialect/Vector/IR/ValueBoundsOpInterfaceImpl.h
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
//===- ValueBoundsOpInterfaceImpl.h - Impl. of ValueBoundsOpInterface -----===// | ||
// | ||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://llvm.org/LICENSE.txt for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
// | ||
//===----------------------------------------------------------------------===// | ||
|
||
#ifndef MLIR_DIALECT_VECTOR_IR_VALUEBOUNDSOPINTERFACEIMPL_H | ||
#define MLIR_DIALECT_VECTOR_IR_VALUEBOUNDSOPINTERFACEIMPL_H | ||
|
||
namespace mlir { | ||
class DialectRegistry; | ||
|
||
namespace vector { | ||
void registerValueBoundsOpInterfaceExternalModels(DialectRegistry ®istry); | ||
} // namespace vector | ||
} // namespace mlir | ||
|
||
#endif // MLIR_DIALECT_VECTOR_IR_VALUEBOUNDSOPINTERFACEIMPL_H |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
103 changes: 103 additions & 0 deletions
103
mlir/lib/Dialect/Vector/IR/ScalableValueBoundsConstraintSet.cpp
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,103 @@ | ||
//===- ScalableValueBoundsConstraintSet.cpp - Scalable Value Bounds -------===// | ||
// | ||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://llvm.org/LICENSE.txt for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
// | ||
//===----------------------------------------------------------------------===// | ||
|
||
#include "mlir/Dialect/Vector/IR/ScalableValueBoundsConstraintSet.h" | ||
|
||
#include "mlir/Dialect/Vector/IR/VectorOps.h" | ||
|
||
namespace mlir::vector { | ||
|
||
FailureOr<ConstantOrScalableBound::BoundSize> | ||
ConstantOrScalableBound::getSize() const { | ||
if (map.isSingleConstant()) | ||
return BoundSize{map.getSingleConstantResult(), /*scalable=*/false}; | ||
if (map.getNumResults() != 1 || map.getNumInputs() != 1) | ||
return failure(); | ||
auto binop = dyn_cast<AffineBinaryOpExpr>(map.getResult(0)); | ||
if (!binop || binop.getKind() != AffineExprKind::Mul) | ||
return failure(); | ||
auto matchConstant = [&](AffineExpr expr, int64_t &constant) -> bool { | ||
if (auto cst = dyn_cast<AffineConstantExpr>(expr)) { | ||
constant = cst.getValue(); | ||
return true; | ||
} | ||
return false; | ||
}; | ||
// Match `s0 * cst` or `cst * s0`: | ||
banach-space marked this conversation as resolved.
Show resolved
Hide resolved
|
||
int64_t cst = 0; | ||
auto lhs = binop.getLHS(); | ||
auto rhs = binop.getRHS(); | ||
if ((matchConstant(lhs, cst) && isa<AffineSymbolExpr>(rhs)) || | ||
(matchConstant(rhs, cst) && isa<AffineSymbolExpr>(lhs))) { | ||
return BoundSize{cst, /*scalable=*/true}; | ||
} | ||
return failure(); | ||
} | ||
|
||
char ScalableValueBoundsConstraintSet::ID = 0; | ||
|
||
FailureOr<ConstantOrScalableBound> | ||
ScalableValueBoundsConstraintSet::computeScalableBound( | ||
Value value, std::optional<int64_t> dim, unsigned vscaleMin, | ||
unsigned vscaleMax, presburger::BoundType boundType, bool closedUB, | ||
StopConditionFn stopCondition) { | ||
using namespace presburger; | ||
|
||
assert(vscaleMin <= vscaleMax); | ||
ScalableValueBoundsConstraintSet scalableCstr(value.getContext(), vscaleMin, | ||
vscaleMax); | ||
|
||
int64_t pos = scalableCstr.populateConstraintsSet(value, dim, stopCondition); | ||
|
||
// Project out all variables apart from vscale. | ||
// This should result in constraints in terms of vscale only. | ||
scalableCstr.projectOut( | ||
[&](ValueDim p) { return p.first != scalableCstr.getVscaleValue(); }); | ||
|
||
assert(scalableCstr.cstr.getNumDimAndSymbolVars() == | ||
scalableCstr.positionToValueDim.size() && | ||
"inconsistent mapping state"); | ||
|
||
// Check that the only symbols left are vscale. | ||
for (int64_t i = 0; i < scalableCstr.cstr.getNumDimAndSymbolVars(); ++i) { | ||
if (i == pos) | ||
continue; | ||
if (scalableCstr.positionToValueDim[i] != | ||
ValueDim(scalableCstr.getVscaleValue(), | ||
ValueBoundsConstraintSet::kIndexValue)) { | ||
return failure(); | ||
} | ||
} | ||
|
||
SmallVector<AffineMap, 1> lowerBound(1), upperBound(1); | ||
scalableCstr.cstr.getSliceBounds(pos, 1, value.getContext(), &lowerBound, | ||
&upperBound, closedUB); | ||
|
||
auto invalidBound = [](auto &bound) { | ||
return !bound[0] || bound[0].getNumResults() != 1; | ||
}; | ||
|
||
AffineMap bound = [&] { | ||
if (boundType == BoundType::EQ && !invalidBound(lowerBound) && | ||
lowerBound[0] == lowerBound[0]) { | ||
return lowerBound[0]; | ||
} else if (boundType == BoundType::LB && !invalidBound(lowerBound)) { | ||
return lowerBound[0]; | ||
} else if (boundType == BoundType::UB && !invalidBound(upperBound)) { | ||
return upperBound[0]; | ||
} | ||
return AffineMap{}; | ||
}(); | ||
|
||
if (!bound) | ||
return failure(); | ||
|
||
return ConstantOrScalableBound{bound}; | ||
} | ||
|
||
} // namespace mlir::vector |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
//===- ValueBoundsOpInterfaceImpl.cpp - Impl. of ValueBoundsOpInterface ---===// | ||
// | ||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://llvm.org/LICENSE.txt for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
// | ||
//===----------------------------------------------------------------------===// | ||
|
||
#include "mlir/Dialect/Vector/IR/ValueBoundsOpInterfaceImpl.h" | ||
|
||
#include "mlir/Dialect/Vector/IR/ScalableValueBoundsConstraintSet.h" | ||
#include "mlir/Dialect/Vector/IR/VectorOps.h" | ||
#include "mlir/Interfaces/ValueBoundsOpInterface.h" | ||
|
||
using namespace mlir; | ||
|
||
namespace mlir::vector { | ||
namespace { | ||
|
||
struct VectorScaleOpInterface | ||
: public ValueBoundsOpInterface::ExternalModel<VectorScaleOpInterface, | ||
VectorScaleOp> { | ||
void populateBoundsForIndexValue(Operation *op, Value value, | ||
ValueBoundsConstraintSet &cstr) const { | ||
auto *scalableCstr = dyn_cast<ScalableValueBoundsConstraintSet>(&cstr); | ||
if (!scalableCstr) | ||
return; | ||
auto vscaleOp = cast<VectorScaleOp>(op); | ||
assert(value == vscaleOp.getResult() && "invalid value"); | ||
if (auto vscale = scalableCstr->getVscaleValue()) { | ||
// All copies of vscale are equivalent. | ||
scalableCstr->bound(value) == cstr.getExpr(vscale); | ||
} else { | ||
// We know vscale is confined to [vscaleMin, vscaleMax]. | ||
scalableCstr->bound(value) >= scalableCstr->getVscaleMin(); | ||
scalableCstr->bound(value) <= scalableCstr->getVscaleMax(); | ||
scalableCstr->setVscale(vscaleOp); | ||
} | ||
} | ||
}; | ||
|
||
} // namespace | ||
} // namespace mlir::vector | ||
|
||
void mlir::vector::registerValueBoundsOpInterfaceExternalModels( | ||
DialectRegistry ®istry) { | ||
registry.addExtension(+[](MLIRContext *ctx, vector::VectorDialect *dialect) { | ||
vector::VectorScaleOp::attachInterface<vector::VectorScaleOpInterface>( | ||
*ctx); | ||
}); | ||
} |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[nit] vscale ->
vector.vscale
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here I'm referring to the concept of
vscale
, not the vector dialect operation.