@@ -103,37 +103,28 @@ static Optional<Value> sizeFromTensorAtDim(OpBuilder &rewriter, Location loc,
103
103
104
104
// / Returns field index of sparse tensor type for pointers/indices, when set.
105
105
static unsigned getFieldIndex (Type type, unsigned ptrDim, unsigned idxDim) {
106
- auto enc = getSparseTensorEncoding (type);
107
- assert (enc);
106
+ assert (getSparseTensorEncoding (type));
108
107
RankedTensorType rType = type.cast <RankedTensorType>();
109
108
unsigned field = 2 ; // start past sizes
110
109
unsigned ptr = 0 ;
111
110
unsigned idx = 0 ;
112
111
for (unsigned r = 0 , rank = rType.getShape ().size (); r < rank; r++) {
113
- switch (enc.getDimLevelType ()[r]) {
114
- case SparseTensorEncodingAttr::DimLevelType::Dense:
115
- break ; // no fields
116
- case SparseTensorEncodingAttr::DimLevelType::Compressed:
117
- case SparseTensorEncodingAttr::DimLevelType::CompressedNu:
118
- case SparseTensorEncodingAttr::DimLevelType::CompressedNo:
119
- case SparseTensorEncodingAttr::DimLevelType::CompressedNuNo:
112
+ if (isCompressedDim (rType, r)) {
120
113
if (ptr++ == ptrDim)
121
114
return field;
122
115
field++;
123
116
if (idx++ == idxDim)
124
117
return field;
125
118
field++;
126
- break ;
127
- case SparseTensorEncodingAttr::DimLevelType::Singleton:
128
- case SparseTensorEncodingAttr::DimLevelType::SingletonNu:
129
- case SparseTensorEncodingAttr::DimLevelType::SingletonNo:
130
- case SparseTensorEncodingAttr::DimLevelType::SingletonNuNo:
119
+ } else if (isSingletonDim (rType, r)) {
131
120
if (idx++ == idxDim)
132
121
return field;
133
122
field++;
134
- break ;
123
+ } else {
124
+ assert (isDenseDim (rType, r)); // no fields
135
125
}
136
126
}
127
+ assert (ptrDim == -1u && idxDim == -1u );
137
128
return field + 1 ; // return values field index
138
129
}
139
130
@@ -176,30 +167,21 @@ convertSparseTensorType(Type type, SmallVectorImpl<Type> &fields) {
176
167
// The dimSizes array.
177
168
fields.push_back (MemRefType::get ({rank}, indexType));
178
169
// The memSizes array.
179
- unsigned lastField = getFieldIndex (type, -1 , -1 );
170
+ unsigned lastField = getFieldIndex (type, -1u , -1u );
180
171
fields.push_back (MemRefType::get ({lastField - 2 }, indexType));
181
172
// Per-dimension storage.
182
173
for (unsigned r = 0 ; r < rank; r++) {
183
174
// Dimension level types apply in order to the reordered dimension.
184
175
// As a result, the compound type can be constructed directly in the given
185
176
// order. Clients of this type know what field is what from the sparse
186
177
// tensor type.
187
- switch (enc.getDimLevelType ()[r]) {
188
- case SparseTensorEncodingAttr::DimLevelType::Dense:
189
- break ; // no fields
190
- case SparseTensorEncodingAttr::DimLevelType::Compressed:
191
- case SparseTensorEncodingAttr::DimLevelType::CompressedNu:
192
- case SparseTensorEncodingAttr::DimLevelType::CompressedNo:
193
- case SparseTensorEncodingAttr::DimLevelType::CompressedNuNo:
178
+ if (isCompressedDim (rType, r)) {
194
179
fields.push_back (MemRefType::get ({ShapedType::kDynamicSize }, ptrType));
195
180
fields.push_back (MemRefType::get ({ShapedType::kDynamicSize }, idxType));
196
- break ;
197
- case SparseTensorEncodingAttr::DimLevelType::Singleton:
198
- case SparseTensorEncodingAttr::DimLevelType::SingletonNu:
199
- case SparseTensorEncodingAttr::DimLevelType::SingletonNo:
200
- case SparseTensorEncodingAttr::DimLevelType::SingletonNuNo:
181
+ } else if (isSingletonDim (rType, r)) {
201
182
fields.push_back (MemRefType::get ({ShapedType::kDynamicSize }, idxType));
202
- break ;
183
+ } else {
184
+ assert (isDenseDim (rType, r)); // no fields
203
185
}
204
186
}
205
187
// The values array.
@@ -254,7 +236,7 @@ static void createAllocFields(OpBuilder &builder, Location loc, Type type,
254
236
builder.create <memref::AllocOp>(loc, MemRefType::get ({rank}, indexType));
255
237
fields.push_back (dimSizes);
256
238
// The sizes array.
257
- unsigned lastField = getFieldIndex (type, -1 , -1 );
239
+ unsigned lastField = getFieldIndex (type, -1u , -1u );
258
240
Value memSizes = builder.create <memref::AllocOp>(
259
241
loc, MemRefType::get ({lastField - 2 }, indexType));
260
242
fields.push_back (memSizes);
@@ -265,25 +247,16 @@ static void createAllocFields(OpBuilder &builder, Location loc, Type type,
265
247
builder.create <memref::StoreOp>(loc, sizes[ro], dimSizes,
266
248
constantIndex (builder, loc, r));
267
249
linear = builder.create <arith::MulIOp>(loc, linear, sizes[ro]);
268
- // Allocate fiels.
269
- switch (enc.getDimLevelType ()[r]) {
270
- case SparseTensorEncodingAttr::DimLevelType::Dense:
271
- break ; // no fields
272
- case SparseTensorEncodingAttr::DimLevelType::Compressed:
273
- case SparseTensorEncodingAttr::DimLevelType::CompressedNu:
274
- case SparseTensorEncodingAttr::DimLevelType::CompressedNo:
275
- case SparseTensorEncodingAttr::DimLevelType::CompressedNuNo:
250
+ // Allocate fields.
251
+ if (isCompressedDim (rType, r)) {
276
252
fields.push_back (createAllocation (builder, loc, ptrType, heuristic));
277
253
fields.push_back (createAllocation (builder, loc, idxType, heuristic));
278
254
allDense = false ;
279
- break ;
280
- case SparseTensorEncodingAttr::DimLevelType::Singleton:
281
- case SparseTensorEncodingAttr::DimLevelType::SingletonNu:
282
- case SparseTensorEncodingAttr::DimLevelType::SingletonNo:
283
- case SparseTensorEncodingAttr::DimLevelType::SingletonNuNo:
255
+ } else if (isSingletonDim (rType, r)) {
284
256
fields.push_back (createAllocation (builder, loc, idxType, heuristic));
285
257
allDense = false ;
286
- break ;
258
+ } else {
259
+ assert (isDenseDim (rType, r)); // no fields
287
260
}
288
261
}
289
262
// The values array. For all-dense, the full length is required.
@@ -507,7 +480,8 @@ class SparseExpandConverter : public OpConversionPattern<ExpandOp> {
507
480
matchAndRewrite (ExpandOp op, OpAdaptor adaptor,
508
481
ConversionPatternRewriter &rewriter) const override {
509
482
Location loc = op->getLoc ();
510
- ShapedType srcType = op.getTensor ().getType ().cast <ShapedType>();
483
+ RankedTensorType srcType =
484
+ op.getTensor ().getType ().cast <RankedTensorType>();
511
485
Type eltType = srcType.getElementType ();
512
486
Type boolType = rewriter.getIntegerType (1 );
513
487
Type idxType = rewriter.getIndexType ();
@@ -561,17 +535,18 @@ class SparseCompressConverter : public OpConversionPattern<CompressOp> {
561
535
matchAndRewrite (CompressOp op, OpAdaptor adaptor,
562
536
ConversionPatternRewriter &rewriter) const override {
563
537
Location loc = op->getLoc ();
564
- ShapedType srcType = op.getTensor ().getType ().cast <ShapedType>();
565
- Type eltType = srcType.getElementType ();
538
+ RankedTensorType dstType =
539
+ op.getTensor ().getType ().cast <RankedTensorType>();
540
+ Type eltType = dstType.getElementType ();
566
541
Value values = adaptor.getValues ();
567
542
Value filled = adaptor.getFilled ();
568
543
Value added = adaptor.getAdded ();
569
544
Value count = adaptor.getCount ();
570
-
571
- //
572
- // TODO: need to implement "std::sort(added, added + count);" for ordered
573
- //
574
-
545
+ // If the innermost dimension is ordered, we need to sort the indices
546
+ // in the "added" array prior to applying the compression.
547
+ unsigned rank = dstType. getShape (). size ();
548
+ if ( isOrderedDim (dstType, rank - 1 ))
549
+ rewriter. create <SortOp>(loc, count, ValueRange{added}, ValueRange{});
575
550
// While performing the insertions, we also need to reset the elements
576
551
// of the values/filled-switch by only iterating over the set elements,
577
552
// to ensure that the runtime complexity remains proportional to the
@@ -699,7 +674,7 @@ class SparseToPointersConverter
699
674
static unsigned getIndexForOp (UnrealizedConversionCastOp /* tuple*/ ,
700
675
ToPointersOp op) {
701
676
uint64_t dim = op.getDimension ().getZExtValue ();
702
- return getFieldIndex (op.getTensor ().getType (), /* ptrDim=*/ dim, -1 );
677
+ return getFieldIndex (op.getTensor ().getType (), /* ptrDim=*/ dim, -1u );
703
678
}
704
679
};
705
680
@@ -712,7 +687,7 @@ class SparseToIndicesConverter
712
687
static unsigned getIndexForOp (UnrealizedConversionCastOp /* tuple*/ ,
713
688
ToIndicesOp op) {
714
689
uint64_t dim = op.getDimension ().getZExtValue ();
715
- return getFieldIndex (op.getTensor ().getType (), -1 , /* idxDim=*/ dim);
690
+ return getFieldIndex (op.getTensor ().getType (), -1u , /* idxDim=*/ dim);
716
691
}
717
692
};
718
693
0 commit comments