@@ -66,7 +66,7 @@ void StorageLayout::foreachField(
66
66
callback) const {
67
67
const auto lvlTypes = enc.getLvlTypes ();
68
68
const Level lvlRank = enc.getLvlRank ();
69
- const Level cooStart = getCOOStart (enc);
69
+ const Level cooStart = SparseTensorType (enc). getCOOStart ( );
70
70
const Level end = cooStart == lvlRank ? cooStart : cooStart + 1 ;
71
71
FieldIndex fieldIdx = kDataFieldStartingIdx ;
72
72
// Per-level storage.
@@ -158,7 +158,7 @@ StorageLayout::getFieldIndexAndStride(SparseTensorFieldKind kind,
158
158
unsigned stride = 1 ;
159
159
if (kind == SparseTensorFieldKind::CrdMemRef) {
160
160
assert (lvl.has_value ());
161
- const Level cooStart = getCOOStart (enc);
161
+ const Level cooStart = SparseTensorType (enc). getCOOStart ( );
162
162
const Level lvlRank = enc.getLvlRank ();
163
163
if (lvl.value () >= cooStart && lvl.value () < lvlRank) {
164
164
lvl = cooStart;
@@ -710,6 +710,29 @@ LogicalResult SparseTensorEncodingAttr::verifyEncoding(
710
710
// SparseTensorType Methods.
711
711
// ===----------------------------------------------------------------------===//
712
712
713
+ bool mlir::sparse_tensor::SparseTensorType::isCOOType (Level startLvl,
714
+ bool isUnique) const {
715
+ if (!hasEncoding ())
716
+ return false ;
717
+ if (!isCompressedLvl (startLvl) && !isLooseCompressedLvl (startLvl))
718
+ return false ;
719
+ for (Level l = startLvl + 1 ; l < lvlRank; ++l)
720
+ if (!isSingletonLvl (l))
721
+ return false ;
722
+ // If isUnique is true, then make sure that the last level is unique,
723
+ // that is, when lvlRank == 1, the only compressed level is unique,
724
+ // and when lvlRank > 1, the last singleton is unique.
725
+ return !isUnique || isUniqueLvl (lvlRank - 1 );
726
+ }
727
+
728
+ Level mlir::sparse_tensor::SparseTensorType::getCOOStart () const {
729
+ if (hasEncoding () && lvlRank > 1 )
730
+ for (Level l = 0 ; l < lvlRank - 1 ; l++)
731
+ if (isCOOType (l, /* isUnique=*/ false ))
732
+ return l;
733
+ return lvlRank;
734
+ }
735
+
713
736
RankedTensorType
714
737
mlir::sparse_tensor::SparseTensorType::getCOOType (bool ordered) const {
715
738
SmallVector<LevelType> lvlTypes;
@@ -859,25 +882,6 @@ bool mlir::sparse_tensor::isBlockSparsity(AffineMap dimToLvl) {
859
882
return !coeffientMap.empty ();
860
883
}
861
884
862
- bool mlir::sparse_tensor::isCOOType (SparseTensorEncodingAttr enc,
863
- Level startLvl, bool isUnique) {
864
- if (!enc ||
865
- !(enc.isCompressedLvl (startLvl) || enc.isLooseCompressedLvl (startLvl)))
866
- return false ;
867
- const Level lvlRank = enc.getLvlRank ();
868
- for (Level l = startLvl + 1 ; l < lvlRank; ++l)
869
- if (!enc.isSingletonLvl (l))
870
- return false ;
871
- // If isUnique is true, then make sure that the last level is unique,
872
- // that is, lvlRank == 1 (unique the only compressed) and lvlRank > 1
873
- // (unique on the last singleton).
874
- return !isUnique || enc.isUniqueLvl (lvlRank - 1 );
875
- }
876
-
877
- bool mlir::sparse_tensor::isUniqueCOOType (Type tp) {
878
- return isCOOType (getSparseTensorEncoding (tp), 0 , /* isUnique=*/ true );
879
- }
880
-
881
885
bool mlir::sparse_tensor::hasAnyNonIdentityOperandsOrResults (Operation *op) {
882
886
auto hasNonIdentityMap = [](Value v) {
883
887
auto stt = tryGetSparseTensorType (v);
@@ -888,17 +892,6 @@ bool mlir::sparse_tensor::hasAnyNonIdentityOperandsOrResults(Operation *op) {
888
892
llvm::any_of (op->getResults (), hasNonIdentityMap);
889
893
}
890
894
891
- Level mlir::sparse_tensor::getCOOStart (SparseTensorEncodingAttr enc) {
892
- // We only consider COO region with at least two levels for the purpose
893
- // of AOS storage optimization.
894
- const Level lvlRank = enc.getLvlRank ();
895
- if (lvlRank > 1 )
896
- for (Level l = 0 ; l < lvlRank - 1 ; l++)
897
- if (isCOOType (enc, l, /* isUnique=*/ false ))
898
- return l;
899
- return lvlRank;
900
- }
901
-
902
895
Dimension mlir::sparse_tensor::toDim (SparseTensorEncodingAttr enc, Level l) {
903
896
if (enc) {
904
897
assert (enc.isPermutation () && " Non permutation map not supported" );
@@ -1013,7 +1006,7 @@ static LogicalResult verifyPackUnPack(Operation *op, bool requiresStaticShape,
1013
1006
return op->emitError (" the sparse-tensor must have the identity mapping" );
1014
1007
1015
1008
// Verifies the trailing COO.
1016
- Level cooStartLvl = getCOOStart ( stt.getEncoding () );
1009
+ Level cooStartLvl = stt.getCOOStart ( );
1017
1010
if (cooStartLvl < stt.getLvlRank ()) {
1018
1011
// We only supports trailing COO for now, must be the last input.
1019
1012
auto cooTp = llvm::cast<ShapedType>(lvlTps.back ());
@@ -1309,34 +1302,34 @@ OpFoldResult ReinterpretMapOp::fold(FoldAdaptor adaptor) {
1309
1302
}
1310
1303
1311
1304
LogicalResult ToPositionsOp::verify () {
1312
- auto e = getSparseTensorEncoding (getTensor (). getType ());
1305
+ auto stt = getSparseTensorType (getTensor ());
1313
1306
if (failed (lvlIsInBounds (getLevel (), getTensor ())))
1314
1307
return emitError (" requested level is out of bounds" );
1315
- if (failed (isMatchingWidth (getResult (), e .getPosWidth ())))
1308
+ if (failed (isMatchingWidth (getResult (), stt .getPosWidth ())))
1316
1309
return emitError (" unexpected type for positions" );
1317
1310
return success ();
1318
1311
}
1319
1312
1320
1313
LogicalResult ToCoordinatesOp::verify () {
1321
- auto e = getSparseTensorEncoding (getTensor (). getType ());
1314
+ auto stt = getSparseTensorType (getTensor ());
1322
1315
if (failed (lvlIsInBounds (getLevel (), getTensor ())))
1323
1316
return emitError (" requested level is out of bounds" );
1324
- if (failed (isMatchingWidth (getResult (), e .getCrdWidth ())))
1317
+ if (failed (isMatchingWidth (getResult (), stt .getCrdWidth ())))
1325
1318
return emitError (" unexpected type for coordinates" );
1326
1319
return success ();
1327
1320
}
1328
1321
1329
1322
LogicalResult ToCoordinatesBufferOp::verify () {
1330
- auto e = getSparseTensorEncoding (getTensor (). getType ());
1331
- if (getCOOStart (e ) >= e .getLvlRank ())
1323
+ auto stt = getSparseTensorType (getTensor ());
1324
+ if (stt. getCOOStart () >= stt .getLvlRank ())
1332
1325
return emitError (" expected sparse tensor with a COO region" );
1333
1326
return success ();
1334
1327
}
1335
1328
1336
1329
LogicalResult ToValuesOp::verify () {
1337
- auto ttp = getRankedTensorType (getTensor ());
1330
+ auto stt = getSparseTensorType (getTensor ());
1338
1331
auto mtp = getMemRefType (getResult ());
1339
- if (ttp .getElementType () != mtp.getElementType ())
1332
+ if (stt .getElementType () != mtp.getElementType ())
1340
1333
return emitError (" unexpected mismatch in element types" );
1341
1334
return success ();
1342
1335
}
@@ -1660,9 +1653,8 @@ LogicalResult ReorderCOOOp::verify() {
1660
1653
SparseTensorType srcStt = getSparseTensorType (getInputCoo ());
1661
1654
SparseTensorType dstStt = getSparseTensorType (getResultCoo ());
1662
1655
1663
- if (!isCOOType (srcStt.getEncoding (), 0 , /* isUnique=*/ true ) ||
1664
- !isCOOType (dstStt.getEncoding (), 0 , /* isUnique=*/ true ))
1665
- emitError (" Unexpected non-COO sparse tensors" );
1656
+ if (!srcStt.isCOOType () || !dstStt.isCOOType ())
1657
+ emitError (" Expected COO sparse tensors only" );
1666
1658
1667
1659
if (!srcStt.hasSameDimToLvl (dstStt))
1668
1660
emitError (" Unmatched dim2lvl map between input and result COO" );
0 commit comments