Skip to content

Commit 13af97a

Browse files
author
Peiming Liu
authored
[mlir][sparse] allow multiple COO segments in sparse encodings. (#91786)
**NOTE**: we still have implementation holes when handling multiple COO segments in the encoding. But the format should be considered to be legal.
1 parent 77a59c3 commit 13af97a

File tree

2 files changed

+21
-5
lines changed

2 files changed

+21
-5
lines changed

mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -788,24 +788,29 @@ LogicalResult SparseTensorEncodingAttr::verify(
788788
return emitError() << "unexpected position bitwidth: " << posWidth;
789789
if (!acceptBitWidth(crdWidth))
790790
return emitError() << "unexpected coordinate bitwidth: " << crdWidth;
791-
if (auto it = std::find_if(lvlTypes.begin(), lvlTypes.end(), isSingletonLT);
792-
it != std::end(lvlTypes)) {
791+
792+
// Verify every COO segment.
793+
auto *it = std::find_if(lvlTypes.begin(), lvlTypes.end(), isSingletonLT);
794+
while (it != lvlTypes.end()) {
793795
if (it == lvlTypes.begin() ||
794-
(!isCompressedLT(*(it - 1)) && !isLooseCompressedLT(*(it - 1))))
796+
!(it - 1)->isa<LevelFormat::Compressed, LevelFormat::LooseCompressed>())
795797
return emitError() << "expected compressed or loose_compressed level "
796798
"before singleton level";
797-
if (!std::all_of(it, lvlTypes.end(),
799+
800+
auto *curCOOEnd = std::find_if_not(it, lvlTypes.end(), isSingletonLT);
801+
if (!std::all_of(it, curCOOEnd,
798802
[](LevelType i) { return isSingletonLT(i); }))
799803
return emitError() << "expected all singleton lvlTypes "
800804
"following a singleton level";
801805
// We can potentially support mixed SoA/AoS singleton levels.
802-
if (!std::all_of(it, lvlTypes.end(), [it](LevelType i) {
806+
if (!std::all_of(it, curCOOEnd, [it](LevelType i) {
803807
return it->isa<LevelPropNonDefault::SoA>() ==
804808
i.isa<LevelPropNonDefault::SoA>();
805809
})) {
806810
return emitError() << "expected all singleton lvlTypes stored in the "
807811
"same memory layout (SoA vs AoS).";
808812
}
813+
it = std::find_if(curCOOEnd, lvlTypes.end(), isSingletonLT);
809814
}
810815

811816
auto lastBatch = std::find_if(lvlTypes.rbegin(), lvlTypes.rend(), isBatchLT);

mlir/test/Dialect/SparseTensor/roundtrip_encoding.mlir

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,17 @@ func.func private @sparse_coo(tensor<?x?xf32, #COO>)
156156

157157
// -----
158158

159+
#COO_DENSE = #sparse_tensor.encoding<{
160+
map = (d0, d1, d2) -> (d0 : compressed(nonunique), d1 : singleton, d2: dense)
161+
}>
162+
163+
// CHECK-DAG: #[[$COO:.*]] = #sparse_tensor.encoding<{ map = (d0, d1, d2) -> (d0 : compressed(nonunique), d1 : singleton, d2 : dense) }>
164+
// CHECK-LABEL: func private @sparse_coo_trailing_dense(
165+
// CHECK-SAME: tensor<?x?x1xf32, #[[$COO]]>)
166+
func.func private @sparse_coo_trailing_dense(tensor<?x?x1xf32, #COO_DENSE>)
167+
168+
// -----
169+
159170
#BCOO = #sparse_tensor.encoding<{
160171
map = (d0, d1, d2) -> (d0 : dense, d1 : loose_compressed(nonunique), d2 : singleton)
161172
}>

0 commit comments

Comments
 (0)