Skip to content

Commit 22212ca

Browse files
authored
[mlir][sparse] simplify some header code (#70989)
This is a first revision in a small series of changes that removes duplications between direct encoding methods and sparse tensor type wrapper methods (in favor of the latter abstraction, since it provides more safety). The goal is to simply end up with "just" SparseTensorType
1 parent 4c41e7c commit 22212ca

File tree

9 files changed

+60
-142
lines changed

9 files changed

+60
-142
lines changed

mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h

Lines changed: 12 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -22,39 +22,24 @@
2222

2323
//===----------------------------------------------------------------------===//
2424
//
25-
// Type aliases to help code be more self-documenting. Unfortunately
25+
// Type aliases to help code be more self-documenting. Unfortunately
2626
// these are not type-checked, so they only provide documentation rather
2727
// than doing anything to prevent mixups.
2828
//
29-
// We must include these here (rather than in "SparseTensorType.h")
30-
// because they are used by methods declared in the tablegen files.
31-
//
3229
//===----------------------------------------------------------------------===//
3330

3431
namespace mlir {
3532
namespace sparse_tensor {
3633

37-
/// The type of dimension identifiers, and dimension-ranks. We use the
38-
/// same type for both identifiers and ranks because the latter are used
39-
/// mainly for ordering-comparisons against the former (just like how the
40-
/// one-past-the-end iterators are used).
34+
/// The type of dimension identifiers and dimension-ranks.
4135
using Dimension = uint64_t;
4236

43-
/// The type of level identifiers, and level-ranks. We use the same
44-
/// type for both identifiers and ranks because the latter are used
45-
/// mainly for ordering-comparisons against the former (just like how
46-
/// the one-past-the-end iterators are used).
37+
/// The type of level identifiers and level-ranks.
4738
using Level = uint64_t;
4839

49-
/// The type for individual components of a compile-time shape. We avoid
50-
/// calling this "size" because we use the term "sizes" to indicate the
51-
/// actual run-time sizes, whereas this type also allows the value
52-
/// `ShapedType::kDynamic`.
53-
using DynSize = int64_t;
54-
55-
/// The type for individual components of a compile-time shape which
56-
/// are known not to be `ShapedType::kDynamic`.
57-
using StaticSize = int64_t;
40+
/// The type for individual components of a compile-time shape,
41+
/// including the value `ShapedType::kDynamic` (for shapes).
42+
using Size = int64_t;
5843

5944
} // namespace sparse_tensor
6045
} // namespace mlir
@@ -63,9 +48,6 @@ using StaticSize = int64_t;
6348
// TableGen-defined classes
6449
//===----------------------------------------------------------------------===//
6550

66-
// We must include Enums.h.inc before AttrDefs.h.inc due to dependency between
67-
// StorageSpecifierKindAttr and StorageSpeciferKind Enum.
68-
6951
#define GET_ATTRDEF_CLASSES
7052
#include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrEnums.h.inc"
7153

@@ -87,11 +69,6 @@ using StaticSize = int64_t;
8769
namespace mlir {
8870
namespace sparse_tensor {
8971

90-
// NOTE: `Value::getType` doesn't check for null before trying to
91-
// dereference things. Therefore we check, because an assertion-failure
92-
// is easier to debug than a segfault. Presumably other `T::getType`
93-
// methods are similarly susceptible.
94-
9572
/// Convenience method to abbreviate casting `getType()`.
9673
template <typename T>
9774
inline RankedTensorType getRankedTensorType(T &&t) {
@@ -192,33 +169,15 @@ bool isBlockSparsity(AffineMap dimToLvl);
192169
// Reordering.
193170
//
194171

195-
// This CPP guard is to disable deprecation warnings for the LLVM
196-
// build-bot, while making it easy to re-enable it for local development.
197-
#if 0
198-
#define DEPRECATED \
199-
LLVM_DEPRECATED("The toOrigDim/toStoredDim functions are deprecated " \
200-
"because they only work for permutations; therefore any " \
201-
"code using them cannot support non-permutations.", \
202-
"")
203-
#else
204-
#define DEPRECATED
205-
#endif
206-
207172
/// [deprecated] Convenience method to translate the given level to the
208-
/// corresponding dimension. Requires: `0 <= l < lvlRank`.
209-
DEPRECATED Dimension toOrigDim(SparseTensorEncodingAttr enc, Level l);
210-
DEPRECATED Dimension toOrigDim(RankedTensorType type, Level l);
173+
/// corresponding dimension. Requires: `0 <= l < lvlRank`.
174+
Dimension toOrigDim(SparseTensorEncodingAttr enc, Level l);
175+
Dimension toOrigDim(RankedTensorType type, Level l);
211176

212177
/// [deprecated] Convenience method to translate the given dimension to
213-
/// the corresponding level. Requires: `0 <= d < dimRank`.
214-
DEPRECATED Level toStoredDim(SparseTensorEncodingAttr enc, Dimension d);
215-
DEPRECATED Level toStoredDim(RankedTensorType type, Dimension d);
216-
217-
#undef DEPRECATED
218-
219-
namespace detail {
220-
Type getIntegerOrIndexType(MLIRContext *ctx, unsigned bitwidth);
221-
} // namespace detail
178+
/// the corresponding level. Requires: `0 <= d < dimRank`.
179+
Level toStoredDim(SparseTensorEncodingAttr enc, Dimension d);
180+
Level toStoredDim(RankedTensorType type, Dimension d);
222181

223182
} // namespace sparse_tensor
224183
} // namespace mlir

mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -403,20 +403,6 @@ def SparseTensorEncodingAttr : SparseTensor_Attr<"SparseTensorEncoding",
403403
/// always have the identity mapping).
404404
bool isPermutation() const;
405405

406-
//
407-
// posWidth/crdWidth methods.
408-
//
409-
410-
/// Returns the type for position storage based on posWidth.
411-
/// Asserts that the encoding is non-null (since there's nowhere
412-
/// to get the `MLIRContext` from).
413-
Type getPosType() const;
414-
415-
/// Returns the type for coordinate storage based on crdWidth.
416-
/// Asserts that the encoding is non-null (since there's nowhere
417-
/// to get the `MLIRContext` from).
418-
Type getCrdType() const;
419-
420406
//
421407
// dimSlices methods.
422408
//
@@ -571,5 +557,4 @@ def SparseTensorCrdTransDirectionAttr
571557
"CrdTransDirection"> {
572558
}
573559

574-
575560
#endif // SPARSETENSOR_ATTRDEFS

mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h

Lines changed: 15 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -60,10 +60,7 @@ class SparseTensorType {
6060
: SparseTensorType(
6161
RankedTensorType::get(stp.getShape(), stp.getElementType(), enc)) {}
6262

63-
// Copy-assignment would be implicitly deleted (because our fields
64-
// are const), so we explicitly delete it for clarity.
6563
SparseTensorType &operator=(const SparseTensorType &) = delete;
66-
// So we must explicitly define the copy-ctor to silence -Wdeprecated-copy.
6764
SparseTensorType(const SparseTensorType &) = default;
6865

6966
//
@@ -243,10 +240,10 @@ class SparseTensorType {
243240
Level getLvlRank() const { return lvlRank; }
244241

245242
/// Returns the dimension-shape.
246-
ArrayRef<DynSize> getDimShape() const { return rtp.getShape(); }
243+
ArrayRef<Size> getDimShape() const { return rtp.getShape(); }
247244

248245
/// Returns the Level-shape.
249-
SmallVector<DynSize> getLvlShape() const {
246+
SmallVector<Size> getLvlShape() const {
250247
return getEncoding().tranlateShape(getDimShape(),
251248
CrdTransDirectionKind::dim2lvl);
252249
}
@@ -260,19 +257,11 @@ class SparseTensorType {
260257
/// Safely looks up the requested dimension-DynSize. If you intend
261258
/// to check the result with `ShapedType::isDynamic`, then see the
262259
/// `getStaticDimSize` method instead.
263-
DynSize getDynamicDimSize(Dimension d) const {
260+
Size getDynamicDimSize(Dimension d) const {
264261
assert(d < getDimRank() && "Dimension is out of bounds");
265262
return getDimShape()[d];
266263
}
267264

268-
/// Safely looks up the requested dimension-size, mapping dynamic
269-
/// sizes to `std::nullopt`.
270-
std::optional<StaticSize> getStaticDimSize(Dimension d) const {
271-
const DynSize sh = getDynamicDimSize(d);
272-
return ShapedType::isDynamic(sh) ? std::nullopt
273-
: std::optional<StaticSize>(sh);
274-
}
275-
276265
/// Returns true if no dimension has dynamic size.
277266
bool hasStaticDimShape() const { return rtp.hasStaticShape(); }
278267

@@ -318,12 +307,16 @@ class SparseTensorType {
318307

319308
/// Returns the coordinate-overhead MLIR type, defaulting to `IndexType`.
320309
Type getCrdType() const {
321-
return detail::getIntegerOrIndexType(getContext(), getCrdWidth());
310+
if (getCrdWidth())
311+
return IntegerType::get(getContext(), getCrdWidth());
312+
return IndexType::get(getContext());
322313
}
323314

324315
/// Returns the position-overhead MLIR type, defaulting to `IndexType`.
325316
Type getPosType() const {
326-
return detail::getIntegerOrIndexType(getContext(), getPosWidth());
317+
if (getPosWidth())
318+
return IntegerType::get(getContext(), getPosWidth());
319+
return IndexType::get(getContext());
327320
}
328321

329322
private:
@@ -336,14 +329,13 @@ class SparseTensorType {
336329
const AffineMap lvlToDim;
337330
};
338331

339-
/// Convenience methods to abbreviate wrapping `getRankedTensorType`.
340-
template <typename T>
341-
inline SparseTensorType getSparseTensorType(T t) {
342-
return SparseTensorType(getRankedTensorType(t));
332+
/// Convenience methods to obtain a SparseTensorType from a Value.
333+
inline SparseTensorType getSparseTensorType(Value val) {
334+
return SparseTensorType(cast<RankedTensorType>(val.getType()));
343335
}
344-
inline std::optional<SparseTensorType> tryGetSparseTensorType(Value v) {
345-
if (isa<RankedTensorType>(v.getType()))
346-
return getSparseTensorType(v);
336+
inline std::optional<SparseTensorType> tryGetSparseTensorType(Value val) {
337+
if (auto rtp = dyn_cast<RankedTensorType>(val.getType()))
338+
return SparseTensorType(rtp);
347339
return std::nullopt;
348340
}
349341

mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp

Lines changed: 9 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -270,23 +270,6 @@ SparseTensorDimSliceAttr::verify(function_ref<InFlightDiagnostic()> emitError,
270270
return success();
271271
}
272272

273-
Type mlir::sparse_tensor::detail::getIntegerOrIndexType(MLIRContext *ctx,
274-
unsigned bitwidth) {
275-
if (bitwidth)
276-
return IntegerType::get(ctx, bitwidth);
277-
return IndexType::get(ctx);
278-
}
279-
280-
Type SparseTensorEncodingAttr::getPosType() const {
281-
assert(getImpl() && "Uninitialized SparseTensorEncodingAttr");
282-
return detail::getIntegerOrIndexType(getContext(), getPosWidth());
283-
}
284-
285-
Type SparseTensorEncodingAttr::getCrdType() const {
286-
assert(getImpl() && "Uninitialized SparseTensorEncodingAttr");
287-
return detail::getIntegerOrIndexType(getContext(), getCrdWidth());
288-
}
289-
290273
SparseTensorEncodingAttr
291274
SparseTensorEncodingAttr::withDimToLvl(AffineMap dimToLvl) const {
292275
assert(getImpl() && "Uninitialized SparseTensorEncodingAttr");
@@ -722,7 +705,7 @@ SparseTensorEncodingAttr::verify(function_ref<InFlightDiagnostic()> emitError,
722705
}
723706

724707
LogicalResult SparseTensorEncodingAttr::verifyEncoding(
725-
ArrayRef<DynSize> dimShape, Type elementType,
708+
ArrayRef<Size> dimShape, Type elementType,
726709
function_ref<InFlightDiagnostic()> emitError) const {
727710
// Check structural integrity. In particular, this ensures that the
728711
// level-rank is coherent across all the fields.
@@ -1312,7 +1295,7 @@ OpFoldResult LvlOp::fold(FoldAdaptor adaptor) {
13121295

13131296
// TODO: we can remove this after SparseTensorEncoding always returns non-null
13141297
// dimToLvl map.
1315-
ArrayRef<DynSize> shape = stt.getDimShape();
1298+
ArrayRef<Size> shape = stt.getDimShape();
13161299
if (stt.isPermutation()) {
13171300
Dimension dim = toOrigDim(stt, lvl);
13181301
if (!ShapedType::isDynamic(shape[dim])) {
@@ -1378,8 +1361,8 @@ LogicalResult ReinterpretMapOp::verify() {
13781361
if (srcStt.getElementType() != dstStt.getElementType())
13791362
return emitError("Element type mismatch between source/dest tensors");
13801363

1381-
SmallVector<DynSize> srcLvlShape = srcStt.getLvlShape();
1382-
SmallVector<DynSize> dstLvlShape = dstStt.getLvlShape();
1364+
SmallVector<Size> srcLvlShape = srcStt.getLvlShape();
1365+
SmallVector<Size> dstLvlShape = dstStt.getLvlShape();
13831366
for (auto [srcLvlSz, dstLvlSz] : llvm::zip(srcLvlShape, dstLvlShape)) {
13841367
if (srcLvlSz != dstLvlSz) {
13851368
// Should we allow one side to be dynamic size, e.g., <?x?> should be
@@ -1616,13 +1599,13 @@ LogicalResult ConcatenateOp::verify() {
16161599
}
16171600

16181601
for (Dimension d = 0; d < dimRank; d++) {
1619-
const DynSize dstSh = dstTp.getDimShape()[d];
1602+
const Size dstSh = dstTp.getDimShape()[d];
16201603
if (d == concatDim) {
16211604
if (!ShapedType::isDynamic(dstSh)) {
16221605
// If we reach here, then all inputs have static shapes. So we
16231606
// can use `getDimShape()[d]` instead of `*getDynamicDimSize(d)`
16241607
// to avoid redundant assertions in the loop.
1625-
StaticSize sumSz = 0;
1608+
Size sumSz = 0;
16261609
for (const auto src : getInputs())
16271610
sumSz += getSparseTensorType(src).getDimShape()[d];
16281611
// If all dimension are statically known, the sum of all the input
@@ -1633,7 +1616,7 @@ LogicalResult ConcatenateOp::verify() {
16331616
"sum of all the concatenation dimensions of the input tensors.");
16341617
}
16351618
} else {
1636-
DynSize prev = dstSh;
1619+
Size prev = dstSh;
16371620
for (const auto src : getInputs()) {
16381621
const auto sh = getSparseTensorType(src).getDimShape()[d];
16391622
if (!ShapedType::isDynamic(prev) && sh != prev)
@@ -1808,8 +1791,8 @@ LogicalResult SortOp::verify() {
18081791
// FIXME: update the types of variables used in expressions bassed as
18091792
// the `minSize` argument, to avoid implicit casting at the callsites
18101793
// of this lambda.
1811-
const auto checkDim = [&](Value v, StaticSize minSize, const char *message) {
1812-
const DynSize sh = getMemRefType(v).getShape()[0];
1794+
const auto checkDim = [&](Value v, Size minSize, const char *message) {
1795+
const Size sh = getMemRefType(v).getShape()[0];
18131796
if (!ShapedType::isDynamic(sh) && sh < minSize)
18141797
emitError(llvm::formatv("{0} got {1} < {2}", message, sh, minSize));
18151798
};

mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,6 @@ OverheadType mlir::sparse_tensor::overheadTypeEncoding(Type tp) {
5151
llvm_unreachable("Unknown overhead type");
5252
}
5353

54-
// TODO: should offer an overload of this that takes a `MLIRContext*`
55-
// instead of the builder, similar to `detail::getIntegerOrIndexType`.
5654
Type mlir::sparse_tensor::getOverheadType(Builder &builder, OverheadType ot) {
5755
switch (ot) {
5856
case OverheadType::kIndex:
@@ -209,7 +207,7 @@ Value mlir::sparse_tensor::genIsNonzero(OpBuilder &builder, mlir::Location loc,
209207

210208
void mlir::sparse_tensor::genReshapeDstShape(
211209
OpBuilder &builder, Location loc, SmallVectorImpl<Value> &dstShape,
212-
ArrayRef<Value> srcShape, ArrayRef<StaticSize> staticDstShape,
210+
ArrayRef<Value> srcShape, ArrayRef<Size> staticDstShape,
213211
ArrayRef<ReassociationIndices> reassociation) {
214212
// Collapse shape.
215213
if (reassociation.size() < srcShape.size()) {
@@ -242,7 +240,7 @@ void mlir::sparse_tensor::genReshapeDstShape(
242240
if (staticDstShape[j] == ShapedType::kDynamic) {
243241
// The expanded dimension has dynamic size. We compute the dimension
244242
// by dividing srcDim by the product of the static dimensions.
245-
StaticSize product = 1;
243+
Size product = 1;
246244
for (unsigned k = start; k < start + map.size(); k++) {
247245
if (staticDstShape[k] != ShapedType::kDynamic) {
248246
product *= staticDstShape[k];
@@ -423,7 +421,8 @@ Operation *mlir::sparse_tensor::getTop(Operation *op) {
423421
void sparse_tensor::foreachInSparseConstant(
424422
OpBuilder &builder, Location loc, SparseElementsAttr attr, AffineMap order,
425423
function_ref<void(ArrayRef<Value>, Value)> callback) {
426-
const Dimension dimRank = getSparseTensorType(attr).getDimRank();
424+
const Dimension dimRank =
425+
SparseTensorType(getRankedTensorType(attr)).getDimRank();
427426
const auto coordinates = attr.getIndices().getValues<IntegerAttr>();
428427
const auto values = attr.getValues().getValues<Attribute>();
429428

@@ -494,8 +493,8 @@ SmallVector<Value> sparse_tensor::loadAll(OpBuilder &builder, Location loc,
494493
#ifndef NDEBUG
495494
const auto memTp = cast<MemRefType>(mem.getType());
496495
assert(memTp.getRank() == 1);
497-
const DynSize memSh = memTp.getDimSize(0);
498-
assert(ShapedType::isDynamic(memSh) || memSh >= static_cast<DynSize>(size));
496+
const Size memSh = memTp.getDimSize(0);
497+
assert(ShapedType::isDynamic(memSh) || memSh >= static_cast<Size>(size));
499498
assert(offsetIdx == 0 || offsetIdx < size);
500499
#endif // NDEBUG
501500
SmallVector<Value> vs;
@@ -516,8 +515,8 @@ void sparse_tensor::storeAll(OpBuilder &builder, Location loc, Value mem,
516515
const size_t vsize = vs.size();
517516
const auto memTp = cast<MemRefType>(mem.getType());
518517
assert(memTp.getRank() == 1);
519-
const DynSize memSh = memTp.getDimSize(0);
520-
assert(ShapedType::isDynamic(memSh) || memSh >= static_cast<DynSize>(vsize));
518+
const Size memSh = memTp.getDimSize(0);
519+
assert(ShapedType::isDynamic(memSh) || memSh >= static_cast<Size>(vsize));
521520
assert(offsetIdx == 0 || offsetIdx < vsize);
522521
#endif // NDEBUG
523522
for (const auto &v : llvm::enumerate(vs)) {
@@ -546,11 +545,11 @@ Value sparse_tensor::reshapeValuesToLevels(OpBuilder &builder, Location loc,
546545
// The memref ReshapeOp requires the sizes buffer to have a static
547546
// shape.
548547
const auto iTp = builder.getIndexType();
549-
const SmallVector<DynSize, 1> lvlSizesShape{static_cast<DynSize>(lvlRank)};
548+
const SmallVector<Size, 1> lvlSizesShape{static_cast<Size>(lvlRank)};
550549
const auto lvlSizesTp = MemRefType::get(lvlSizesShape, iTp);
551550
lvlCoords = builder.create<memref::CastOp>(loc, lvlSizesTp, lvlCoords);
552551
// Finally, create the ReshapeOp.
553-
const SmallVector<DynSize> resShape(lvlRank, ShapedType::kDynamic);
552+
const SmallVector<Size> resShape(lvlRank, ShapedType::kDynamic);
554553
const Type elemTp = getMemRefType(valuesBuffer).getElementType();
555554
const auto resTp = MemRefType::get(resShape, elemTp);
556555
return builder.create<memref::ReshapeOp>(loc, resTp, valuesBuffer, lvlCoords);
@@ -628,7 +627,7 @@ void sparse_tensor::fillDimShape(OpBuilder &builder, Location loc,
628627
SmallVectorImpl<Value> &out) {
629628
out.clear();
630629
out.reserve(stt.getDimRank());
631-
for (const DynSize sh : stt.getDimShape()) {
630+
for (const Size sh : stt.getDimShape()) {
632631
const auto s = ShapedType::isDynamic(sh) ? 0 : sh;
633632
out.push_back(constantIndex(builder, loc, s));
634633
}

0 commit comments

Comments
 (0)