@@ -455,6 +455,7 @@ Attribute SparseTensorEncodingAttr::parse(AsmParser &parser, Type type) {
455
455
SmallVector<DimLevelType> lvlTypes;
456
456
SmallVector<SparseTensorDimSliceAttr> dimSlices;
457
457
AffineMap dimToLvl = {};
458
+ AffineMap lvlToDim = {};
458
459
unsigned posWidth = 0 ;
459
460
unsigned crdWidth = 0 ;
460
461
StringRef attrName;
@@ -568,6 +569,7 @@ Attribute SparseTensorEncodingAttr::parse(AsmParser &parser, Type type) {
568
569
569
570
ERROR_IF (dimToLvl, " Cannot mix `dimToLvl` with `map`" )
570
571
dimToLvl = dlm.getDimToLvlMap (parser.getContext ());
572
+ lvlToDim = dlm.getLvlToDimMap (parser.getContext ());
571
573
break ;
572
574
}
573
575
} // switch
@@ -582,8 +584,9 @@ Attribute SparseTensorEncodingAttr::parse(AsmParser &parser, Type type) {
582
584
#undef RETURN_ON_FAIL
583
585
584
586
// Construct struct-like storage for attribute.
585
- // TODO: Fetch lvlToDim if user provides one
586
- AffineMap lvlToDim = inferLvlToDim (dimToLvl, parser.getContext ());
587
+ if (!lvlToDim || lvlToDim.isEmpty ()) {
588
+ lvlToDim = inferLvlToDim (dimToLvl, parser.getContext ());
589
+ }
587
590
return parser.getChecked <SparseTensorEncodingAttr>(
588
591
parser.getContext (), lvlTypes, dimToLvl, lvlToDim, posWidth, crdWidth,
589
592
dimSlices);
@@ -663,6 +666,17 @@ SparseTensorEncodingAttr::verify(function_ref<InFlightDiagnostic()> emitError,
663
666
return emitError () << " unexpected position bitwidth: " << posWidth;
664
667
if (!acceptBitWidth (crdWidth))
665
668
return emitError () << " unexpected coordinate bitwidth: " << crdWidth;
669
+ if (auto it = std::find_if (lvlTypes.begin (), lvlTypes.end (), isSingletonDLT);
670
+ it != std::end (lvlTypes)) {
671
+ if (it == lvlTypes.begin () ||
672
+ (!isCompressedDLT (*(it - 1 )) && !isLooseCompressedDLT (*(it - 1 ))))
673
+ return emitError () << " expected compressed or loose_compressed level "
674
+ " before singleton level" ;
675
+ if (!std::all_of (it, lvlTypes.end (),
676
+ [](DimLevelType i) { return isSingletonDLT (i); }))
677
+ return emitError () << " expected all singleton lvlTypes "
678
+ " following a singleton level" ;
679
+ }
666
680
// Before we can check that the level-rank is consistent/coherent
667
681
// across all fields, we need to define it. The source-of-truth for
668
682
// the `getLvlRank` method is the length of the level-types array,
@@ -678,19 +692,12 @@ SparseTensorEncodingAttr::verify(function_ref<InFlightDiagnostic()> emitError,
678
692
return emitError ()
679
693
<< " level-rank mismatch between dimToLvl and lvlTypes: "
680
694
<< dimToLvl.getNumResults () << " != " << lvlRank;
681
- // TODO: The following is attempting to match the old error-conditions
682
- // from prior to merging dimOrdering and higherOrdering into dimToLvl.
683
- // That is, we currently require `dimToLvl` to be either a permutation
684
- // (as when higherOrdering is the identity) or expansive (as per the
685
- // constraints on higherOrdering). However, those constraints do
686
- // not match the intended semantics of `dimToLvl`. As we improve the
687
- // compiler to actually handle non-permutations, we need to update these
688
- // checks to match what is actually supported. In particular, this is
689
- // where we'll have to check that when `lvlToDim` is provided then it
690
- // is indeed an inverse of `dimToLvl`, and when it isn't provided then
691
- // it can be automatically inferred.
692
- if (dimRank == lvlRank && !dimToLvl.isPermutation ())
693
- return emitError () << " expected a permutation affine map for dimToLvl" ;
695
+ auto inferRes = inferLvlToDim (dimToLvl, dimToLvl.getContext ());
696
+ // Symbols can't be inferred but are acceptable.
697
+ if (!inferRes && dimToLvl.getNumSymbols () == 0 )
698
+ return emitError () << " failed to infer lvlToDim from dimToLvl" ;
699
+ if (lvlToDim && (inferRes != lvlToDim))
700
+ return emitError () << " expected lvlToDim to be an inverse of dimToLvl" ;
694
701
if (dimRank > lvlRank)
695
702
return emitError () << " unexpected dimToLvl mapping from " << dimRank
696
703
<< " to " << lvlRank;
@@ -758,8 +765,7 @@ AffineMap mlir::sparse_tensor::inferLvlToDim(AffineMap dimToLvl,
758
765
lvlToDim = AffineMap ();
759
766
} else if (map.isPermutation ()) {
760
767
lvlToDim = inversePermutation (map);
761
- } else {
762
- // TODO: check if it's block sparsity
768
+ } else if (isBlockSparsity (map)) {
763
769
lvlToDim = inverseBlockSparsity (map, context);
764
770
}
765
771
return lvlToDim;
@@ -818,6 +824,53 @@ AffineMap mlir::sparse_tensor::inverseBlockSparsity(AffineMap dimToLvl,
818
824
return dimToLvl.get (dimToLvl.getNumResults (), 0 , lvlExprs, context);
819
825
}
820
826
827
+ SmallVector<unsigned > mlir::sparse_tensor::getBlockSize (AffineMap dimToLvl) {
828
+ assert (isBlockSparsity (dimToLvl) &&
829
+ " expected dimToLvl to be block sparsity for calling getBlockSize" );
830
+ SmallVector<unsigned > blockSize;
831
+ for (auto result : dimToLvl.getResults ()) {
832
+ if (auto binOp = result.dyn_cast <AffineBinaryOpExpr>()) {
833
+ if (result.getKind () == AffineExprKind::Mod) {
834
+ blockSize.push_back (
835
+ binOp.getRHS ().dyn_cast <AffineConstantExpr>().getValue ());
836
+ }
837
+ } else {
838
+ blockSize.push_back (0 );
839
+ }
840
+ }
841
+ return blockSize;
842
+ }
843
+
844
+ bool mlir::sparse_tensor::isBlockSparsity (AffineMap dimToLvl) {
845
+ if (!dimToLvl)
846
+ return false ;
847
+ std::map<unsigned , int64_t > coeffientMap;
848
+ for (auto result : dimToLvl.getResults ()) {
849
+ if (auto binOp = result.dyn_cast <AffineBinaryOpExpr>()) {
850
+ auto pos = binOp.getLHS ().dyn_cast <AffineDimExpr>().getPosition ();
851
+ if (result.getKind () == AffineExprKind::FloorDiv) {
852
+ // Expect only one floordiv for each dimension.
853
+ if (coeffientMap.find (pos) != coeffientMap.end ())
854
+ return false ;
855
+ coeffientMap[pos] =
856
+ binOp.getRHS ().dyn_cast <AffineConstantExpr>().getValue ();
857
+ } else if (result.getKind () == AffineExprKind::Mod) {
858
+ // Expect floordiv before mod.
859
+ if (coeffientMap.find (pos) == coeffientMap.end ())
860
+ return false ;
861
+ // Expect mod to have the same coefficient as floordiv.
862
+ if (binOp.getRHS ().dyn_cast <AffineConstantExpr>().getValue () !=
863
+ coeffientMap[pos]) {
864
+ return false ;
865
+ }
866
+ } else {
867
+ return false ;
868
+ }
869
+ }
870
+ }
871
+ return !coeffientMap.empty ();
872
+ }
873
+
821
874
bool mlir::sparse_tensor::isCOOType (SparseTensorEncodingAttr enc,
822
875
Level startLvl, bool isUnique) {
823
876
if (!enc ||
0 commit comments