Skip to content

Commit 5b72950

Browse files
authored
[mlir][sparse] move all COO related methods into SparseTensorType (#73881)
This centralizes all COO methods, and provides a cleaner API. Note that the "enc" only constructor is a temporary workaround the need for COO methods inside the "enc" only storage specifier.
1 parent ea5b1ef commit 5b72950

File tree

8 files changed

+65
-72
lines changed

8 files changed

+65
-72
lines changed

mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -89,19 +89,6 @@ inline MemRefType getMemRefType(T &&t) {
8989
/// Returns null-attribute for any type without an encoding.
9090
SparseTensorEncodingAttr getSparseTensorEncoding(Type type);
9191

92-
/// Returns true iff the given sparse tensor encoding attribute has a trailing
93-
/// COO region starting at the given level.
94-
bool isCOOType(SparseTensorEncodingAttr enc, Level startLvl, bool isUnique);
95-
96-
/// Returns true iff the given type is a COO type where the last level
97-
/// is unique.
98-
bool isUniqueCOOType(Type tp);
99-
100-
/// Returns the starting level for a trailing COO region that spans
101-
/// at least two levels. If no such COO region is found, then returns
102-
/// the level-rank.
103-
Level getCOOStart(SparseTensorEncodingAttr enc);
104-
10592
/// Returns true iff MLIR operand has any sparse operand.
10693
inline bool hasAnySparseOperand(Operation *op) {
10794
return llvm::any_of(op->getOperands().getTypes(), [](Type t) {

mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,12 @@ class SparseTensorType {
6060
: SparseTensorType(
6161
RankedTensorType::get(stp.getShape(), stp.getElementType(), enc)) {}
6262

63+
// TODO: remove?
64+
SparseTensorType(SparseTensorEncodingAttr enc)
65+
: SparseTensorType(RankedTensorType::get(
66+
SmallVector<Size>(enc.getDimRank(), ShapedType::kDynamic),
67+
Float32Type::get(enc.getContext()), enc)) {}
68+
6369
SparseTensorType &operator=(const SparseTensorType &) = delete;
6470
SparseTensorType(const SparseTensorType &) = default;
6571

@@ -234,9 +240,9 @@ class SparseTensorType {
234240
CrdTransDirectionKind::dim2lvl);
235241
}
236242

243+
/// Returns the type with an identity mapping.
237244
RankedTensorType getDemappedType() const {
238-
auto lvlShape = getLvlShape();
239-
return RankedTensorType::get(lvlShape, rtp.getElementType(),
245+
return RankedTensorType::get(getLvlShape(), getElementType(),
240246
enc.withoutDimToLvl());
241247
}
242248

@@ -311,6 +317,16 @@ class SparseTensorType {
311317
return IndexType::get(getContext());
312318
}
313319

320+
/// Returns true iff this sparse tensor type has a trailing
321+
/// COO region starting at the given level. By default, it
322+
/// tests for a unique COO type at top level.
323+
bool isCOOType(Level startLvl = 0, bool isUnique = true) const;
324+
325+
/// Returns the starting level of this sparse tensor type for a
326+
/// trailing COO region that spans **at least** two levels. If
327+
/// no such COO region is found, then returns the level-rank.
328+
Level getCOOStart() const;
329+
314330
/// Returns [un]ordered COO type for this sparse tensor type.
315331
RankedTensorType getCOOType(bool ordered) const;
316332

mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp

Lines changed: 36 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ void StorageLayout::foreachField(
6666
callback) const {
6767
const auto lvlTypes = enc.getLvlTypes();
6868
const Level lvlRank = enc.getLvlRank();
69-
const Level cooStart = getCOOStart(enc);
69+
const Level cooStart = SparseTensorType(enc).getCOOStart();
7070
const Level end = cooStart == lvlRank ? cooStart : cooStart + 1;
7171
FieldIndex fieldIdx = kDataFieldStartingIdx;
7272
// Per-level storage.
@@ -158,7 +158,7 @@ StorageLayout::getFieldIndexAndStride(SparseTensorFieldKind kind,
158158
unsigned stride = 1;
159159
if (kind == SparseTensorFieldKind::CrdMemRef) {
160160
assert(lvl.has_value());
161-
const Level cooStart = getCOOStart(enc);
161+
const Level cooStart = SparseTensorType(enc).getCOOStart();
162162
const Level lvlRank = enc.getLvlRank();
163163
if (lvl.value() >= cooStart && lvl.value() < lvlRank) {
164164
lvl = cooStart;
@@ -710,6 +710,29 @@ LogicalResult SparseTensorEncodingAttr::verifyEncoding(
710710
// SparseTensorType Methods.
711711
//===----------------------------------------------------------------------===//
712712

713+
bool mlir::sparse_tensor::SparseTensorType::isCOOType(Level startLvl,
714+
bool isUnique) const {
715+
if (!hasEncoding())
716+
return false;
717+
if (!isCompressedLvl(startLvl) && !isLooseCompressedLvl(startLvl))
718+
return false;
719+
for (Level l = startLvl + 1; l < lvlRank; ++l)
720+
if (!isSingletonLvl(l))
721+
return false;
722+
// If isUnique is true, then make sure that the last level is unique,
723+
// that is, when lvlRank == 1, the only compressed level is unique,
724+
// and when lvlRank > 1, the last singleton is unique.
725+
return !isUnique || isUniqueLvl(lvlRank - 1);
726+
}
727+
728+
Level mlir::sparse_tensor::SparseTensorType::getCOOStart() const {
729+
if (hasEncoding() && lvlRank > 1)
730+
for (Level l = 0; l < lvlRank - 1; l++)
731+
if (isCOOType(l, /*isUnique=*/false))
732+
return l;
733+
return lvlRank;
734+
}
735+
713736
RankedTensorType
714737
mlir::sparse_tensor::SparseTensorType::getCOOType(bool ordered) const {
715738
SmallVector<LevelType> lvlTypes;
@@ -859,25 +882,6 @@ bool mlir::sparse_tensor::isBlockSparsity(AffineMap dimToLvl) {
859882
return !coeffientMap.empty();
860883
}
861884

862-
bool mlir::sparse_tensor::isCOOType(SparseTensorEncodingAttr enc,
863-
Level startLvl, bool isUnique) {
864-
if (!enc ||
865-
!(enc.isCompressedLvl(startLvl) || enc.isLooseCompressedLvl(startLvl)))
866-
return false;
867-
const Level lvlRank = enc.getLvlRank();
868-
for (Level l = startLvl + 1; l < lvlRank; ++l)
869-
if (!enc.isSingletonLvl(l))
870-
return false;
871-
// If isUnique is true, then make sure that the last level is unique,
872-
// that is, lvlRank == 1 (unique the only compressed) and lvlRank > 1
873-
// (unique on the last singleton).
874-
return !isUnique || enc.isUniqueLvl(lvlRank - 1);
875-
}
876-
877-
bool mlir::sparse_tensor::isUniqueCOOType(Type tp) {
878-
return isCOOType(getSparseTensorEncoding(tp), 0, /*isUnique=*/true);
879-
}
880-
881885
bool mlir::sparse_tensor::hasAnyNonIdentityOperandsOrResults(Operation *op) {
882886
auto hasNonIdentityMap = [](Value v) {
883887
auto stt = tryGetSparseTensorType(v);
@@ -888,17 +892,6 @@ bool mlir::sparse_tensor::hasAnyNonIdentityOperandsOrResults(Operation *op) {
888892
llvm::any_of(op->getResults(), hasNonIdentityMap);
889893
}
890894

891-
Level mlir::sparse_tensor::getCOOStart(SparseTensorEncodingAttr enc) {
892-
// We only consider COO region with at least two levels for the purpose
893-
// of AOS storage optimization.
894-
const Level lvlRank = enc.getLvlRank();
895-
if (lvlRank > 1)
896-
for (Level l = 0; l < lvlRank - 1; l++)
897-
if (isCOOType(enc, l, /*isUnique=*/false))
898-
return l;
899-
return lvlRank;
900-
}
901-
902895
Dimension mlir::sparse_tensor::toDim(SparseTensorEncodingAttr enc, Level l) {
903896
if (enc) {
904897
assert(enc.isPermutation() && "Non permutation map not supported");
@@ -1013,7 +1006,7 @@ static LogicalResult verifyPackUnPack(Operation *op, bool requiresStaticShape,
10131006
return op->emitError("the sparse-tensor must have the identity mapping");
10141007

10151008
// Verifies the trailing COO.
1016-
Level cooStartLvl = getCOOStart(stt.getEncoding());
1009+
Level cooStartLvl = stt.getCOOStart();
10171010
if (cooStartLvl < stt.getLvlRank()) {
10181011
// We only supports trailing COO for now, must be the last input.
10191012
auto cooTp = llvm::cast<ShapedType>(lvlTps.back());
@@ -1309,34 +1302,34 @@ OpFoldResult ReinterpretMapOp::fold(FoldAdaptor adaptor) {
13091302
}
13101303

13111304
LogicalResult ToPositionsOp::verify() {
1312-
auto e = getSparseTensorEncoding(getTensor().getType());
1305+
auto stt = getSparseTensorType(getTensor());
13131306
if (failed(lvlIsInBounds(getLevel(), getTensor())))
13141307
return emitError("requested level is out of bounds");
1315-
if (failed(isMatchingWidth(getResult(), e.getPosWidth())))
1308+
if (failed(isMatchingWidth(getResult(), stt.getPosWidth())))
13161309
return emitError("unexpected type for positions");
13171310
return success();
13181311
}
13191312

13201313
LogicalResult ToCoordinatesOp::verify() {
1321-
auto e = getSparseTensorEncoding(getTensor().getType());
1314+
auto stt = getSparseTensorType(getTensor());
13221315
if (failed(lvlIsInBounds(getLevel(), getTensor())))
13231316
return emitError("requested level is out of bounds");
1324-
if (failed(isMatchingWidth(getResult(), e.getCrdWidth())))
1317+
if (failed(isMatchingWidth(getResult(), stt.getCrdWidth())))
13251318
return emitError("unexpected type for coordinates");
13261319
return success();
13271320
}
13281321

13291322
LogicalResult ToCoordinatesBufferOp::verify() {
1330-
auto e = getSparseTensorEncoding(getTensor().getType());
1331-
if (getCOOStart(e) >= e.getLvlRank())
1323+
auto stt = getSparseTensorType(getTensor());
1324+
if (stt.getCOOStart() >= stt.getLvlRank())
13321325
return emitError("expected sparse tensor with a COO region");
13331326
return success();
13341327
}
13351328

13361329
LogicalResult ToValuesOp::verify() {
1337-
auto ttp = getRankedTensorType(getTensor());
1330+
auto stt = getSparseTensorType(getTensor());
13381331
auto mtp = getMemRefType(getResult());
1339-
if (ttp.getElementType() != mtp.getElementType())
1332+
if (stt.getElementType() != mtp.getElementType())
13401333
return emitError("unexpected mismatch in element types");
13411334
return success();
13421335
}
@@ -1660,9 +1653,8 @@ LogicalResult ReorderCOOOp::verify() {
16601653
SparseTensorType srcStt = getSparseTensorType(getInputCoo());
16611654
SparseTensorType dstStt = getSparseTensorType(getResultCoo());
16621655

1663-
if (!isCOOType(srcStt.getEncoding(), 0, /*isUnique=*/true) ||
1664-
!isCOOType(dstStt.getEncoding(), 0, /*isUnique=*/true))
1665-
emitError("Unexpected non-COO sparse tensors");
1656+
if (!srcStt.isCOOType() || !dstStt.isCOOType())
1657+
emitError("Expected COO sparse tensors only");
16661658

16671659
if (!srcStt.hasSameDimToLvl(dstStt))
16681660
emitError("Unmatched dim2lvl map between input and result COO");

mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -412,8 +412,7 @@ void LoopEmitter::initializeLoopEmit(
412412
auto stt = getSparseTensorType(tensor);
413413
const Level lvlRank = stt.getLvlRank();
414414
const auto shape = rtp.getShape();
415-
const auto enc = getSparseTensorEncoding(rtp);
416-
const Level cooStart = enc ? getCOOStart(enc) : lvlRank;
415+
const Level cooStart = stt.getCOOStart();
417416

418417
SmallVector<Value> lvlSzs;
419418
for (Level l = 0; l < stt.getLvlRank(); l++) {
@@ -457,8 +456,8 @@ void LoopEmitter::initializeLoopEmit(
457456
// values.
458457
// Delegates extra output initialization to clients.
459458
bool isOutput = isOutputTensor(t);
460-
Type elementType = rtp.getElementType();
461-
if (!enc) {
459+
Type elementType = stt.getElementType();
460+
if (!stt.hasEncoding()) {
462461
// Non-annotated dense tensors.
463462
BaseMemRefType denseTp = MemRefType::get(shape, elementType);
464463

mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,7 @@ static void createAllocFields(OpBuilder &builder, Location loc,
194194
valHeuristic =
195195
builder.create<arith::MulIOp>(loc, valHeuristic, lvlSizesValues[lvl]);
196196
} else if (sizeHint) {
197-
if (getCOOStart(stt.getEncoding()) == 0) {
197+
if (stt.getCOOStart() == 0) {
198198
posHeuristic = constantIndex(builder, loc, 2);
199199
crdHeuristic = builder.create<arith::MulIOp>(
200200
loc, constantIndex(builder, loc, lvlRank), sizeHint); // AOS
@@ -657,8 +657,7 @@ struct SparseReorderCOOConverter : public OpConversionPattern<ReorderCOOOp> {
657657

658658
// Should have been verified.
659659
assert(dstStt.isAllOrdered() && !srcStt.isAllOrdered() &&
660-
isUniqueCOOType(srcStt.getRankedTensorType()) &&
661-
isUniqueCOOType(dstStt.getRankedTensorType()));
660+
dstStt.isCOOType() && srcStt.isCOOType());
662661
assert(dstStt.hasSameDimToLvl(srcStt));
663662

664663
// We don't need a mutable descriptor here as we perform sorting in-place.
@@ -1317,7 +1316,7 @@ struct SparseAssembleOpConverter : public OpConversionPattern<AssembleOp> {
13171316
Value posBack = c0; // index to the last value in the position array
13181317
Value memSize = c1; // memory size for current array
13191318

1320-
Level trailCOOStart = getCOOStart(stt.getEncoding());
1319+
Level trailCOOStart = stt.getCOOStart();
13211320
Level trailCOORank = stt.getLvlRank() - trailCOOStart;
13221321
// Sets up SparseTensorSpecifier.
13231322
for (Level lvl = 0, lvlRank = stt.getLvlRank(); lvl < lvlRank; lvl++) {
@@ -1454,7 +1453,7 @@ struct SparseNewConverter : public OpConversionPattern<NewOp> {
14541453
const auto dstTp = getSparseTensorType(op.getResult());
14551454
// Creating COO with NewOp is handled by direct IR codegen. All other cases
14561455
// are handled by rewriting.
1457-
if (!dstTp.hasEncoding() || getCOOStart(dstTp.getEncoding()) != 0)
1456+
if (!dstTp.hasEncoding() || dstTp.getCOOStart() != 0)
14581457
return failure();
14591458

14601459
// Implement as follows:

mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorDescriptor.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ void SparseTensorSpecifier::setSpecifierField(OpBuilder &builder, Location loc,
103103

104104
Value sparse_tensor::SparseTensorDescriptor::getCrdMemRefOrView(
105105
OpBuilder &builder, Location loc, Level lvl) const {
106-
const Level cooStart = getCOOStart(rType.getEncoding());
106+
const Level cooStart = rType.getCOOStart();
107107
if (lvl < cooStart)
108108
return getMemRefField(SparseTensorFieldKind::CrdMemRef, lvl);
109109

mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorDescriptor.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ class SparseTensorDescriptorImpl {
137137
}
138138

139139
Value getAOSMemRef() const {
140-
const Level cooStart = getCOOStart(rType.getEncoding());
140+
const Level cooStart = rType.getCOOStart();
141141
assert(cooStart < rType.getLvlRank());
142142
return getMemRefField(SparseTensorFieldKind::CrdMemRef, cooStart);
143143
}

mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1180,8 +1180,7 @@ struct NewRewriter : public OpRewritePattern<NewOp> {
11801180
PatternRewriter &rewriter) const override {
11811181
Location loc = op.getLoc();
11821182
auto stt = getSparseTensorType(op.getResult());
1183-
auto enc = stt.getEncoding();
1184-
if (!stt.hasEncoding() || getCOOStart(enc) == 0)
1183+
if (!stt.hasEncoding() || stt.getCOOStart() == 0)
11851184
return failure();
11861185

11871186
// Implement the NewOp as follows:
@@ -1192,6 +1191,7 @@ struct NewRewriter : public OpRewritePattern<NewOp> {
11921191
RankedTensorType cooTp = stt.getCOOType(/*ordered=*/true);
11931192
Value cooTensor = rewriter.create<NewOp>(loc, cooTp, op.getSource());
11941193
Value convert = cooTensor;
1194+
auto enc = stt.getEncoding();
11951195
if (!stt.isPermutation()) { // demap coo, demap dstTp
11961196
auto coo = getSparseTensorType(cooTensor).getEncoding().withoutDimToLvl();
11971197
convert = rewriter.create<ReinterpretMapOp>(loc, coo, convert);

0 commit comments

Comments
 (0)