Skip to content

Commit 4d06861

Browse files
committed
[mlir][sparse] add "sort" to the compress op codegen
This revision also adds convenience methods to test the dim level type/property (with the codegen being first client) Reviewed By: bixia Differential Revision: https://reviews.llvm.org/D134776
1 parent c598396 commit 4d06861

File tree

5 files changed

+203
-56
lines changed

5 files changed

+203
-56
lines changed

mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,43 @@
2727

2828
namespace mlir {
2929
namespace sparse_tensor {
30+
3031
/// Convenience method to get a sparse encoding attribute from a type.
3132
/// Returns null-attribute for any type without an encoding.
3233
SparseTensorEncodingAttr getSparseTensorEncoding(Type type);
34+
35+
//
36+
// Dimension level types.
37+
//
38+
39+
bool isDenseDim(SparseTensorEncodingAttr::DimLevelType dltp);
40+
bool isCompressedDim(SparseTensorEncodingAttr::DimLevelType dltp);
41+
bool isSingletonDim(SparseTensorEncodingAttr::DimLevelType dltp);
42+
43+
/// Convenience method to test for dense dimension (0 <= d < rank).
44+
bool isDenseDim(RankedTensorType type, uint64_t d);
45+
46+
/// Convenience method to test for compressed dimension (0 <= d < rank).
47+
bool isCompressedDim(RankedTensorType type, uint64_t d);
48+
49+
/// Convenience method to test for singleton dimension (0 <= d < rank).
50+
bool isSingletonDim(RankedTensorType type, uint64_t d);
51+
52+
//
53+
// Dimension level properties.
54+
//
55+
56+
bool isOrderedDim(SparseTensorEncodingAttr::DimLevelType dltp);
57+
bool isUniqueDim(SparseTensorEncodingAttr::DimLevelType dltp);
58+
59+
/// Convenience method to test for ordered property in the
60+
/// given dimension (0 <= d < rank).
61+
bool isOrderedDim(RankedTensorType type, uint64_t d);
62+
63+
/// Convenience method to test for unique property in the
64+
/// given dimension (0 <= d < rank).
65+
bool isUniqueDim(RankedTensorType type, uint64_t d);
66+
3367
} // namespace sparse_tensor
3468
} // namespace mlir
3569

mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -216,13 +216,109 @@ LogicalResult SparseTensorEncodingAttr::verifyEncoding(
216216
return success();
217217
}
218218

219+
//===----------------------------------------------------------------------===//
220+
// Convenience Methods.
221+
//===----------------------------------------------------------------------===//
222+
219223
SparseTensorEncodingAttr
220224
mlir::sparse_tensor::getSparseTensorEncoding(Type type) {
221225
if (auto ttp = type.dyn_cast<RankedTensorType>())
222226
return ttp.getEncoding().dyn_cast_or_null<SparseTensorEncodingAttr>();
223227
return nullptr;
224228
}
225229

