@@ -90,6 +90,115 @@ static Value genIndexAndValueForDense(OpBuilder &builder, Location loc,
90
90
return val;
91
91
}
92
92
93
+ void sparse_tensor::foreachFieldInSparseTensor (
94
+ const SparseTensorEncodingAttr enc,
95
+ llvm::function_ref<bool (unsigned , SparseTensorFieldKind, unsigned ,
96
+ DimLevelType)>
97
+ callback) {
98
+ assert (enc);
99
+
100
+ #define RETURN_ON_FALSE (idx, kind, dim, dlt ) \
101
+ if (!(callback (idx, kind, dim, dlt))) \
102
+ return ;
103
+
104
+ RETURN_ON_FALSE (dimSizesIdx, SparseTensorFieldKind::DimSizes, -1u ,
105
+ DimLevelType::Undef);
106
+ RETURN_ON_FALSE (memSizesIdx, SparseTensorFieldKind::MemSizes, -1u ,
107
+ DimLevelType::Undef);
108
+
109
+ static_assert (dataFieldIdx == memSizesIdx + 1 );
110
+ unsigned fieldIdx = dataFieldIdx;
111
+ // Per-dimension storage.
112
+ for (unsigned r = 0 , rank = enc.getDimLevelType ().size (); r < rank; r++) {
113
+ // Dimension level types apply in order to the reordered dimension.
114
+ // As a result, the compound type can be constructed directly in the given
115
+ // order.
116
+ auto dlt = getDimLevelType (enc, r);
117
+ if (isCompressedDLT (dlt)) {
118
+ RETURN_ON_FALSE (fieldIdx++, SparseTensorFieldKind::PtrMemRef, r, dlt);
119
+ RETURN_ON_FALSE (fieldIdx++, SparseTensorFieldKind::IdxMemRef, r, dlt);
120
+ } else if (isSingletonDLT (dlt)) {
121
+ RETURN_ON_FALSE (fieldIdx++, SparseTensorFieldKind::IdxMemRef, r, dlt);
122
+ } else {
123
+ assert (isDenseDLT (dlt)); // no fields
124
+ }
125
+ }
126
+ // The values array.
127
+ RETURN_ON_FALSE (fieldIdx++, SparseTensorFieldKind::ValMemRef, -1u ,
128
+ DimLevelType::Undef);
129
+
130
+ #undef RETURN_ON_FALSE
131
+ }
132
+
133
+ void sparse_tensor::foreachFieldAndTypeInSparseTensor (
134
+ RankedTensorType rType,
135
+ llvm::function_ref<bool (Type, unsigned , SparseTensorFieldKind, unsigned ,
136
+ DimLevelType)>
137
+ callback) {
138
+ auto enc = getSparseTensorEncoding (rType);
139
+ assert (enc);
140
+ // Construct the basic types.
141
+ Type indexType = IndexType::get (enc.getContext ());
142
+ Type idxType = enc.getIndexType ();
143
+ Type ptrType = enc.getPointerType ();
144
+ Type eltType = rType.getElementType ();
145
+ unsigned rank = rType.getShape ().size ();
146
+ // memref<rank x index> dimSizes
147
+ Type dimSizeType = MemRefType::get ({rank}, indexType);
148
+ // memref<n x index> memSizes
149
+ Type memSizeType =
150
+ MemRefType::get ({getNumDataFieldsFromEncoding (enc)}, indexType);
151
+ // memref<? x ptr> pointers
152
+ Type ptrMemType = MemRefType::get ({ShapedType::kDynamic }, ptrType);
153
+ // memref<? x idx> indices
154
+ Type idxMemType = MemRefType::get ({ShapedType::kDynamic }, idxType);
155
+ // memref<? x eltType> values
156
+ Type valMemType = MemRefType::get ({ShapedType::kDynamic }, eltType);
157
+
158
+ foreachFieldInSparseTensor (
159
+ enc,
160
+ [dimSizeType, memSizeType, ptrMemType, idxMemType, valMemType,
161
+ callback](unsigned fieldIdx, SparseTensorFieldKind fieldKind,
162
+ unsigned dim, DimLevelType dlt) -> bool {
163
+ switch (fieldKind) {
164
+ case SparseTensorFieldKind::DimSizes:
165
+ return callback (dimSizeType, fieldIdx, fieldKind, dim, dlt);
166
+ case SparseTensorFieldKind::MemSizes:
167
+ return callback (memSizeType, fieldIdx, fieldKind, dim, dlt);
168
+ case SparseTensorFieldKind::PtrMemRef:
169
+ return callback (ptrMemType, fieldIdx, fieldKind, dim, dlt);
170
+ case SparseTensorFieldKind::IdxMemRef:
171
+ return callback (idxMemType, fieldIdx, fieldKind, dim, dlt);
172
+ case SparseTensorFieldKind::ValMemRef:
173
+ return callback (valMemType, fieldIdx, fieldKind, dim, dlt);
174
+ };
175
+ });
176
+ }
177
+
178
+ unsigned sparse_tensor::getNumFieldsFromEncoding (SparseTensorEncodingAttr enc) {
179
+ unsigned numFields = 0 ;
180
+ foreachFieldInSparseTensor (enc,
181
+ [&numFields](unsigned , SparseTensorFieldKind,
182
+ unsigned , DimLevelType) -> bool {
183
+ numFields++;
184
+ return true ;
185
+ });
186
+ return numFields;
187
+ }
188
+
189
+ unsigned
190
+ sparse_tensor::getNumDataFieldsFromEncoding (SparseTensorEncodingAttr enc) {
191
+ unsigned numFields = 0 ; // one value memref
192
+ foreachFieldInSparseTensor (enc,
193
+ [&numFields](unsigned fidx, SparseTensorFieldKind,
194
+ unsigned , DimLevelType) -> bool {
195
+ if (fidx >= dataFieldIdx)
196
+ numFields++;
197
+ return true ;
198
+ });
199
+ assert (numFields == getNumFieldsFromEncoding (enc) - dataFieldIdx);
200
+ return numFields;
201
+ }
93
202
// ===----------------------------------------------------------------------===//
94
203
// Sparse tensor loop emitter class implementations
95
204
// ===----------------------------------------------------------------------===//
0 commit comments