Skip to content

Commit 80b08b6

Browse files
committed
[mlir][sparse] add a cursor to sparse storage scheme
This prepare a subsequent revision that will generalize the insertion code generation. Similar to the support lib, insertions become much easier to perform with some "cursor" bookkeeping. Note that we, in the long run, could perhaps avoid storing the "cursor" permanently and use some retricted-scope solution (alloca?) instead. However, that puts harder restrictions on insertion-chain operations, so for now we follow the more straightforward approach. Reviewed By: Peiming Differential Revision: https://reviews.llvm.org/D136800
1 parent 8469041 commit 80b08b6

File tree

4 files changed

+283
-222
lines changed

4 files changed

+283
-222
lines changed

mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp

Lines changed: 33 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,11 @@ using namespace mlir::sparse_tensor;
3131

3232
namespace {
3333

34+
static constexpr uint64_t DimSizesIdx = 0;
35+
static constexpr uint64_t DimCursorIdx = 1;
36+
static constexpr uint64_t MemSizesIdx = 2;
37+
static constexpr uint64_t FieldsIdx = 3;
38+
3439
//===----------------------------------------------------------------------===//
3540
// Helper methods.
3641
//===----------------------------------------------------------------------===//
@@ -90,11 +95,17 @@ static Optional<Value> sizeFromTensorAtDim(OpBuilder &rewriter, Location loc,
9095
.getResult();
9196
}
9297

98+
/// Translates field index to memSizes index.
99+
static unsigned getMemSizesIndex(unsigned field) {
100+
assert(FieldsIdx <= field);
101+
return field - FieldsIdx;
102+
}
103+
93104
/// Returns field index of sparse tensor type for pointers/indices, when set.
94105
static unsigned getFieldIndex(Type type, unsigned ptrDim, unsigned idxDim) {
95106
assert(getSparseTensorEncoding(type));
96107
RankedTensorType rType = type.cast<RankedTensorType>();
97-
unsigned field = 2; // start past sizes
108+
unsigned field = FieldsIdx; // start past header
98109
unsigned ptr = 0;
99110
unsigned idx = 0;
100111
for (unsigned r = 0, rank = rType.getShape().size(); r < rank; r++) {
@@ -140,6 +151,7 @@ convertSparseTensorType(Type type, SmallVectorImpl<Type> &fields) {
140151
//
141152
// struct {
142153
// memref<rank x index> dimSizes ; size in each dimension
154+
// memref<rank x index> dimCursor ; cursor in each dimension
143155
// memref<n x index> memSizes ; sizes of ptrs/inds/values
144156
// ; per-dimension d:
145157
// ; if dense:
@@ -153,11 +165,11 @@ convertSparseTensorType(Type type, SmallVectorImpl<Type> &fields) {
153165
// };
154166
//
155167
unsigned rank = rType.getShape().size();
156-
// The dimSizes array.
157-
fields.push_back(MemRefType::get({rank}, indexType));
158-
// The memSizes array.
159168
unsigned lastField = getFieldIndex(type, -1u, -1u);
160-
fields.push_back(MemRefType::get({lastField - 2}, indexType));
169+
// The dimSizes array, dimCursor array, and memSizes array.
170+
fields.push_back(MemRefType::get({rank}, indexType));
171+
fields.push_back(MemRefType::get({rank}, indexType));
172+
fields.push_back(MemRefType::get({getMemSizesIndex(lastField)}, indexType));
161173
// Per-dimension storage.
162174
for (unsigned r = 0; r < rank; r++) {
163175
// Dimension level types apply in order to the reordered dimension.
@@ -179,7 +191,7 @@ convertSparseTensorType(Type type, SmallVectorImpl<Type> &fields) {
179191
return success();
180192
}
181193

182-
/// Create allocation operation.
194+
/// Creates allocation operation.
183195
static Value createAllocation(OpBuilder &builder, Location loc, Type type,
184196
Value sz) {
185197
auto memType = MemRefType::get({ShapedType::kDynamicSize}, type);
@@ -220,14 +232,16 @@ static void createAllocFields(OpBuilder &builder, Location loc, Type type,
220232
else
221233
sizes.push_back(constantIndex(builder, loc, shape[r]));
222234
}
223-
// The dimSizes array.
235+
// The dimSizes array, dimCursor array, and memSizes array.
236+
unsigned lastField = getFieldIndex(type, -1u, -1u);
224237
Value dimSizes =
225238
builder.create<memref::AllocOp>(loc, MemRefType::get({rank}, indexType));
226-
fields.push_back(dimSizes);
227-
// The sizes array.
228-
unsigned lastField = getFieldIndex(type, -1u, -1u);
239+
Value dimCursor =
240+
builder.create<memref::AllocOp>(loc, MemRefType::get({rank}, indexType));
229241
Value memSizes = builder.create<memref::AllocOp>(
230-
loc, MemRefType::get({lastField - 2}, indexType));
242+
loc, MemRefType::get({getMemSizesIndex(lastField)}, indexType));
243+
fields.push_back(dimSizes);
244+
fields.push_back(dimCursor);
231245
fields.push_back(memSizes);
232246
// Per-dimension storage.
233247
for (unsigned r = 0; r < rank; r++) {
@@ -277,23 +291,17 @@ static scf::ForOp createFor(OpBuilder &builder, Location loc, Value count,
277291
return forOp;
278292
}
279293

280-
/// Translates field index to memSizes index.
281-
static unsigned getMemSizesIndex(unsigned field) {
282-
assert(2 <= field);
283-
return field - 2;
284-
}
285-
286294
/// Creates a pushback op for given field and updates the fields array
287295
/// accordingly.
288296
static void createPushback(OpBuilder &builder, Location loc,
289297
SmallVectorImpl<Value> &fields, unsigned field,
290298
Value value) {
291-
assert(2 <= field && field < fields.size());
299+
assert(FieldsIdx <= field && field < fields.size());
292300
Type etp = fields[field].getType().cast<ShapedType>().getElementType();
293301
if (value.getType() != etp)
294302
value = builder.create<arith::IndexCastOp>(loc, etp, value);
295303
fields[field] = builder.create<PushBackOp>(
296-
loc, fields[field].getType(), fields[1], fields[field], value,
304+
loc, fields[field].getType(), fields[MemSizesIdx], fields[field], value,
297305
APInt(64, getMemSizesIndex(field)));
298306
}
299307

@@ -312,8 +320,8 @@ static void genInsert(OpBuilder &builder, Location loc, RankedTensorType rtp,
312320
return; // TODO: add codegen
313321
// push_back memSizes indices-0 index
314322
// push_back memSizes values value
315-
createPushback(builder, loc, fields, 3, indices[0]);
316-
createPushback(builder, loc, fields, 4, value);
323+
createPushback(builder, loc, fields, FieldsIdx + 1, indices[0]);
324+
createPushback(builder, loc, fields, FieldsIdx + 2, value);
317325
}
318326

319327
/// Generations insertion finalization code.
@@ -329,9 +337,9 @@ static void genEndInsert(OpBuilder &builder, Location loc, RankedTensorType rtp,
329337
// push_back memSizes pointers-0 memSizes[2]
330338
Value zero = constantIndex(builder, loc, 0);
331339
Value two = constantIndex(builder, loc, 2);
332-
Value size = builder.create<memref::LoadOp>(loc, fields[1], two);
333-
createPushback(builder, loc, fields, 2, zero);
334-
createPushback(builder, loc, fields, 2, size);
340+
Value size = builder.create<memref::LoadOp>(loc, fields[MemSizesIdx], two);
341+
createPushback(builder, loc, fields, FieldsIdx, zero);
342+
createPushback(builder, loc, fields, FieldsIdx, size);
335343
}
336344

337345
//===----------------------------------------------------------------------===//
@@ -759,7 +767,7 @@ class SparseNumberOfEntriesConverter
759767
unsigned lastField = fields.size() - 1;
760768
Value field =
761769
constantIndex(rewriter, op.getLoc(), getMemSizesIndex(lastField));
762-
rewriter.replaceOpWithNewOp<memref::LoadOp>(op, fields[1], field);
770+
rewriter.replaceOpWithNewOp<memref::LoadOp>(op, fields[MemSizesIdx], field);
763771
return success();
764772
}
765773
};

0 commit comments

Comments
 (0)