Skip to content

[mlir][sparse] simplify some header code #70989

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Nov 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 12 additions & 53 deletions mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,39 +22,24 @@

//===----------------------------------------------------------------------===//
//
// Type aliases to help code be more self-documenting. Unfortunately
// Type aliases to help code be more self-documenting. Unfortunately
// these are not type-checked, so they only provide documentation rather
// than doing anything to prevent mixups.
//
// We must include these here (rather than in "SparseTensorType.h")
// because they are used by methods declared in the tablegen files.
//
//===----------------------------------------------------------------------===//

namespace mlir {
namespace sparse_tensor {

/// The type of dimension identifiers, and dimension-ranks. We use the
/// same type for both identifiers and ranks because the latter are used
/// mainly for ordering-comparisons against the former (just like how the
/// one-past-the-end iterators are used).
/// The type of dimension identifiers and dimension-ranks.
using Dimension = uint64_t;

/// The type of level identifiers, and level-ranks. We use the same
/// type for both identifiers and ranks because the latter are used
/// mainly for ordering-comparisons against the former (just like how
/// the one-past-the-end iterators are used).
/// The type of level identifiers and level-ranks.
using Level = uint64_t;

/// The type for individual components of a compile-time shape. We avoid
/// calling this "size" because we use the term "sizes" to indicate the
/// actual run-time sizes, whereas this type also allows the value
/// `ShapedType::kDynamic`.
using DynSize = int64_t;

/// The type for individual components of a compile-time shape which
/// are known not to be `ShapedType::kDynamic`.
using StaticSize = int64_t;
/// The type for individual components of a compile-time shape,
/// including the value `ShapedType::kDynamic` (for shapes).
using Size = int64_t;

} // namespace sparse_tensor
} // namespace mlir
Expand All @@ -63,9 +48,6 @@ using StaticSize = int64_t;
// TableGen-defined classes
//===----------------------------------------------------------------------===//

// We must include Enums.h.inc before AttrDefs.h.inc due to dependency between
// StorageSpecifierKindAttr and StorageSpeciferKind Enum.

#define GET_ATTRDEF_CLASSES
#include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrEnums.h.inc"

