Skip to content

Commit 4528808

Browse files
authored
[mlir][sparse] move toCOOType into SparseTensorType class (llvm#73708)
Migrates dangling convenience method into proper SparseTensorType class. Also cleans up some details (picking right dim2lvl/lvl2dim). Removes more dead code.
1 parent 351c3ee commit 4528808

File tree

5 files changed

+40
-72
lines changed

5 files changed

+40
-72
lines changed

mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -102,10 +102,6 @@ bool isUniqueCOOType(Type tp);
102102
/// the level-rank.
103103
Level getCOOStart(SparseTensorEncodingAttr enc);
104104

105-
/// Helper to setup a COO type.
106-
RankedTensorType getCOOFromTypeWithOrdering(RankedTensorType src,
107-
AffineMap ordering, bool ordered);
108-
109105
/// Returns true iff MLIR operand has any sparse operand.
110106
inline bool hasAnySparseOperand(Operation *op) {
111107
return llvm::any_of(op->getOperands().getTypes(), [](Type t) {

mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h

Lines changed: 9 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -64,18 +64,14 @@ class SparseTensorType {
6464
SparseTensorType(const SparseTensorType &) = default;
6565

6666
//
67-
// Factory methods.
67+
// Factory methods to construct a new `SparseTensorType`
68+
// with the same dimension-shape and element type.
6869
//
6970

70-
/// Constructs a new `SparseTensorType` with the same dimension-shape
71-
/// and element type, but with the encoding replaced by the given encoding.
7271
SparseTensorType withEncoding(SparseTensorEncodingAttr newEnc) const {
7372
return SparseTensorType(rtp, newEnc);
7473
}
7574

76-
/// Constructs a new `SparseTensorType` with the same dimension-shape
77-
/// and element type, but with the encoding replaced by
78-
/// `getEncoding().withDimToLvl(dimToLvl)`.
7975
SparseTensorType withDimToLvl(AffineMap dimToLvl) const {
8076
return withEncoding(enc.withDimToLvl(dimToLvl));
8177
}
@@ -88,23 +84,14 @@ class SparseTensorType {
8884
return withDimToLvl(dimToLvlSTT.getEncoding());
8985
}
9086

91-
/// Constructs a new `SparseTensorType` with the same dimension-shape
92-
/// and element type, but with the encoding replaced by
93-
/// `getEncoding().withoutDimToLvl()`.
9487
SparseTensorType withoutDimToLvl() const {
9588
return withEncoding(enc.withoutDimToLvl());
9689
}
9790

98-
/// Constructs a new `SparseTensorType` with the same dimension-shape
99-
/// and element type, but with the encoding replaced by
100-
/// `getEncoding().withBitWidths(posWidth, crdWidth)`.
10191
SparseTensorType withBitWidths(unsigned posWidth, unsigned crdWidth) const {
10292
return withEncoding(enc.withBitWidths(posWidth, crdWidth));
10393
}
10494

105-
/// Constructs a new `SparseTensorType` with the same dimension-shape
106-
/// and element type, but with the encoding replaced by
107-
/// `getEncoding().withoutBitWidths()`.
10895
SparseTensorType withoutBitWidths() const {
10996
return withEncoding(enc.withoutBitWidths());
11097
}
@@ -118,10 +105,6 @@ class SparseTensorType {
118105
return withEncoding(enc.withoutDimSlices());
119106
}
120107

121-
//
122-
// Other methods.
123-
//
124-
125108
/// Allow implicit conversion to `RankedTensorType`, `ShapedType`,
126109
/// and `Type`. These are implicit to help alleviate the impedance
127110
/// mismatch for code that has not been converted to use `SparseTensorType`
@@ -170,7 +153,6 @@ class SparseTensorType {
170153

171154
Type getElementType() const { return rtp.getElementType(); }
172155

173-
/// Returns the encoding (or the null-attribute for dense-tensors).
174156
SparseTensorEncodingAttr getEncoding() const { return enc; }
175157

176158
//
@@ -204,6 +186,10 @@ class SparseTensorType {
204186
/// (This is always true for dense-tensors.)
205187
bool isIdentity() const { return enc.isIdentity(); }
206188

189+
//
190+
// Other methods.
191+
//
192+
207193
/// Returns the dimToLvl mapping (or the null-map for the identity).
208194
/// If you intend to compare the results of this method for equality,
209195
/// see `hasSameDimToLvl` instead.
@@ -325,6 +311,9 @@ class SparseTensorType {
325311
return IndexType::get(getContext());
326312
}
327313

314+
/// Returns [un]ordered COO type for this sparse tensor type.
315+
RankedTensorType getCOOType(bool ordered) const;
316+
328317
private:
329318
// These two must be const, to ensure coherence of the memoized fields.
330319
const RankedTensorType rtp;

mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp

Lines changed: 27 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ using namespace mlir;
3636
using namespace mlir::sparse_tensor;
3737

3838
//===----------------------------------------------------------------------===//
39-
// Local convenience methods.
39+
// Local Convenience Methods.
4040
//===----------------------------------------------------------------------===//
4141

4242
static constexpr bool acceptBitWidth(unsigned bitWidth) {
@@ -711,7 +711,32 @@ LogicalResult SparseTensorEncodingAttr::verifyEncoding(
711711
}
712712

713713
//===----------------------------------------------------------------------===//
714-
// Convenience methods.
714+
// SparseTensorType Methods.
715+
//===----------------------------------------------------------------------===//
716+
717+
RankedTensorType
718+
mlir::sparse_tensor::SparseTensorType::getCOOType(bool ordered) const {
719+
SmallVector<LevelType> lvlTypes;
720+
lvlTypes.reserve(lvlRank);
721+
// A non-unique compressed level at beginning (unless this is
722+
// also the last level, then it is unique).
723+
lvlTypes.push_back(
724+
*buildLevelType(LevelFormat::Compressed, ordered, lvlRank == 1));
725+
if (lvlRank > 1) {
726+
// Followed by n-2 non-unique singleton levels.
727+
std::fill_n(std::back_inserter(lvlTypes), lvlRank - 2,
728+
*buildLevelType(LevelFormat::Singleton, ordered, false));
729+
// Ends by a unique singleton level.
730+
lvlTypes.push_back(*buildLevelType(LevelFormat::Singleton, ordered, true));
731+
}
732+
auto enc = SparseTensorEncodingAttr::get(getContext(), lvlTypes,
733+
getDimToLvl(), getLvlToDim(),
734+
getPosWidth(), getCrdWidth());
735+
return RankedTensorType::get(getDimShape(), getElementType(), enc);
736+
}
737+
738+
//===----------------------------------------------------------------------===//
739+
// Convenience Methods.
715740
//===----------------------------------------------------------------------===//
716741

717742
SparseTensorEncodingAttr
@@ -878,39 +903,6 @@ Level mlir::sparse_tensor::getCOOStart(SparseTensorEncodingAttr enc) {
878903
return lvlRank;
879904
}
880905

881-
// Helper to setup a COO type.
882-
RankedTensorType sparse_tensor::getCOOFromTypeWithOrdering(RankedTensorType rtt,
883-
AffineMap lvlPerm,
884-
bool ordered) {
885-
const SparseTensorType src(rtt);
886-
const Level lvlRank = src.getLvlRank();
887-
SmallVector<LevelType> lvlTypes;
888-
lvlTypes.reserve(lvlRank);
889-
890-
// An unordered and non-unique compressed level at beginning.
891-
// If this is also the last level, then it is unique.
892-
lvlTypes.push_back(
893-
*buildLevelType(LevelFormat::Compressed, ordered, lvlRank == 1));
894-
if (lvlRank > 1) {
895-
// TODO: it is actually ordered at the level for ordered input.
896-
// Followed by unordered non-unique n-2 singleton levels.
897-
std::fill_n(std::back_inserter(lvlTypes), lvlRank - 2,
898-
*buildLevelType(LevelFormat::Singleton, ordered, false));
899-
// Ends by a unique singleton level unless the lvlRank is 1.
900-
lvlTypes.push_back(*buildLevelType(LevelFormat::Singleton, ordered, true));
901-
}
902-
903-
// TODO: Maybe pick the bitwidth based on input/output tensors (probably the
904-
// largest one among them) in the original operation instead of using the
905-
// default value.
906-
unsigned posWidth = src.getPosWidth();
907-
unsigned crdWidth = src.getCrdWidth();
908-
AffineMap invPerm = src.getLvlToDim();
909-
auto enc = SparseTensorEncodingAttr::get(src.getContext(), lvlTypes, lvlPerm,
910-
invPerm, posWidth, crdWidth);
911-
return RankedTensorType::get(src.getDimShape(), src.getElementType(), enc);
912-
}
913-
914906
Dimension mlir::sparse_tensor::toDim(SparseTensorEncodingAttr enc, Level l) {
915907
if (enc) {
916908
assert(enc.isPermutation() && "Non permutation map not supported");

mlir/lib/Dialect/SparseTensor/IR/SparseTensorInterfaces.cpp

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,7 @@ sparse_tensor::detail::stageWithSortImpl(StageWithSortSparseOp op,
2525
Location loc = op.getLoc();
2626
Type finalTp = op->getOpResult(0).getType();
2727
SparseTensorType dstStt(finalTp.cast<RankedTensorType>());
28-
29-
Type srcCOOTp = getCOOFromTypeWithOrdering(
30-
dstStt.getRankedTensorType(), dstStt.getDimToLvl(), /*ordered=*/false);
28+
Type srcCOOTp = dstStt.getCOOType(/*ordered=*/false);
3129

3230
// Clones the original operation but changing the output to an unordered COO.
3331
Operation *cloned = rewriter.clone(*op.getOperation());
@@ -37,8 +35,7 @@ sparse_tensor::detail::stageWithSortImpl(StageWithSortSparseOp op,
3735
Value srcCOO = cloned->getOpResult(0);
3836

3937
// -> sort
40-
Type dstCOOTp = getCOOFromTypeWithOrdering(
41-
dstStt.getRankedTensorType(), dstStt.getDimToLvl(), /*ordered=*/true);
38+
Type dstCOOTp = dstStt.getCOOType(/*ordered=*/true);
4239
Value dstCOO = rewriter.create<ReorderCOOOp>(
4340
loc, dstCOOTp, srcCOO, SparseTensorSortKind::HybridQuickSort);
4441

mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -132,15 +132,9 @@ static void sizesForTensor(OpBuilder &builder, SmallVectorImpl<Value> &sizes,
132132
}
133133
}
134134

135-
// TODO: The dim level property of the COO type relies on input tensors, the
136-
// shape relies on the output tensor
137-
static RankedTensorType getCOOType(const SparseTensorType &stt, bool ordered) {
138-
return getCOOFromTypeWithOrdering(stt, stt.getDimToLvl(), ordered);
139-
}
140-
141135
static RankedTensorType getBufferType(const SparseTensorType &stt,
142136
bool needTmpCOO) {
143-
return needTmpCOO ? getCOOType(stt, /*ordered=*/false)
137+
return needTmpCOO ? stt.getCOOType(/*ordered=*/false)
144138
: stt.getRankedTensorType();
145139
}
146140

@@ -1195,7 +1189,7 @@ struct NewRewriter : public OpRewritePattern<NewOp> {
11951189
// %t = sparse_tensor.convert %orderedCoo
11961190
// with enveloping reinterpreted_map ops for non-permutations.
11971191
RankedTensorType dstTp = stt.getRankedTensorType();
1198-
RankedTensorType cooTp = getCOOType(dstTp, /*ordered=*/true);
1192+
RankedTensorType cooTp = stt.getCOOType(/*ordered=*/true);
11991193
Value cooTensor = rewriter.create<NewOp>(loc, cooTp, op.getSource());
12001194
Value convert = cooTensor;
12011195
if (!stt.isPermutation()) { // demap coo, demap dstTp

0 commit comments

Comments
 (0)