Skip to content

[mlir][Vector] Add utility for computing scalable value bounds #83876

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Mar 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
//===- ScalableValueBoundsConstraintSet.h - Scalable Value Bounds ---------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_DIALECT_VECTOR_IR_SCALABLEVALUEBOUNDSCONSTRAINTSET_H
#define MLIR_DIALECT_VECTOR_IR_SCALABLEVALUEBOUNDSCONSTRAINTSET_H

#include "mlir/Analysis/Presburger/IntegerRelation.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Interfaces/ValueBoundsOpInterface.h"

namespace mlir::vector {

namespace detail {

/// Parent class for the value bounds RTTIExtends. Uses protected inheritance to
/// hide all ValueBoundsConstraintSet methods by default (as some do not use the
/// ScalableValueBoundsConstraintSet, so may produce unexpected results).
struct ValueBoundsConstraintSet : protected ::mlir::ValueBoundsConstraintSet {
using ::mlir::ValueBoundsConstraintSet::ValueBoundsConstraintSet;
};
} // namespace detail

/// A version of `ValueBoundsConstraintSet` that can solve for scalable bounds.
struct ScalableValueBoundsConstraintSet
: public llvm::RTTIExtends<ScalableValueBoundsConstraintSet,
detail::ValueBoundsConstraintSet> {
ScalableValueBoundsConstraintSet(MLIRContext *context, unsigned vscaleMin,
unsigned vscaleMax)
: RTTIExtends(context), vscaleMin(vscaleMin), vscaleMax(vscaleMax){};

using RTTIExtends::bound;
using RTTIExtends::StopConditionFn;

/// A thin wrapper over an `AffineMap` which can represent a constant bound,
/// or a scalable bound (in terms of vscale). The `AffineMap` will always
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nit] vscale -> vector.vscale

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here I'm referring to the concept of vscale, not the vector dialect operation.

/// take at most one parameter, vscale, and returns a single result, which is
/// the bound of value.
struct ConstantOrScalableBound {
AffineMap map;

struct BoundSize {
int64_t baseSize{0};
bool scalable{false};
};

/// Get the (possibly) scalable size of the bound, returns failure if
/// the bound cannot be represented as a single quantity.
FailureOr<BoundSize> getSize() const;
};

/// Computes a (possibly) scalable bound for a given value. This is
/// similar to `ValueBoundsConstraintSet::computeConstantBound()`, but
/// uses knowledge of the range of vscale to compute either a constant
/// bound, an expression in terms of vscale, or failure if no bound can
/// be computed.
///
/// The resulting `AffineMap` will always take at most one parameter,
/// vscale, and return a single result, which is the bound of `value`.
///
/// Note: `vscaleMin` must be `<=` to `vscaleMax`. If `vscaleMin` ==
/// `vscaleMax`, the resulting bound (if found), will be constant.
static FailureOr<ConstantOrScalableBound>
computeScalableBound(Value value, std::optional<int64_t> dim,
unsigned vscaleMin, unsigned vscaleMax,
presburger::BoundType boundType, bool closedUB = true,
StopConditionFn stopCondition = nullptr);

/// Get the value of vscale. Returns `nullptr` vscale as not been encountered.
Value getVscaleValue() const { return vscale; }

/// Sets the value of vscale. Asserts if vscale has already been set.
void setVscale(vector::VectorScaleOp vscaleOp) {
assert(!vscale && "expected vscale to be unset");
vscale = vscaleOp.getResult();
}

/// The minimum possible value of vscale.
unsigned getVscaleMin() const { return vscaleMin; }

/// The maximum possible value of vscale.
unsigned getVscaleMax() const { return vscaleMax; }

static char ID;

private:
const unsigned vscaleMin;
const unsigned vscaleMax;

// This will be set when the first `vector.vscale` operation is found within
// the `ValueBoundsOpInterface` implementation then reused from there on.
Value vscale = nullptr;
};

using ConstantOrScalableBound =
ScalableValueBoundsConstraintSet::ConstantOrScalableBound;

} // namespace mlir::vector

#endif // MLIR_DIALECT_VECTOR_IR_SCALABLEVALUEBOUNDSCONSTRAINTSET_H
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
//===- ValueBoundsOpInterfaceImpl.h - Impl. of ValueBoundsOpInterface -----===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_DIALECT_VECTOR_IR_VALUEBOUNDSOPINTERFACEIMPL_H
#define MLIR_DIALECT_VECTOR_IR_VALUEBOUNDSOPINTERFACEIMPL_H

