Skip to content

Commit 62fa12a

Browse files
author
Peiming Liu
authored
[mlir][sparse] support querying sparse buffer types from sparse tenso… (llvm#88308)
…r encodings.
1 parent 4e6d18f commit 62fa12a

File tree

3 files changed

+75
-19
lines changed

3 files changed

+75
-19
lines changed

mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -401,6 +401,26 @@ def SparseTensorEncodingAttr : SparseTensor_Attr<"SparseTensorEncoding",
401401
/// the null encoding (since dense-tensors are always all-ordered).
402402
bool isAllOrdered() const;
403403

404+
//
405+
// Storage type methods.
406+
//
407+
408+
/// Returns the coordinate-overhead MLIR type, defaulting to `IndexType`.
409+
Type getCrdElemType() const;
410+
411+
/// Returns the position-overhead MLIR type, defaulting to `IndexType`.
412+
Type getPosElemType() const;
413+
414+
/// Returns the coordinate-memref MLIR type, an optional tensorDimShape is
415+
/// used to refine the leading batch dimensions (if any).
416+
MemRefType getCrdMemRefType(
417+
std::optional<ArrayRef<int64_t>> tensorDimShape = std::nullopt) const;
418+
419+
/// Returns the position-memref MLIR type, an optional tensorDimShape is
420+
/// used to refine the leading batch dimensions (if any).
421+
MemRefType getPosMemRefType(
422+
std::optional<ArrayRef<int64_t>> tensorDimShape = std::nullopt) const;
423+
404424
//
405425
// dimToLvl methods.
406426
//

mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -328,18 +328,10 @@ class SparseTensorType {
328328
unsigned getPosWidth() const { return enc ? enc.getPosWidth() : 0; }
329329

330330
/// Returns the coordinate-overhead MLIR type, defaulting to `IndexType`.
331-
Type getCrdType() const {
332-
if (getCrdWidth())
333-
return IntegerType::get(getContext(), getCrdWidth());
334-
return IndexType::get(getContext());
335-
}
331+
Type getCrdType() const { return enc.getCrdElemType(); }
336332

337333
/// Returns the position-overhead MLIR type, defaulting to `IndexType`.
338-
Type getPosType() const {
339-
if (getPosWidth())
340-
return IntegerType::get(getContext(), getPosWidth());
341-
return IndexType::get(getContext());
342-
}
334+
Type getPosType() const { return enc.getPosElemType(); }
343335

344336
/// Returns true iff this sparse tensor type has a trailing
345337
/// COO region starting at the given level. By default, it

mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp

Lines changed: 53 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,26 @@ static constexpr bool acceptBitWidth(unsigned bitWidth) {
6161
}
6262
}
6363

64+
static SmallVector<Size>
65+
getSparseFieldShape(const SparseTensorEncodingAttr enc,
66+
std::optional<ArrayRef<int64_t>> dimShape) {
67+
assert(enc);
68+
// With only encoding, we can not determine the static shape for leading
69+
// batch levels, we therefore return a dynamic shape memref instead.
70+
SmallVector<int64_t> memrefShape(enc.getBatchLvlRank(), ShapedType::kDynamic);
71+
if (dimShape.has_value()) {
72+
// If the actual tensor shape is provided, we can then refine the leading
73+
// batch dimension.
74+
SmallVector<int64_t> lvlShape =
75+
enc.translateShape(*dimShape, CrdTransDirectionKind::dim2lvl);
76+
memrefShape.assign(lvlShape.begin(),
77+
lvlShape.begin() + enc.getBatchLvlRank());
78+
}
79+
// Another dynamic dimension to store the sparse level.
80+
memrefShape.push_back(ShapedType::kDynamic);
81+
return memrefShape;
82+
}
83+
6484
//===----------------------------------------------------------------------===//
6585
// SparseTensorDialect StorageLayout.
6686
//===----------------------------------------------------------------------===//
@@ -122,21 +142,17 @@ void sparse_tensor::foreachFieldAndTypeInSparseTensor(
122142
LevelType)>
123143
callback) {
124144
assert(stt.hasEncoding());
125-
// Construct the basic types.
126-
const Type crdType = stt.getCrdType();
127-
const Type posType = stt.getPosType();
128-
const Type eltType = stt.getElementType();
129145

130-
SmallVector<int64_t> memrefShape = stt.getBatchLvlShape();
131-
memrefShape.push_back(ShapedType::kDynamic);
146+
SmallVector<int64_t> memrefShape =
147+
getSparseFieldShape(stt.getEncoding(), stt.getDimShape());
132148

133149
const Type specType = StorageSpecifierType::get(stt.getEncoding());
134150
// memref<[batch] x ? x pos> positions
135-
const Type posMemType = MemRefType::get(memrefShape, posType);
151+
const Type posMemType = MemRefType::get(memrefShape, stt.getPosType());
136152
// memref<[batch] x ? x crd> coordinates
137-
const Type crdMemType = MemRefType::get(memrefShape, crdType);
153+
const Type crdMemType = MemRefType::get(memrefShape, stt.getCrdType());
138154
// memref<[batch] x ? x eltType> values
139-
const Type valMemType = MemRefType::get(memrefShape, eltType);
155+
const Type valMemType = MemRefType::get(memrefShape, stt.getElementType());
140156

141157
StorageLayout(stt).foreachField([specType, posMemType, crdMemType, valMemType,
142158
callback](FieldIndex fieldIdx,
@@ -354,6 +370,34 @@ bool SparseTensorEncodingAttr::isAllOrdered() const {
354370
return !getImpl() || llvm::all_of(getLvlTypes(), isOrderedLT);
355371
}
356372

373+
Type SparseTensorEncodingAttr::getCrdElemType() const {
374+
if (!getImpl())
375+
return nullptr;
376+
if (getCrdWidth())
377+
return IntegerType::get(getContext(), getCrdWidth());
378+
return IndexType::get(getContext());
379+
}
380+
381+
Type SparseTensorEncodingAttr::getPosElemType() const {
382+
if (!getImpl())
383+
return nullptr;
384+
if (getPosWidth())
385+
return IntegerType::get(getContext(), getPosWidth());
386+
return IndexType::get(getContext());
387+
}
388+
389+
MemRefType SparseTensorEncodingAttr::getCrdMemRefType(
390+
std::optional<ArrayRef<int64_t>> dimShape) const {
391+
SmallVector<Size> shape = getSparseFieldShape(*this, dimShape);
392+
return MemRefType::get(shape, getCrdElemType());
393+
}
394+
395+
MemRefType SparseTensorEncodingAttr::getPosMemRefType(
396+
std::optional<ArrayRef<int64_t>> dimShape) const {
397+
SmallVector<Size> shape = getSparseFieldShape(*this, dimShape);
398+
return MemRefType::get(shape, getPosElemType());
399+
}
400+
357401
bool SparseTensorEncodingAttr::isIdentity() const {
358402
return !getImpl() || !getDimToLvl() || getDimToLvl().isIdentity();
359403
}

0 commit comments

Comments
 (0)