230+
bool mlir::sparse_tensor::isDenseDim(
231+
SparseTensorEncodingAttr::DimLevelType dltp) {
232+
return dltp == SparseTensorEncodingAttr::DimLevelType::Dense;
233+
}
234+
235+
bool mlir::sparse_tensor::isCompressedDim(
236+
SparseTensorEncodingAttr::DimLevelType dltp) {
237+
switch (dltp) {
238+
case SparseTensorEncodingAttr::DimLevelType::Compressed:
239+
case SparseTensorEncodingAttr::DimLevelType::CompressedNu:
240+
case SparseTensorEncodingAttr::DimLevelType::CompressedNo:
241+
case SparseTensorEncodingAttr::DimLevelType::CompressedNuNo:
242+
return true;
243+
default:
244+
return false;
245+
}
246+
}
247+
248+
bool mlir::sparse_tensor::isSingletonDim(
249+
SparseTensorEncodingAttr::DimLevelType dltp) {
250+
switch (dltp) {
251+
case SparseTensorEncodingAttr::DimLevelType::Singleton:
252+
case SparseTensorEncodingAttr::DimLevelType::SingletonNu:
253+
case SparseTensorEncodingAttr::DimLevelType::SingletonNo:
254+
case SparseTensorEncodingAttr::DimLevelType::SingletonNuNo:
255+
return true;
256+
default:
257+
return false;
258+
}
259+
}
260+
261+
bool mlir::sparse_tensor::isDenseDim(RankedTensorType type, uint64_t d) {
262+
assert(d < static_cast<uint64_t>(type.getRank()));
263+
if (auto enc = getSparseTensorEncoding(type))
264+
return isDenseDim(enc.getDimLevelType()[d]);
265+
return true; // unannotated tensor is dense
266+
}
267+
268+
bool mlir::sparse_tensor::isCompressedDim(RankedTensorType type, uint64_t d) {
269+
assert(d < static_cast<uint64_t>(type.getRank()));
270+
if (auto enc = getSparseTensorEncoding(type))
271+
return isCompressedDim(enc.getDimLevelType()[d]);
272+
return false; // unannotated tensor is dense
273+
}
274+
275+
bool mlir::sparse_tensor::isSingletonDim(RankedTensorType type, uint64_t d) {
276+
assert(d < static_cast<uint64_t>(type.getRank()));
277+
if (auto enc = getSparseTensorEncoding(type))
278+
return isSingletonDim(enc.getDimLevelType()[d]);
279+
return false; // unannotated tensor is dense
280+
}
281+
282+
bool mlir::sparse_tensor::isOrderedDim(
283+
SparseTensorEncodingAttr::DimLevelType dltp) {
284+
switch (dltp) {
285+
case SparseTensorEncodingAttr::DimLevelType::CompressedNo:
286+
case SparseTensorEncodingAttr::DimLevelType::CompressedNuNo:
287+
case SparseTensorEncodingAttr::DimLevelType::SingletonNo:
288+
case SparseTensorEncodingAttr::DimLevelType::SingletonNuNo:
289+
return false;
290+
default:
291+
return true;
292+
}
293+
}
294+
295+
bool mlir::sparse_tensor::isUniqueDim(
296+
SparseTensorEncodingAttr::DimLevelType dltp) {
297+
switch (dltp) {
298+
case SparseTensorEncodingAttr::DimLevelType::CompressedNu:
299+
case SparseTensorEncodingAttr::DimLevelType::CompressedNuNo:
300+
case SparseTensorEncodingAttr::DimLevelType::SingletonNu:
301+
case SparseTensorEncodingAttr::DimLevelType::SingletonNuNo:
302+
return false;
303+
default:
304+
return true;
305+
}
306+
}
307+
308+
bool mlir::sparse_tensor::isOrderedDim(RankedTensorType type, uint64_t d) {
309+
assert(d < static_cast<uint64_t>(type.getRank()));
310+
if (auto enc = getSparseTensorEncoding(type))
311+
return isOrderedDim(enc.getDimLevelType()[d]);
312+
return true; // unannotated tensor is dense (and thus ordered)
313+
}
314+
315+
bool mlir::sparse_tensor::isUniqueDim(RankedTensorType type, uint64_t d) {
316+
assert(d < static_cast<uint64_t>(type.getRank()));
317+
if (auto enc = getSparseTensorEncoding(type))
318+
return isUniqueDim(enc.getDimLevelType()[d]);
319+
return true; // unannotated tensor is dense (and thus unique)
320+
}
321+
226322
//===----------------------------------------------------------------------===//
227323
// TensorDialect Operations.
228324
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp

