@@ -61,6 +61,26 @@ static constexpr bool acceptBitWidth(unsigned bitWidth) {
61
61
}
62
62
}
63
63
64
+ static SmallVector<Size>
65
+ getSparseFieldShape (const SparseTensorEncodingAttr enc,
66
+ std::optional<ArrayRef<int64_t >> dimShape) {
67
+ assert (enc);
68
+ // With only encoding, we can not determine the static shape for leading
69
+ // batch levels, we therefore return a dynamic shape memref instead.
70
+ SmallVector<int64_t > memrefShape (enc.getBatchLvlRank (), ShapedType::kDynamic );
71
+ if (dimShape.has_value ()) {
72
+ // If the actual tensor shape is provided, we can then refine the leading
73
+ // batch dimension.
74
+ SmallVector<int64_t > lvlShape =
75
+ enc.translateShape (*dimShape, CrdTransDirectionKind::dim2lvl);
76
+ memrefShape.assign (lvlShape.begin (),
77
+ lvlShape.begin () + enc.getBatchLvlRank ());
78
+ }
79
+ // Another dynamic dimension to store the sparse level.
80
+ memrefShape.push_back (ShapedType::kDynamic );
81
+ return memrefShape;
82
+ }
83
+
64
84
// ===----------------------------------------------------------------------===//
65
85
// SparseTensorDialect StorageLayout.
66
86
// ===----------------------------------------------------------------------===//
@@ -122,21 +142,17 @@ void sparse_tensor::foreachFieldAndTypeInSparseTensor(
122
142
LevelType)>
123
143
callback) {
124
144
assert (stt.hasEncoding ());
125
- // Construct the basic types.
126
- const Type crdType = stt.getCrdType ();
127
- const Type posType = stt.getPosType ();
128
- const Type eltType = stt.getElementType ();
129
145
130
- SmallVector<int64_t > memrefShape = stt. getBatchLvlShape ();
131
- memrefShape. push_back (ShapedType:: kDynamic );
146
+ SmallVector<int64_t > memrefShape =
147
+ getSparseFieldShape (stt. getEncoding (), stt. getDimShape () );
132
148
133
149
const Type specType = StorageSpecifierType::get (stt.getEncoding ());
134
150
// memref<[batch] x ? x pos> positions
135
- const Type posMemType = MemRefType::get (memrefShape, posType );
151
+ const Type posMemType = MemRefType::get (memrefShape, stt. getPosType () );
136
152
// memref<[batch] x ? x crd> coordinates
137
- const Type crdMemType = MemRefType::get (memrefShape, crdType );
153
+ const Type crdMemType = MemRefType::get (memrefShape, stt. getCrdType () );
138
154
// memref<[batch] x ? x eltType> values
139
- const Type valMemType = MemRefType::get (memrefShape, eltType );
155
+ const Type valMemType = MemRefType::get (memrefShape, stt. getElementType () );
140
156
141
157
StorageLayout (stt).foreachField ([specType, posMemType, crdMemType, valMemType,
142
158
callback](FieldIndex fieldIdx,
@@ -354,6 +370,34 @@ bool SparseTensorEncodingAttr::isAllOrdered() const {
354
370
return !getImpl () || llvm::all_of (getLvlTypes (), isOrderedLT);
355
371
}
356
372
373
+ Type SparseTensorEncodingAttr::getCrdElemType () const {
374
+ if (!getImpl ())
375
+ return nullptr ;
376
+ if (getCrdWidth ())
377
+ return IntegerType::get (getContext (), getCrdWidth ());
378
+ return IndexType::get (getContext ());
379
+ }
380
+
381
+ Type SparseTensorEncodingAttr::getPosElemType () const {
382
+ if (!getImpl ())
383
+ return nullptr ;
384
+ if (getPosWidth ())
385
+ return IntegerType::get (getContext (), getPosWidth ());
386
+ return IndexType::get (getContext ());
387
+ }
388
+
389
+ MemRefType SparseTensorEncodingAttr::getCrdMemRefType (
390
+ std::optional<ArrayRef<int64_t >> dimShape) const {
391
+ SmallVector<Size> shape = getSparseFieldShape (*this , dimShape);
392
+ return MemRefType::get (shape, getCrdElemType ());
393
+ }
394
+
395
+ MemRefType SparseTensorEncodingAttr::getPosMemRefType (
396
+ std::optional<ArrayRef<int64_t >> dimShape) const {
397
+ SmallVector<Size> shape = getSparseFieldShape (*this , dimShape);
398
+ return MemRefType::get (shape, getPosElemType ());
399
+ }
400
+
357
401
bool SparseTensorEncodingAttr::isIdentity () const {
358
402
return !getImpl () || !getDimToLvl () || getDimToLvl ().isIdentity ();
359
403
}
0 commit comments