@@ -788,24 +788,29 @@ LogicalResult SparseTensorEncodingAttr::verify(
788
788
return emitError () << " unexpected position bitwidth: " << posWidth;
789
789
if (!acceptBitWidth (crdWidth))
790
790
return emitError () << " unexpected coordinate bitwidth: " << crdWidth;
791
- if (auto it = std::find_if (lvlTypes.begin (), lvlTypes.end (), isSingletonLT);
792
- it != std::end (lvlTypes)) {
791
+
792
+ // Verify every COO segment.
793
+ auto *it = std::find_if (lvlTypes.begin (), lvlTypes.end (), isSingletonLT);
794
+ while (it != lvlTypes.end ()) {
793
795
if (it == lvlTypes.begin () ||
794
- (! isCompressedLT (*( it - 1 )) && ! isLooseCompressedLT (*(it - 1 )) ))
796
+ !( it - 1 )-> isa <LevelFormat::Compressed, LevelFormat::LooseCompressed>( ))
795
797
return emitError () << " expected compressed or loose_compressed level "
796
798
" before singleton level" ;
797
- if (!std::all_of (it, lvlTypes.end (),
799
+
800
+ auto *curCOOEnd = std::find_if_not (it, lvlTypes.end (), isSingletonLT);
801
+ if (!std::all_of (it, curCOOEnd,
798
802
[](LevelType i) { return isSingletonLT (i); }))
799
803
return emitError () << " expected all singleton lvlTypes "
800
804
" following a singleton level" ;
801
805
// We can potentially support mixed SoA/AoS singleton levels.
802
- if (!std::all_of (it, lvlTypes. end () , [it](LevelType i) {
806
+ if (!std::all_of (it, curCOOEnd , [it](LevelType i) {
803
807
return it->isa <LevelPropNonDefault::SoA>() ==
804
808
i.isa <LevelPropNonDefault::SoA>();
805
809
})) {
806
810
return emitError () << " expected all singleton lvlTypes stored in the "
807
811
" same memory layout (SoA vs AoS)." ;
808
812
}
813
+ it = std::find_if (curCOOEnd, lvlTypes.end (), isSingletonLT);
809
814
}
810
815
811
816
auto lastBatch = std::find_if (lvlTypes.rbegin (), lvlTypes.rend (), isBatchLT);
0 commit comments