namespace mlir {
class DialectRegistry;

namespace vector {
void registerValueBoundsOpInterfaceExternalModels(DialectRegistry &registry);
} // namespace vector
} // namespace mlir

#endif // MLIR_DIALECT_VECTOR_IR_VALUEBOUNDSOPINTERFACEIMPL_H
2 changes: 2 additions & 0 deletions mlir/include/mlir/InitAllDialects.h
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
#include "mlir/Dialect/Transform/PDLExtension/PDLExtension.h"
#include "mlir/Dialect/UB/IR/UBOps.h"
#include "mlir/Dialect/Vector/IR/ValueBoundsOpInterfaceImpl.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.h"
#include "mlir/Dialect/Vector/Transforms/SubsetOpInterfaceImpl.h"
Expand Down Expand Up @@ -174,6 +175,7 @@ inline void registerAllDialects(DialectRegistry &registry) {
tosa::registerShardingInterfaceExternalModels(registry);
vector::registerBufferizableOpInterfaceExternalModels(registry);
vector::registerSubsetOpInterfaceExternalModels(registry);
vector::registerValueBoundsOpInterfaceExternalModels(registry);
NVVM::registerNVVMTargetInterfaceExternalModels(registry);
ROCDL::registerROCDLTargetInterfaceExternalModels(registry);
spirv::registerSPIRVTargetInterfaceExternalModels(registry);
Expand Down
16 changes: 15 additions & 1 deletion mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "mlir/IR/Value.h"
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/Support/ExtensibleRTTI.h"

#include <queue>

Expand Down Expand Up @@ -63,7 +64,8 @@ using ValueDimList = SmallVector<std::pair<Value, std::optional<int64_t>>>;
///
/// Note: Any modification of existing IR invalides the data stored in this
/// class. Adding new operations is allowed.
class ValueBoundsConstraintSet {
class ValueBoundsConstraintSet
: public llvm::RTTIExtends<ValueBoundsConstraintSet, llvm::RTTIRoot> {
protected:
/// Helper class that builds a bound for a shaped value dimension or
/// index-typed value.
Expand Down Expand Up @@ -107,6 +109,8 @@ class ValueBoundsConstraintSet {
};

public:
static char ID;

/// The stop condition when traversing the backward slice of a shaped value/
/// index-type value. The traversal continues until the stop condition
/// evaluates to "true" for a value.
Expand Down Expand Up @@ -265,6 +269,16 @@ class ValueBoundsConstraintSet {

ValueBoundsConstraintSet(MLIRContext *ctx);

/// Populates the constraint set for a value/map without actually computing
/// the bound. Returns the position for the value/map (via the return value
/// and `posOut` output parameter).
int64_t populateConstraintsSet(Value value,
std::optional<int64_t> dim = std::nullopt,
StopConditionFn stopCondition = nullptr);
int64_t populateConstraintsSet(AffineMap map, ValueDimList mapOperands,
StopConditionFn stopCondition = nullptr,
int64_t *posOut = nullptr);

/// Iteratively process all elements on the worklist until an index-typed
/// value or shaped value meets `stopCondition`. Such values are not processed
/// any further.
Expand Down
2 changes: 2 additions & 0 deletions mlir/lib/Dialect/Vector/IR/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
add_mlir_dialect_library(MLIRVectorDialect
VectorOps.cpp
ValueBoundsOpInterfaceImpl.cpp
ScalableValueBoundsConstraintSet.cpp

ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Vector/IR
Expand Down
103 changes: 103 additions & 0 deletions mlir/lib/Dialect/Vector/IR/ScalableValueBoundsConstraintSet.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
//===- ScalableValueBoundsConstraintSet.cpp - Scalable Value Bounds -------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Vector/IR/ScalableValueBoundsConstraintSet.h"

#include "mlir/Dialect/Vector/IR/VectorOps.h"

namespace mlir::vector {

FailureOr<ConstantOrScalableBound::BoundSize>
ConstantOrScalableBound::getSize() const {
if (map.isSingleConstant())
return BoundSize{map.getSingleConstantResult(), /*scalable=*/false};
if (map.getNumResults() != 1 || map.getNumInputs() != 1)
return failure();
auto binop = dyn_cast<AffineBinaryOpExpr>(map.getResult(0));
if (!binop || binop.getKind() != AffineExprKind::Mul)
return failure();
auto matchConstant = [&](AffineExpr expr, int64_t &constant) -> bool {
if (auto cst = dyn_cast<AffineConstantExpr>(expr)) {
constant = cst.getValue();
return true;
}
return false;
};
// Match `s0 * cst` or `cst * s0`:
int64_t cst = 0;
auto lhs = binop.getLHS();
auto rhs = binop.getRHS();
if ((matchConstant(lhs, cst) && isa<AffineSymbolExpr>(rhs)) ||
(matchConstant(rhs, cst) && isa<AffineSymbolExpr>(lhs))) {
return BoundSize{cst, /*scalable=*/true};
}
return failure();
}

char ScalableValueBoundsConstraintSet::ID = 0;

FailureOr<ConstantOrScalableBound>
ScalableValueBoundsConstraintSet::computeScalableBound(
Value value, std::optional<int64_t> dim, unsigned vscaleMin,
unsigned vscaleMax, presburger::BoundType boundType, bool closedUB,
StopConditionFn stopCondition) {
using namespace presburger;

assert(vscaleMin <= vscaleMax);
ScalableValueBoundsConstraintSet scalableCstr(value.getContext(), vscaleMin,
vscaleMax);

int64_t pos = scalableCstr.populateConstraintsSet(value, dim, stopCondition);

// Project out all variables apart from vscale.
// This should result in constraints in terms of vscale only.
scalableCstr.projectOut(
[&](ValueDim p) { return p.first != scalableCstr.getVscaleValue(); });

assert(scalableCstr.cstr.getNumDimAndSymbolVars() ==
scalableCstr.positionToValueDim.size() &&
"inconsistent mapping state");

// Check that the only symbols left are vscale.
for (int64_t i = 0; i < scalableCstr.cstr.getNumDimAndSymbolVars(); ++i) {
if (i == pos)
continue;
if (scalableCstr.positionToValueDim[i] !=
ValueDim(scalableCstr.getVscaleValue(),
ValueBoundsConstraintSet::kIndexValue)) {
return failure();
}
}

SmallVector<AffineMap, 1> lowerBound(1), upperBound(1);
scalableCstr.cstr.getSliceBounds(pos, 1, value.getContext(), &lowerBound,
&upperBound, closedUB);

auto invalidBound = [](auto &bound) {
return !bound[0] || bound[0].getNumResults() != 1;
};

AffineMap bound = [&] {
if (boundType == BoundType::EQ && !invalidBound(lowerBound) &&
lowerBound[0] == lowerBound[0]) {
return lowerBound[0];
} else if (boundType == BoundType::LB && !invalidBound(lowerBound)) {
return lowerBound[0];
} else if (boundType == BoundType::UB && !invalidBound(upperBound)) {
return upperBound[0];
}
return AffineMap{};
}();

if (!bound)
return failure();

return ConstantOrScalableBound{bound};
}

} // namespace mlir::vector
51 changes: 51 additions & 0 deletions mlir/lib/Dialect/Vector/IR/ValueBoundsOpInterfaceImpl.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
//===- ValueBoundsOpInterfaceImpl.cpp - Impl. of ValueBoundsOpInterface ---===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Vector/IR/ValueBoundsOpInterfaceImpl.h"

#include "mlir/Dialect/Vector/IR/ScalableValueBoundsConstraintSet.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Interfaces/ValueBoundsOpInterface.h"

using namespace mlir;

namespace mlir::vector {
namespace {

struct VectorScaleOpInterface
: public ValueBoundsOpInterface::ExternalModel<VectorScaleOpInterface,
VectorScaleOp> {
void populateBoundsForIndexValue(Operation *op, Value value,
ValueBoundsConstraintSet &cstr) const {
auto *scalableCstr = dyn_cast<ScalableValueBoundsConstraintSet>(&cstr);
if (!scalableCstr)
return;
auto vscaleOp = cast<VectorScaleOp>(op);
assert(value == vscaleOp.getResult() && "invalid value");
if (auto vscale = scalableCstr->getVscaleValue()) {
// All copies of vscale are equivalent.
scalableCstr->bound(value) == cstr.getExpr(vscale);
} else {
// We know vscale is confined to [vscaleMin, vscaleMax].
scalableCstr->bound(value) >= scalableCstr->getVscaleMin();
scalableCstr->bound(value) <= scalableCstr->getVscaleMax();
scalableCstr->setVscale(vscaleOp);
}
}
};

} // namespace
} // namespace mlir::vector

void mlir::vector::registerValueBoundsOpInterfaceExternalModels(
DialectRegistry &registry) {
registry.addExtension(+[](MLIRContext *ctx, vector::VectorDialect *dialect) {
vector::VectorScaleOp::attachInterface<vector::VectorScaleOpInterface>(
*ctx);
});
}
Loading