@@ -31,6 +31,11 @@ using namespace mlir::sparse_tensor;
31
31
32
32
namespace {
33
33
34
+ static constexpr uint64_t DimSizesIdx = 0 ;
35
+ static constexpr uint64_t DimCursorIdx = 1 ;
36
+ static constexpr uint64_t MemSizesIdx = 2 ;
37
+ static constexpr uint64_t FieldsIdx = 3 ;
38
+
34
39
// ===----------------------------------------------------------------------===//
35
40
// Helper methods.
36
41
// ===----------------------------------------------------------------------===//
@@ -90,11 +95,17 @@ static Optional<Value> sizeFromTensorAtDim(OpBuilder &rewriter, Location loc,
90
95
.getResult ();
91
96
}
92
97
98
+ // / Translates field index to memSizes index.
99
+ static unsigned getMemSizesIndex (unsigned field) {
100
+ assert (FieldsIdx <= field);
101
+ return field - FieldsIdx;
102
+ }
103
+
93
104
// / Returns field index of sparse tensor type for pointers/indices, when set.
94
105
static unsigned getFieldIndex (Type type, unsigned ptrDim, unsigned idxDim) {
95
106
assert (getSparseTensorEncoding (type));
96
107
RankedTensorType rType = type.cast <RankedTensorType>();
97
- unsigned field = 2 ; // start past sizes
108
+ unsigned field = FieldsIdx ; // start past header
98
109
unsigned ptr = 0 ;
99
110
unsigned idx = 0 ;
100
111
for (unsigned r = 0 , rank = rType.getShape ().size (); r < rank; r++) {
@@ -140,6 +151,7 @@ convertSparseTensorType(Type type, SmallVectorImpl<Type> &fields) {
140
151
//
141
152
// struct {
142
153
// memref<rank x index> dimSizes ; size in each dimension
154
+ // memref<rank x index> dimCursor ; cursor in each dimension
143
155
// memref<n x index> memSizes ; sizes of ptrs/inds/values
144
156
// ; per-dimension d:
145
157
// ; if dense:
@@ -153,11 +165,11 @@ convertSparseTensorType(Type type, SmallVectorImpl<Type> &fields) {
153
165
// };
154
166
//
155
167
unsigned rank = rType.getShape ().size ();
156
- // The dimSizes array.
157
- fields.push_back (MemRefType::get ({rank}, indexType));
158
- // The memSizes array.
159
168
unsigned lastField = getFieldIndex (type, -1u , -1u );
160
- fields.push_back (MemRefType::get ({lastField - 2 }, indexType));
169
+ // The dimSizes array, dimCursor array, and memSizes array.
170
+ fields.push_back (MemRefType::get ({rank}, indexType));
171
+ fields.push_back (MemRefType::get ({rank}, indexType));
172
+ fields.push_back (MemRefType::get ({getMemSizesIndex (lastField)}, indexType));
161
173
// Per-dimension storage.
162
174
for (unsigned r = 0 ; r < rank; r++) {
163
175
// Dimension level types apply in order to the reordered dimension.
@@ -179,7 +191,7 @@ convertSparseTensorType(Type type, SmallVectorImpl<Type> &fields) {
179
191
return success ();
180
192
}
181
193
182
- // / Create allocation operation.
194
+ // / Creates allocation operation.
183
195
static Value createAllocation (OpBuilder &builder, Location loc, Type type,
184
196
Value sz) {
185
197
auto memType = MemRefType::get ({ShapedType::kDynamicSize }, type);
@@ -220,14 +232,16 @@ static void createAllocFields(OpBuilder &builder, Location loc, Type type,
220
232
else
221
233
sizes.push_back (constantIndex (builder, loc, shape[r]));
222
234
}
223
- // The dimSizes array.
235
+ // The dimSizes array, dimCursor array, and memSizes array.
236
+ unsigned lastField = getFieldIndex (type, -1u , -1u );
224
237
Value dimSizes =
225
238
builder.create <memref::AllocOp>(loc, MemRefType::get ({rank}, indexType));
226
- fields.push_back (dimSizes);
227
- // The sizes array.
228
- unsigned lastField = getFieldIndex (type, -1u , -1u );
239
+ Value dimCursor =
240
+ builder.create <memref::AllocOp>(loc, MemRefType::get ({rank}, indexType));
229
241
Value memSizes = builder.create <memref::AllocOp>(
230
- loc, MemRefType::get ({lastField - 2 }, indexType));
242
+ loc, MemRefType::get ({getMemSizesIndex (lastField)}, indexType));
243
+ fields.push_back (dimSizes);
244
+ fields.push_back (dimCursor);
231
245
fields.push_back (memSizes);
232
246
// Per-dimension storage.
233
247
for (unsigned r = 0 ; r < rank; r++) {
@@ -277,23 +291,17 @@ static scf::ForOp createFor(OpBuilder &builder, Location loc, Value count,
277
291
return forOp;
278
292
}
279
293
280
- // / Translates field index to memSizes index.
281
- static unsigned getMemSizesIndex (unsigned field) {
282
- assert (2 <= field);
283
- return field - 2 ;
284
- }
285
-
286
294
// / Creates a pushback op for given field and updates the fields array
287
295
// / accordingly.
288
296
static void createPushback (OpBuilder &builder, Location loc,
289
297
SmallVectorImpl<Value> &fields, unsigned field,
290
298
Value value) {
291
- assert (2 <= field && field < fields.size ());
299
+ assert (FieldsIdx <= field && field < fields.size ());
292
300
Type etp = fields[field].getType ().cast <ShapedType>().getElementType ();
293
301
if (value.getType () != etp)
294
302
value = builder.create <arith::IndexCastOp>(loc, etp, value);
295
303
fields[field] = builder.create <PushBackOp>(
296
- loc, fields[field].getType (), fields[1 ], fields[field], value,
304
+ loc, fields[field].getType (), fields[MemSizesIdx ], fields[field], value,
297
305
APInt (64 , getMemSizesIndex (field)));
298
306
}
299
307
@@ -312,8 +320,8 @@ static void genInsert(OpBuilder &builder, Location loc, RankedTensorType rtp,
312
320
return ; // TODO: add codegen
313
321
// push_back memSizes indices-0 index
314
322
// push_back memSizes values value
315
- createPushback (builder, loc, fields, 3 , indices[0 ]);
316
- createPushback (builder, loc, fields, 4 , value);
323
+ createPushback (builder, loc, fields, FieldsIdx + 1 , indices[0 ]);
324
+ createPushback (builder, loc, fields, FieldsIdx + 2 , value);
317
325
}
318
326
319
327
// / Generations insertion finalization code.
@@ -329,9 +337,9 @@ static void genEndInsert(OpBuilder &builder, Location loc, RankedTensorType rtp,
329
337
// push_back memSizes pointers-0 memSizes[2]
330
338
Value zero = constantIndex (builder, loc, 0 );
331
339
Value two = constantIndex (builder, loc, 2 );
332
- Value size = builder.create <memref::LoadOp>(loc, fields[1 ], two);
333
- createPushback (builder, loc, fields, 2 , zero);
334
- createPushback (builder, loc, fields, 2 , size);
340
+ Value size = builder.create <memref::LoadOp>(loc, fields[MemSizesIdx ], two);
341
+ createPushback (builder, loc, fields, FieldsIdx , zero);
342
+ createPushback (builder, loc, fields, FieldsIdx , size);
335
343
}
336
344
337
345
// ===----------------------------------------------------------------------===//
@@ -759,7 +767,7 @@ class SparseNumberOfEntriesConverter
759
767
unsigned lastField = fields.size () - 1 ;
760
768
Value field =
761
769
constantIndex (rewriter, op.getLoc (), getMemSizesIndex (lastField));
762
- rewriter.replaceOpWithNewOp <memref::LoadOp>(op, fields[1 ], field);
770
+ rewriter.replaceOpWithNewOp <memref::LoadOp>(op, fields[MemSizesIdx ], field);
763
771
return success ();
764
772
}
765
773
};
0 commit comments