Skip to content

Commit 8a7e69d

Browse files
author
Peiming Liu
committed
[mlir][sparse] Refactoring: abstract sparse tensor memory scheme into a SparseTensorDescriptor class.
This patch abstracts sparse tensor memory scheme into a SparseTensorDescriptor class. Previously, the field accesses are performed in a relatively error-prone way, this patch hides the hairy details behind a SparseTensorDescriptor class to allow users access sparse tensor fields in a more cohesive way. Reviewed By: aartbik Differential Revision: https://reviews.llvm.org/D138627
1 parent 0d03ba6 commit 8a7e69d

File tree

4 files changed

+564
-288
lines changed

4 files changed

+564
-288
lines changed

mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,21 @@ inline bool isSingletonDim(RankedTensorType type, uint64_t d) {
7575
return isSingletonDLT(getDimLevelType(type, d));
7676
}
7777

78+
/// Convenience function to test for dense dimension (0 <= d < rank).
79+
inline bool isDenseDim(SparseTensorEncodingAttr enc, uint64_t d) {
80+
return isDenseDLT(getDimLevelType(enc, d));
81+
}
82+
83+
/// Convenience function to test for compressed dimension (0 <= d < rank).
84+
inline bool isCompressedDim(SparseTensorEncodingAttr enc, uint64_t d) {
85+
return isCompressedDLT(getDimLevelType(enc, d));
86+
}
87+
88+
/// Convenience function to test for singleton dimension (0 <= d < rank).
89+
inline bool isSingletonDim(SparseTensorEncodingAttr enc, uint64_t d) {
90+
return isSingletonDLT(getDimLevelType(enc, d));
91+
}
92+
7893
//
7994
// Dimension level properties.
8095
//

mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,115 @@ static Value genIndexAndValueForDense(OpBuilder &builder, Location loc,
9090
return val;
9191
}
9292

93+
void sparse_tensor::foreachFieldInSparseTensor(
94+
const SparseTensorEncodingAttr enc,
95+
llvm::function_ref<bool(unsigned, SparseTensorFieldKind, unsigned,
96+
DimLevelType)>
97+
callback) {
98+
assert(enc);
99+
100+
#define RETURN_ON_FALSE(idx, kind, dim, dlt) \
101+
if (!(callback(idx, kind, dim, dlt))) \
102+
return;
103+
104+
RETURN_ON_FALSE(dimSizesIdx, SparseTensorFieldKind::DimSizes, -1u,
105+
DimLevelType::Undef);
106+
RETURN_ON_FALSE(memSizesIdx, SparseTensorFieldKind::MemSizes, -1u,
107+
DimLevelType::Undef);
108+
109+
static_assert(dataFieldIdx == memSizesIdx + 1);
110+
unsigned fieldIdx = dataFieldIdx;
111+
// Per-dimension storage.
112+
for (unsigned r = 0, rank = enc.getDimLevelType().size(); r < rank; r++) {
113+
// Dimension level types apply in order to the reordered dimension.
114+
// As a result, the compound type can be constructed directly in the given
115+
// order.
116+
auto dlt = getDimLevelType(enc, r);
117+
if (isCompressedDLT(dlt)) {
118+
RETURN_ON_FALSE(fieldIdx++, SparseTensorFieldKind::PtrMemRef, r, dlt);
119+
RETURN_ON_FALSE(fieldIdx++, SparseTensorFieldKind::IdxMemRef, r, dlt);
120+
} else if (isSingletonDLT(dlt)) {
121+
RETURN_ON_FALSE(fieldIdx++, SparseTensorFieldKind::IdxMemRef, r, dlt);
122+
} else {
123+
assert(isDenseDLT(dlt)); // no fields
124+
}
125+
}
126+
// The values array.
127+
RETURN_ON_FALSE(fieldIdx++, SparseTensorFieldKind::ValMemRef, -1u,
128+
DimLevelType::Undef);
129+
130+
#undef RETURN_ON_FALSE
131+
}
132+
133+
void sparse_tensor::foreachFieldAndTypeInSparseTensor(
134+
RankedTensorType rType,
135+
llvm::function_ref<bool(Type, unsigned, SparseTensorFieldKind, unsigned,
136+
DimLevelType)>
137+
callback) {
138+
auto enc = getSparseTensorEncoding(rType);
139+
assert(enc);
140+
// Construct the basic types.
141+
Type indexType = IndexType::get(enc.getContext());
142+
Type idxType = enc.getIndexType();
143+
Type ptrType = enc.getPointerType();
144+
Type eltType = rType.getElementType();
145+
unsigned rank = rType.getShape().size();
146+
// memref<rank x index> dimSizes
147+
Type dimSizeType = MemRefType::get({rank}, indexType);
148+
// memref<n x index> memSizes
149+
Type memSizeType =
150+
MemRefType::get({getNumDataFieldsFromEncoding(enc)}, indexType);
151+
// memref<? x ptr> pointers
152+
Type ptrMemType = MemRefType::get({ShapedType::kDynamic}, ptrType);
153+
// memref<? x idx> indices
154+
Type idxMemType = MemRefType::get({ShapedType::kDynamic}, idxType);
155+
// memref<? x eltType> values
156+
Type valMemType = MemRefType::get({ShapedType::kDynamic}, eltType);
157+
158+
foreachFieldInSparseTensor(
159+
enc,
160+
[dimSizeType, memSizeType, ptrMemType, idxMemType, valMemType,
161+
callback](unsigned fieldIdx, SparseTensorFieldKind fieldKind,
162+
unsigned dim, DimLevelType dlt) -> bool {
163+
switch (fieldKind) {
164+
case SparseTensorFieldKind::DimSizes:
165+
return callback(dimSizeType, fieldIdx, fieldKind, dim, dlt);
166+
case SparseTensorFieldKind::MemSizes:
167+
return callback(memSizeType, fieldIdx, fieldKind, dim, dlt);
168+
case SparseTensorFieldKind::PtrMemRef:
169+
return callback(ptrMemType, fieldIdx, fieldKind, dim, dlt);
170+
case SparseTensorFieldKind::IdxMemRef:
171+
return callback(idxMemType, fieldIdx, fieldKind, dim, dlt);
172+
case SparseTensorFieldKind::ValMemRef:
173+
return callback(valMemType, fieldIdx, fieldKind, dim, dlt);
174+
};
175+
});
176+
}
177+
178+
unsigned sparse_tensor::getNumFieldsFromEncoding(SparseTensorEncodingAttr enc) {
179+
unsigned numFields = 0;
180+
foreachFieldInSparseTensor(enc,
181+
[&numFields](unsigned, SparseTensorFieldKind,
182+
unsigned, DimLevelType) -> bool {
183+
numFields++;
184+
return true;
185+
});
186+
return numFields;
187+
}
188+
189+
unsigned
190+
sparse_tensor::getNumDataFieldsFromEncoding(SparseTensorEncodingAttr enc) {
191+
unsigned numFields = 0; // one value memref
192+
foreachFieldInSparseTensor(enc,
193+
[&numFields](unsigned fidx, SparseTensorFieldKind,
194+
unsigned, DimLevelType) -> bool {
195+
if (fidx >= dataFieldIdx)
196+
numFields++;
197+
return true;
198+
});
199+
assert(numFields == getNumFieldsFromEncoding(enc) - dataFieldIdx);
200+
return numFields;
201+
}
93202
//===----------------------------------------------------------------------===//
94203
// Sparse tensor loop emitter class implementations
95204
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)