Expand All @@ -87,11 +69,6 @@ using StaticSize = int64_t;
namespace mlir {
namespace sparse_tensor {

// NOTE: `Value::getType` doesn't check for null before trying to
// dereference things. Therefore we check, because an assertion-failure
// is easier to debug than a segfault. Presumably other `T::getType`
// methods are similarly susceptible.

/// Convenience method to abbreviate casting `getType()`.
template <typename T>
inline RankedTensorType getRankedTensorType(T &&t) {
Expand Down Expand Up @@ -192,33 +169,15 @@ bool isBlockSparsity(AffineMap dimToLvl);
// Reordering.
//

// This CPP guard is to disable deprecation warnings for the LLVM
// build-bot, while making it easy to re-enable it for local development.
#if 0
#define DEPRECATED \
LLVM_DEPRECATED("The toOrigDim/toStoredDim functions are deprecated " \
"because they only work for permutations; therefore any " \
"code using them cannot support non-permutations.", \
"")
#else
#define DEPRECATED
#endif

/// [deprecated] Convenience method to translate the given level to the
/// corresponding dimension. Requires: `0 <= l < lvlRank`.
DEPRECATED Dimension toOrigDim(SparseTensorEncodingAttr enc, Level l);
DEPRECATED Dimension toOrigDim(RankedTensorType type, Level l);
/// corresponding dimension. Requires: `0 <= l < lvlRank`.
Dimension toOrigDim(SparseTensorEncodingAttr enc, Level l);
Dimension toOrigDim(RankedTensorType type, Level l);

/// [deprecated] Convenience method to translate the given dimension to
/// the corresponding level. Requires: `0 <= d < dimRank`.
DEPRECATED Level toStoredDim(SparseTensorEncodingAttr enc, Dimension d);
DEPRECATED Level toStoredDim(RankedTensorType type, Dimension d);

#undef DEPRECATED

namespace detail {
Type getIntegerOrIndexType(MLIRContext *ctx, unsigned bitwidth);
} // namespace detail
/// the corresponding level. Requires: `0 <= d < dimRank`.
Level toStoredDim(SparseTensorEncodingAttr enc, Dimension d);
Level toStoredDim(RankedTensorType type, Dimension d);

} // namespace sparse_tensor
} // namespace mlir
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -403,20 +403,6 @@ def SparseTensorEncodingAttr : SparseTensor_Attr<"SparseTensorEncoding",
/// always have the identity mapping).
bool isPermutation() const;

//
// posWidth/crdWidth methods.
//

/// Returns the type for position storage based on posWidth.
/// Asserts that the encoding is non-null (since there's nowhere
/// to get the `MLIRContext` from).
Type getPosType() const;

/// Returns the type for coordinate storage based on crdWidth.
/// Asserts that the encoding is non-null (since there's nowhere
/// to get the `MLIRContext` from).
Type getCrdType() const;

//
// dimSlices methods.
//
Expand Down Expand Up @@ -571,5 +557,4 @@ def SparseTensorCrdTransDirectionAttr
"CrdTransDirection"> {
}


#endif // SPARSETENSOR_ATTRDEFS
38 changes: 15 additions & 23 deletions mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,7 @@ class SparseTensorType {
: SparseTensorType(
RankedTensorType::get(stp.getShape(), stp.getElementType(), enc)) {}

// Copy-assignment would be implicitly deleted (because our fields
// are const), so we explicitly delete it for clarity.
SparseTensorType &operator=(const SparseTensorType &) = delete;
// So we must explicitly define the copy-ctor to silence -Wdeprecated-copy.
SparseTensorType(const SparseTensorType &) = default;

//
Expand Down Expand Up @@ -243,10 +240,10 @@ class SparseTensorType {
Level getLvlRank() const { return lvlRank; }

/// Returns the dimension-shape.
ArrayRef<DynSize> getDimShape() const { return rtp.getShape(); }
ArrayRef<Size> getDimShape() const { return rtp.getShape(); }

/// Returns the Level-shape.
SmallVector<DynSize> getLvlShape() const {
SmallVector<Size> getLvlShape() const {
return getEncoding().tranlateShape(getDimShape(),
CrdTransDirectionKind::dim2lvl);
}
Expand All @@ -260,19 +257,11 @@ class SparseTensorType {
/// Safely looks up the requested dimension-DynSize. If you intend
/// to check the result with `ShapedType::isDynamic`, then see the
/// `getStaticDimSize` method instead.
DynSize getDynamicDimSize(Dimension d) const {
Size getDynamicDimSize(Dimension d) const {
assert(d < getDimRank() && "Dimension is out of bounds");
return getDimShape()[d];
}

/// Safely looks up the requested dimension-size, mapping dynamic
/// sizes to `std::nullopt`.
std::optional<StaticSize> getStaticDimSize(Dimension d) const {
const DynSize sh = getDynamicDimSize(d);
return ShapedType::isDynamic(sh) ? std::nullopt
: std::optional<StaticSize>(sh);
}

/// Returns true if no dimension has dynamic size.
bool hasStaticDimShape() const { return rtp.hasStaticShape(); }

Expand Down Expand Up @@ -318,12 +307,16 @@ class SparseTensorType {

/// Returns the coordinate-overhead MLIR type, defaulting to `IndexType`.
Type getCrdType() const {
return detail::getIntegerOrIndexType(getContext(), getCrdWidth());
if (getCrdWidth())
return IntegerType::get(getContext(), getCrdWidth());
return IndexType::get(getContext());
}

/// Returns the position-overhead MLIR type, defaulting to `IndexType`.
Type getPosType() const {
return detail::getIntegerOrIndexType(getContext(), getPosWidth());
if (getPosWidth())
return IntegerType::get(getContext(), getPosWidth());
return IndexType::get(getContext());
}

private:
Expand All @@ -336,14 +329,13 @@ class SparseTensorType {
const AffineMap lvlToDim;
};

/// Convenience methods to abbreviate wrapping `getRankedTensorType`.
template <typename T>
inline SparseTensorType getSparseTensorType(T t) {
return SparseTensorType(getRankedTensorType(t));
/// Convenience methods to obtain a SparseTensorType from a Value.
inline SparseTensorType getSparseTensorType(Value val) {
return SparseTensorType(cast<RankedTensorType>(val.getType()));
}
inline std::optional<SparseTensorType> tryGetSparseTensorType(Value v) {
if (isa<RankedTensorType>(v.getType()))
return getSparseTensorType(v);
inline std::optional<SparseTensorType> tryGetSparseTensorType(Value val) {
if (auto rtp = dyn_cast<RankedTensorType>(val.getType()))
return SparseTensorType(rtp);
return std::nullopt;
}

Expand Down
35 changes: 9 additions & 26 deletions mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -270,23 +270,6 @@ SparseTensorDimSliceAttr::verify(function_ref<InFlightDiagnostic()> emitError,
return success();
}

Type mlir::sparse_tensor::detail::getIntegerOrIndexType(MLIRContext *ctx,
unsigned bitwidth) {
if (bitwidth)
return IntegerType::get(ctx, bitwidth);
return IndexType::get(ctx);
}

Type SparseTensorEncodingAttr::getPosType() const {
assert(getImpl() && "Uninitialized SparseTensorEncodingAttr");
return detail::getIntegerOrIndexType(getContext(), getPosWidth());
}

Type SparseTensorEncodingAttr::getCrdType() const {
assert(getImpl() && "Uninitialized SparseTensorEncodingAttr");
return detail::getIntegerOrIndexType(getContext(), getCrdWidth());
}

SparseTensorEncodingAttr
SparseTensorEncodingAttr::withDimToLvl(AffineMap dimToLvl) const {
assert(getImpl() && "Uninitialized SparseTensorEncodingAttr");
Expand Down Expand Up @@ -722,7 +705,7 @@ SparseTensorEncodingAttr::verify(function_ref<InFlightDiagnostic()> emitError,
}

LogicalResult SparseTensorEncodingAttr::verifyEncoding(
ArrayRef<DynSize> dimShape, Type elementType,
ArrayRef<Size> dimShape, Type elementType,
function_ref<InFlightDiagnostic()> emitError) const {
// Check structural integrity. In particular, this ensures that the
// level-rank is coherent across all the fields.
Expand Down Expand Up @@ -1312,7 +1295,7 @@ OpFoldResult LvlOp::fold(FoldAdaptor adaptor) {

// TODO: we can remove this after SparseTensorEncoding always returns non-null
// dimToLvl map.
ArrayRef<DynSize> shape = stt.getDimShape();
ArrayRef<Size> shape = stt.getDimShape();
if (stt.isPermutation()) {
Dimension dim = toOrigDim(stt, lvl);
if (!ShapedType::isDynamic(shape[dim])) {
Expand Down Expand Up @@ -1378,8 +1361,8 @@ LogicalResult ReinterpretMapOp::verify() {
if (srcStt.getElementType() != dstStt.getElementType())
return emitError("Element type mismatch between source/dest tensors");

SmallVector<DynSize> srcLvlShape = srcStt.getLvlShape();
SmallVector<DynSize> dstLvlShape = dstStt.getLvlShape();
SmallVector<Size> srcLvlShape = srcStt.getLvlShape();
SmallVector<Size> dstLvlShape = dstStt.getLvlShape();
for (auto [srcLvlSz, dstLvlSz] : llvm::zip(srcLvlShape, dstLvlShape)) {
if (srcLvlSz != dstLvlSz) {
// Should we allow one side to be dynamic size, e.g., <?x?> should be
Expand Down Expand Up @@ -1616,13 +1599,13 @@ LogicalResult ConcatenateOp::verify() {
}

for (Dimension d = 0; d < dimRank; d++) {
const DynSize dstSh = dstTp.getDimShape()[d];
const Size dstSh = dstTp.getDimShape()[d];
if (d == concatDim) {
if (!ShapedType::isDynamic(dstSh)) {
// If we reach here, then all inputs have static shapes. So we
// can use `getDimShape()[d]` instead of `*getDynamicDimSize(d)`
// to avoid redundant assertions in the loop.
StaticSize sumSz = 0;
Size sumSz = 0;
for (const auto src : getInputs())
sumSz += getSparseTensorType(src).getDimShape()[d];
// If all dimension are statically known, the sum of all the input
Expand All @@ -1633,7 +1616,7 @@ LogicalResult ConcatenateOp::verify() {
"sum of all the concatenation dimensions of the input tensors.");
}
} else {
DynSize prev = dstSh;
Size prev = dstSh;
for (const auto src : getInputs()) {
const auto sh = getSparseTensorType(src).getDimShape()[d];
if (!ShapedType::isDynamic(prev) && sh != prev)
Expand Down Expand Up @@ -1808,8 +1791,8 @@ LogicalResult SortOp::verify() {
// FIXME: update the types of variables used in expressions bassed as
// the `minSize` argument, to avoid implicit casting at the callsites
// of this lambda.
const auto checkDim = [&](Value v, StaticSize minSize, const char *message) {
const DynSize sh = getMemRefType(v).getShape()[0];
const auto checkDim = [&](Value v, Size minSize, const char *message) {
const Size sh = getMemRefType(v).getShape()[0];
if (!ShapedType::isDynamic(sh) && sh < minSize)
emitError(llvm::formatv("{0} got {1} < {2}", message, sh, minSize));
};
Expand Down
23 changes: 11 additions & 12 deletions mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,6 @@ OverheadType mlir::sparse_tensor::overheadTypeEncoding(Type tp) {
llvm_unreachable("Unknown overhead type");
}

// TODO: should offer an overload of this that takes a `MLIRContext*`
// instead of the builder, similar to `detail::getIntegerOrIndexType`.
Type mlir::sparse_tensor::getOverheadType(Builder &builder, OverheadType ot) {
switch (ot) {
case OverheadType::kIndex:
Expand Down Expand Up @@ -209,7 +207,7 @@ Value mlir::sparse_tensor::genIsNonzero(OpBuilder &builder, mlir::Location loc,

void mlir::sparse_tensor::genReshapeDstShape(
OpBuilder &builder, Location loc, SmallVectorImpl<Value> &dstShape,
ArrayRef<Value> srcShape, ArrayRef<StaticSize> staticDstShape,
ArrayRef<Value> srcShape, ArrayRef<Size> staticDstShape,
ArrayRef<ReassociationIndices> reassociation) {
// Collapse shape.
if (reassociation.size() < srcShape.size()) {
Expand Down Expand Up @@ -242,7 +240,7 @@ void mlir::sparse_tensor::genReshapeDstShape(
if (staticDstShape[j] == ShapedType::kDynamic) {
// The expanded dimension has dynamic size. We compute the dimension
// by dividing srcDim by the product of the static dimensions.
StaticSize product = 1;
Size product = 1;
for (unsigned k = start; k < start + map.size(); k++) {
if (staticDstShape[k] != ShapedType::kDynamic) {
product *= staticDstShape[k];
Expand Down Expand Up @@ -423,7 +421,8 @@ Operation *mlir::sparse_tensor::getTop(Operation *op) {
void sparse_tensor::foreachInSparseConstant(
OpBuilder &builder, Location loc, SparseElementsAttr attr, AffineMap order,
function_ref<void(ArrayRef<Value>, Value)> callback) {
const Dimension dimRank = getSparseTensorType(attr).getDimRank();
const Dimension dimRank =
SparseTensorType(getRankedTensorType(attr)).getDimRank();
const auto coordinates = attr.getIndices().getValues<IntegerAttr>();
const auto values = attr.getValues().getValues<Attribute>();

Expand Down Expand Up @@ -494,8 +493,8 @@ SmallVector<Value> sparse_tensor::loadAll(OpBuilder &builder, Location loc,
#ifndef NDEBUG
const auto memTp = cast<MemRefType>(mem.getType());
assert(memTp.getRank() == 1);
const DynSize memSh = memTp.getDimSize(0);
assert(ShapedType::isDynamic(memSh) || memSh >= static_cast<DynSize>(size));
const Size memSh = memTp.getDimSize(0);
assert(ShapedType::isDynamic(memSh) || memSh >= static_cast<Size>(size));
assert(offsetIdx == 0 || offsetIdx < size);
#endif // NDEBUG
SmallVector<Value> vs;
Expand All @@ -516,8 +515,8 @@ void sparse_tensor::storeAll(OpBuilder &builder, Location loc, Value mem,
const size_t vsize = vs.size();
const auto memTp = cast<MemRefType>(mem.getType());
assert(memTp.getRank() == 1);
const DynSize memSh = memTp.getDimSize(0);
assert(ShapedType::isDynamic(memSh) || memSh >= static_cast<DynSize>(vsize));
const Size memSh = memTp.getDimSize(0);
assert(ShapedType::isDynamic(memSh) || memSh >= static_cast<Size>(vsize));
assert(offsetIdx == 0 || offsetIdx < vsize);
#endif // NDEBUG
for (const auto &v : llvm::enumerate(vs)) {
Expand Down Expand Up @@ -546,11 +545,11 @@ Value sparse_tensor::reshapeValuesToLevels(OpBuilder &builder, Location loc,
// The memref ReshapeOp requires the sizes buffer to have a static
// shape.
const auto iTp = builder.getIndexType();
const SmallVector<DynSize, 1> lvlSizesShape{static_cast<DynSize>(lvlRank)};
const SmallVector<Size, 1> lvlSizesShape{static_cast<Size>(lvlRank)};
const auto lvlSizesTp = MemRefType::get(lvlSizesShape, iTp);
lvlCoords = builder.create<memref::CastOp>(loc, lvlSizesTp, lvlCoords);
// Finally, create the ReshapeOp.
const SmallVector<DynSize> resShape(lvlRank, ShapedType::kDynamic);
const SmallVector<Size> resShape(lvlRank, ShapedType::kDynamic);
const Type elemTp = getMemRefType(valuesBuffer).getElementType();
const auto resTp = MemRefType::get(resShape, elemTp);
return builder.create<memref::ReshapeOp>(loc, resTp, valuesBuffer, lvlCoords);
Expand Down Expand Up @@ -628,7 +627,7 @@ void sparse_tensor::fillDimShape(OpBuilder &builder, Location loc,
SmallVectorImpl<Value> &out) {
out.clear();
out.reserve(stt.getDimRank());
for (const DynSize sh : stt.getDimShape()) {
for (const Size sh : stt.getDimShape()) {
const auto s = ShapedType::isDynamic(sh) ? 0 : sh;
out.push_back(constantIndex(builder, loc, s));
}
Expand Down
Loading