Lines changed: 29 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -103,37 +103,28 @@ static Optional<Value> sizeFromTensorAtDim(OpBuilder &rewriter, Location loc,
103103

104104
/// Returns field index of sparse tensor type for pointers/indices, when set.
105105
static unsigned getFieldIndex(Type type, unsigned ptrDim, unsigned idxDim) {
106-
auto enc = getSparseTensorEncoding(type);
107-
assert(enc);
106+
assert(getSparseTensorEncoding(type));
108107
RankedTensorType rType = type.cast<RankedTensorType>();
109108
unsigned field = 2; // start past sizes
110109
unsigned ptr = 0;
111110
unsigned idx = 0;
112111
for (unsigned r = 0, rank = rType.getShape().size(); r < rank; r++) {
113-
switch (enc.getDimLevelType()[r]) {
114-
case SparseTensorEncodingAttr::DimLevelType::Dense:
115-
break; // no fields
116-
case SparseTensorEncodingAttr::DimLevelType::Compressed:
117-
case SparseTensorEncodingAttr::DimLevelType::CompressedNu:
118-
case SparseTensorEncodingAttr::DimLevelType::CompressedNo:
119-
case SparseTensorEncodingAttr::DimLevelType::CompressedNuNo:
112+
if (isCompressedDim(rType, r)) {
120113
if (ptr++ == ptrDim)
121114
return field;
122115
field++;
123116
if (idx++ == idxDim)
124117
return field;
125118
field++;
126-
break;
127-
case SparseTensorEncodingAttr::DimLevelType::Singleton:
128-
case SparseTensorEncodingAttr::DimLevelType::SingletonNu:
129-
case SparseTensorEncodingAttr::DimLevelType::SingletonNo:
130-
case SparseTensorEncodingAttr::DimLevelType::SingletonNuNo:
119+
} else if (isSingletonDim(rType, r)) {
131120
if (idx++ == idxDim)
132121
return field;
133122
field++;
134-
break;
123+
} else {
124+
assert(isDenseDim(rType, r)); // no fields
135125
}
136126
}
127+
assert(ptrDim == -1u && idxDim == -1u);
137128
return field + 1; // return values field index
138129
}
139130

@@ -176,30 +167,21 @@ convertSparseTensorType(Type type, SmallVectorImpl<Type> &fields) {
176167
// The dimSizes array.
177168
fields.push_back(MemRefType::get({rank}, indexType));
178169
// The memSizes array.
179-
unsigned lastField = getFieldIndex(type, -1, -1);
170+
unsigned lastField = getFieldIndex(type, -1u, -1u);
180171
fields.push_back(MemRefType::get({lastField - 2}, indexType));
181172
// Per-dimension storage.
182173
for (unsigned r = 0; r < rank; r++) {
183174
// Dimension level types apply in order to the reordered dimension.
184175
// As a result, the compound type can be constructed directly in the given
185176
// order. Clients of this type know what field is what from the sparse
186177
// tensor type.
187-
switch (enc.getDimLevelType()[r]) {
188-
case SparseTensorEncodingAttr::DimLevelType::Dense:
189-
break; // no fields
190-
case SparseTensorEncodingAttr::DimLevelType::Compressed:
191-
case SparseTensorEncodingAttr::DimLevelType::CompressedNu:
192-
case SparseTensorEncodingAttr::DimLevelType::CompressedNo:
193-
case SparseTensorEncodingAttr::DimLevelType::CompressedNuNo:
178+
if (isCompressedDim(rType, r)) {
194179
fields.push_back(MemRefType::get({ShapedType::kDynamicSize}, ptrType));
195180
fields.push_back(MemRefType::get({ShapedType::kDynamicSize}, idxType));
196-
break;
197-
case SparseTensorEncodingAttr::DimLevelType::Singleton:
198-
case SparseTensorEncodingAttr::DimLevelType::SingletonNu:
199-
case SparseTensorEncodingAttr::DimLevelType::SingletonNo:
200-
case SparseTensorEncodingAttr::DimLevelType::SingletonNuNo:
181+
} else if (isSingletonDim(rType, r)) {
201182
fields.push_back(MemRefType::get({ShapedType::kDynamicSize}, idxType));
202-
break;
183+
} else {
184+
assert(isDenseDim(rType, r)); // no fields
203185
}
204186
}
205187
// The values array.
@@ -254,7 +236,7 @@ static void createAllocFields(OpBuilder &builder, Location loc, Type type,
254236
builder.create<memref::AllocOp>(loc, MemRefType::get({rank}, indexType));
255237
fields.push_back(dimSizes);
256238
// The sizes array.
257-
unsigned lastField = getFieldIndex(type, -1, -1);
239+
unsigned lastField = getFieldIndex(type, -1u, -1u);
258240
Value memSizes = builder.create<memref::AllocOp>(
259241
loc, MemRefType::get({lastField - 2}, indexType));
260242
fields.push_back(memSizes);
@@ -265,25 +247,16 @@ static void createAllocFields(OpBuilder &builder, Location loc, Type type,
265247
builder.create<memref::StoreOp>(loc, sizes[ro], dimSizes,
266248
constantIndex(builder, loc, r));
267249
linear = builder.create<arith::MulIOp>(loc, linear, sizes[ro]);
268-
// Allocate fiels.
269-
switch (enc.getDimLevelType()[r]) {
270-
case SparseTensorEncodingAttr::DimLevelType::Dense:
271-
break; // no fields
272-
case SparseTensorEncodingAttr::DimLevelType::Compressed:
273-
case SparseTensorEncodingAttr::DimLevelType::CompressedNu:
274-
case SparseTensorEncodingAttr::DimLevelType::CompressedNo:
275-
case SparseTensorEncodingAttr::DimLevelType::CompressedNuNo:
250+
// Allocate fields.
251+
if (isCompressedDim(rType, r)) {
276252
fields.push_back(createAllocation(builder, loc, ptrType, heuristic));
277253
fields.push_back(createAllocation(builder, loc, idxType, heuristic));
278254
allDense = false;
279-
break;
280-
case SparseTensorEncodingAttr::DimLevelType::Singleton:
281-
case SparseTensorEncodingAttr::DimLevelType::SingletonNu:
282-
case SparseTensorEncodingAttr::DimLevelType::SingletonNo:
283-
case SparseTensorEncodingAttr::DimLevelType::SingletonNuNo:
255+
} else if (isSingletonDim(rType, r)) {
284256
fields.push_back(createAllocation(builder, loc, idxType, heuristic));
285257
allDense = false;
286-
break;
258+
} else {
259+
assert(isDenseDim(rType, r)); // no fields
287260
}
288261
}
289262
// The values array. For all-dense, the full length is required.
@@ -507,7 +480,8 @@ class SparseExpandConverter : public OpConversionPattern<ExpandOp> {
507480
matchAndRewrite(ExpandOp op, OpAdaptor adaptor,
508481
ConversionPatternRewriter &rewriter) const override {
509482
Location loc = op->getLoc();
510-
ShapedType srcType = op.getTensor().getType().cast<ShapedType>();
483+
RankedTensorType srcType =
484+
op.getTensor().getType().cast<RankedTensorType>();
511485
Type eltType = srcType.getElementType();
512486
Type boolType = rewriter.getIntegerType(1);
513487
Type idxType = rewriter.getIndexType();
@@ -561,17 +535,18 @@ class SparseCompressConverter : public OpConversionPattern<CompressOp> {
561535
matchAndRewrite(CompressOp op, OpAdaptor adaptor,
562536
ConversionPatternRewriter &rewriter) const override {
563537
Location loc = op->getLoc();
564-
ShapedType srcType = op.getTensor().getType().cast<ShapedType>();
565-
Type eltType = srcType.getElementType();
538+
RankedTensorType dstType =
539+
op.getTensor().getType().cast<RankedTensorType>();
540+
Type eltType = dstType.getElementType();
566541
Value values = adaptor.getValues();
567542
Value filled = adaptor.getFilled();
568543
Value added = adaptor.getAdded();
569544
Value count = adaptor.getCount();
570-
571-
//
572-
// TODO: need to implement "std::sort(added, added + count);" for ordered
573-
//
574-
545+
// If the innermost dimension is ordered, we need to sort the indices
546+
// in the "added" array prior to applying the compression.
547+
unsigned rank = dstType.getShape().size();
548+
if (isOrderedDim(dstType, rank - 1))
549+
rewriter.create<SortOp>(loc, count, ValueRange{added}, ValueRange{});
575550
// While performing the insertions, we also need to reset the elements
576551
// of the values/filled-switch by only iterating over the set elements,
577552
// to ensure that the runtime complexity remains proportional to the
@@ -699,7 +674,7 @@ class SparseToPointersConverter
699674
static unsigned getIndexForOp(UnrealizedConversionCastOp /*tuple*/,
700675
ToPointersOp op) {
701676
uint64_t dim = op.getDimension().getZExtValue();
702-
return getFieldIndex(op.getTensor().getType(), /*ptrDim=*/dim, -1);
677+
return getFieldIndex(op.getTensor().getType(), /*ptrDim=*/dim, -1u);
703678
}
704679
};
705680

@@ -712,7 +687,7 @@ class SparseToIndicesConverter
712687
static unsigned getIndexForOp(UnrealizedConversionCastOp /*tuple*/,
713688
ToIndicesOp op) {
714689
uint64_t dim = op.getDimension().getZExtValue();
715-
return getFieldIndex(op.getTensor().getType(), -1, /*idxDim=*/dim);
690+
return getFieldIndex(op.getTensor().getType(), -1u, /*idxDim=*/dim);
716691
}
717692
};
718693

mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -156,8 +156,9 @@ struct SparseTensorCodegenPass
156156
RewritePatternSet patterns(ctx);
157157
SparseTensorTypeToBufferConverter converter;
158158
ConversionTarget target(*ctx);
159-
// Everything in the sparse dialect must go!
159+
// Most ops in the sparse dialect must go!
160160
target.addIllegalDialect<SparseTensorDialect>();
161+
target.addLegalOp<SortOp>();
161162
// All dynamic rules below accept new function, call, return, and various
162163
// tensor and bufferization operations as legal output of the rewriting
163164
// provided that all sparse tensor types have been fully rewritten.

0 commit comments

Comments
 (0)