Skip to content

Commit 7b9fb1c

Browse files
[mlir][sparse] Update verifier for block sparsity and singleton (#69389)
Updates: 1. Verification of block sparsity. 2. Verification of singleton level type can only follow compressed or loose_compressed levels. And all level types after singleton should be singleton. 3. Added getBlockSize function. 4. Added an invalid encoding test for an incorrect lvlToDim map that user provides.
1 parent fdac18c commit 7b9fb1c

File tree

4 files changed

+114
-22
lines changed

4 files changed

+114
-22
lines changed

mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,21 @@ AffineMap inferLvlToDim(AffineMap dimToLvl, MLIRContext *context);
173173
/// Asserts on failure (so only use when known to succeed).
174174
AffineMap inverseBlockSparsity(AffineMap dimToLvl, MLIRContext *context);
175175

176+
/// Given the dimToLvl map, returns the block sizes in a vector.
177+
/// For instance, a 2x3 block will return [2, 3]. Unblocked dimension i
178+
/// will return 0, and i floordiv 1, i mod 1 will return 1. Therefore,
179+
/// the example below will return [0, 1].
180+
/// map = ( i, j ) ->
181+
/// ( i : dense,
182+
/// j floordiv 1 : compressed,
183+
/// j mod 1 : dense
184+
/// )
185+
/// Only valid block sparsity will be accepted.
186+
SmallVector<unsigned> getBlockSize(AffineMap dimToLvl);
187+
188+
/// Given the dimToLvl map, returns if it's block sparsity.
189+
bool isBlockSparsity(AffineMap dimToLvl);
190+
176191
//
177192
// Reordering.
178193
//

mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMap.cpp

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -356,10 +356,15 @@ AffineMap DimLvlMap::getDimToLvlMap(MLIRContext *context) const {
356356
AffineMap DimLvlMap::getLvlToDimMap(MLIRContext *context) const {
357357
SmallVector<AffineExpr> dimAffines;
358358
dimAffines.reserve(getDimRank());
359-
for (const auto &dimSpec : dimSpecs)
360-
dimAffines.push_back(dimSpec.getExpr().getAffineExpr());
359+
for (const auto &dimSpec : dimSpecs) {
360+
auto expr = dimSpec.getExpr().getAffineExpr();
361+
if (expr) {
362+
dimAffines.push_back(expr);
363+
}
364+
}
361365
auto map = AffineMap::get(getLvlRank(), getSymRank(), dimAffines, context);
362-
if (map.isIdentity()) return AffineMap();
366+
if (dimAffines.empty() || map.isIdentity())
367+
return AffineMap();
363368
return map;
364369
}
365370

mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp

Lines changed: 70 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -455,6 +455,7 @@ Attribute SparseTensorEncodingAttr::parse(AsmParser &parser, Type type) {
455455
SmallVector<DimLevelType> lvlTypes;
456456
SmallVector<SparseTensorDimSliceAttr> dimSlices;
457457
AffineMap dimToLvl = {};
458+
AffineMap lvlToDim = {};
458459
unsigned posWidth = 0;
459460
unsigned crdWidth = 0;
460461
StringRef attrName;
@@ -568,6 +569,7 @@ Attribute SparseTensorEncodingAttr::parse(AsmParser &parser, Type type) {
568569

569570
ERROR_IF(dimToLvl, "Cannot mix `dimToLvl` with `map`")
570571
dimToLvl = dlm.getDimToLvlMap(parser.getContext());
572+
lvlToDim = dlm.getLvlToDimMap(parser.getContext());
571573
break;
572574
}
573575
} // switch
@@ -582,8 +584,9 @@ Attribute SparseTensorEncodingAttr::parse(AsmParser &parser, Type type) {
582584
#undef RETURN_ON_FAIL
583585

584586
// Construct struct-like storage for attribute.
585-
// TODO: Fetch lvlToDim if user provides one
586-
AffineMap lvlToDim = inferLvlToDim(dimToLvl, parser.getContext());
587+
if (!lvlToDim || lvlToDim.isEmpty()) {
588+
lvlToDim = inferLvlToDim(dimToLvl, parser.getContext());
589+
}
587590
return parser.getChecked<SparseTensorEncodingAttr>(
588591
parser.getContext(), lvlTypes, dimToLvl, lvlToDim, posWidth, crdWidth,
589592
dimSlices);
@@ -663,6 +666,17 @@ SparseTensorEncodingAttr::verify(function_ref<InFlightDiagnostic()> emitError,
663666
return emitError() << "unexpected position bitwidth: " << posWidth;
664667
if (!acceptBitWidth(crdWidth))
665668
return emitError() << "unexpected coordinate bitwidth: " << crdWidth;
669+
if (auto it = std::find_if(lvlTypes.begin(), lvlTypes.end(), isSingletonDLT);
670+
it != std::end(lvlTypes)) {
671+
if (it == lvlTypes.begin() ||
672+
(!isCompressedDLT(*(it - 1)) && !isLooseCompressedDLT(*(it - 1))))
673+
return emitError() << "expected compressed or loose_compressed level "
674+
"before singleton level";
675+
if (!std::all_of(it, lvlTypes.end(),
676+
[](DimLevelType i) { return isSingletonDLT(i); }))
677+
return emitError() << "expected all singleton lvlTypes "
678+
"following a singleton level";
679+
}
666680
// Before we can check that the level-rank is consistent/coherent
667681
// across all fields, we need to define it. The source-of-truth for
668682
// the `getLvlRank` method is the length of the level-types array,
@@ -678,19 +692,12 @@ SparseTensorEncodingAttr::verify(function_ref<InFlightDiagnostic()> emitError,
678692
return emitError()
679693
<< "level-rank mismatch between dimToLvl and lvlTypes: "
680694
<< dimToLvl.getNumResults() << " != " << lvlRank;
681-
// TODO: The following is attempting to match the old error-conditions
682-
// from prior to merging dimOrdering and higherOrdering into dimToLvl.
683-
// That is, we currently require `dimToLvl` to be either a permutation
684-
// (as when higherOrdering is the identity) or expansive (as per the
685-
// constraints on higherOrdering). However, those constraints do
686-
// not match the intended semantics of `dimToLvl`. As we improve the
687-
// compiler to actually handle non-permutations, we need to update these
688-
// checks to match what is actually supported. In particular, this is
689-
// where we'll have to check that when `lvlToDim` is provided then it
690-
// is indeed an inverse of `dimToLvl`, and when it isn't provided then
691-
// it can be automatically inferred.
692-
if (dimRank == lvlRank && !dimToLvl.isPermutation())
693-
return emitError() << "expected a permutation affine map for dimToLvl";
695+
auto inferRes = inferLvlToDim(dimToLvl, dimToLvl.getContext());
696+
// Symbols can't be inferred but are acceptable.
697+
if (!inferRes && dimToLvl.getNumSymbols() == 0)
698+
return emitError() << "failed to infer lvlToDim from dimToLvl";
699+
if (lvlToDim && (inferRes != lvlToDim))
700+
return emitError() << "expected lvlToDim to be an inverse of dimToLvl";
694701
if (dimRank > lvlRank)
695702
return emitError() << "unexpected dimToLvl mapping from " << dimRank
696703
<< " to " << lvlRank;
@@ -758,8 +765,7 @@ AffineMap mlir::sparse_tensor::inferLvlToDim(AffineMap dimToLvl,
758765
lvlToDim = AffineMap();
759766
} else if (map.isPermutation()) {
760767
lvlToDim = inversePermutation(map);
761-
} else {
762-
// TODO: check if it's block sparsity
768+
} else if (isBlockSparsity(map)) {
763769
lvlToDim = inverseBlockSparsity(map, context);
764770
}
765771
return lvlToDim;
@@ -818,6 +824,53 @@ AffineMap mlir::sparse_tensor::inverseBlockSparsity(AffineMap dimToLvl,
818824
return dimToLvl.get(dimToLvl.getNumResults(), 0, lvlExprs, context);
819825
}
820826

827+
SmallVector<unsigned> mlir::sparse_tensor::getBlockSize(AffineMap dimToLvl) {
828+
assert(isBlockSparsity(dimToLvl) &&
829+
"expected dimToLvl to be block sparsity for calling getBlockSize");
830+
SmallVector<unsigned> blockSize;
831+
for (auto result : dimToLvl.getResults()) {
832+
if (auto binOp = result.dyn_cast<AffineBinaryOpExpr>()) {
833+
if (result.getKind() == AffineExprKind::Mod) {
834+
blockSize.push_back(
835+
binOp.getRHS().dyn_cast<AffineConstantExpr>().getValue());
836+
}
837+
} else {
838+
blockSize.push_back(0);
839+
}
840+
}
841+
return blockSize;
842+
}
843+
844+
bool mlir::sparse_tensor::isBlockSparsity(AffineMap dimToLvl) {
845+
if (!dimToLvl)
846+
return false;
847+
std::map<unsigned, int64_t> coeffientMap;
848+
for (auto result : dimToLvl.getResults()) {
849+
if (auto binOp = result.dyn_cast<AffineBinaryOpExpr>()) {
850+
auto pos = binOp.getLHS().dyn_cast<AffineDimExpr>().getPosition();
851+
if (result.getKind() == AffineExprKind::FloorDiv) {
852+
// Expect only one floordiv for each dimension.
853+
if (coeffientMap.find(pos) != coeffientMap.end())
854+
return false;
855+
coeffientMap[pos] =
856+
binOp.getRHS().dyn_cast<AffineConstantExpr>().getValue();
857+
} else if (result.getKind() == AffineExprKind::Mod) {
858+
// Expect floordiv before mod.
859+
if (coeffientMap.find(pos) == coeffientMap.end())
860+
return false;
861+
// Expect mod to have the same coefficient as floordiv.
862+
if (binOp.getRHS().dyn_cast<AffineConstantExpr>().getValue() !=
863+
coeffientMap[pos]) {
864+
return false;
865+
}
866+
} else {
867+
return false;
868+
}
869+
}
870+
}
871+
return !coeffientMap.empty();
872+
}
873+
821874
bool mlir::sparse_tensor::isCOOType(SparseTensorEncodingAttr enc,
822875
Level startLvl, bool isUnique) {
823876
if (!enc ||

mlir/test/Dialect/SparseTensor/invalid_encoding.mlir

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ func.func private @tensor_sizes_mismatch(%arg0: tensor<8xi32, #a>) -> ()
6060

6161
// -----
6262

63-
// expected-error@+1 {{unexpected dimToLvl mapping from 2 to 1}}
63+
// expected-error@+1 {{failed to infer lvlToDim from dimToLvl}}
6464
#a = #sparse_tensor.encoding<{map = (d0, d1) -> (d0 : dense)}>
6565
func.func private @tensor_sizes_mismatch(%arg0: tensor<8xi32, #a>) -> ()
6666

@@ -119,7 +119,7 @@ func.func private @tensor_dimtolvl_mismatch(%arg0: tensor<8xi32, #a>) -> ()
119119

120120
// -----
121121

122-
// expected-error@+1 {{expected a permutation affine map for dimToLvl}}
122+
// expected-error@+1 {{failed to infer lvlToDim from dimToLvl}}
123123
#a = #sparse_tensor.encoding<{map = (d0, d1) -> (d0 : dense, d0 : compressed)}>
124124
func.func private @tensor_no_permutation(%arg0: tensor<16x32xf32, #a>) -> ()
125125

@@ -251,3 +251,22 @@ func.func private @too_few_lvl_decl(%arg0: tensor<?x?xf64, #TooFewLvlDecl>) {
251251
func.func private @wrong_order_lvl_decl(%arg0: tensor<?x?xf64, #WrongOrderLvlDecl>) {
252252
return
253253
}
254+
255+
// -----
256+
257+
// expected-error@+1 {{expected lvlToDim to be an inverse of dimToLvl}}
258+
#BSR_explicit = #sparse_tensor.encoding<{
259+
map =
260+
{il, jl, ii, jj}
261+
( i = il * 3 + ii,
262+
j = jl * 2 + jj
263+
) ->
264+
( il = i floordiv 2 : dense,
265+
jl = j floordiv 3 : compressed,
266+
ii = i mod 2 : dense,
267+
jj = j mod 3 : dense
268+
)
269+
}>
270+
func.func private @BSR_explicit(%arg0: tensor<?x?xf64, #BSR_explicit>) {
271+
return
272+
}

0 commit comments

Comments
 (0)