Skip to content

Commit 3d3e46c

Browse files
authored
[mlir][sparse] make test for block sparsity more robust (llvm#74798)
For BSR and convolutions, we encounter (d0, d1, d2, d3) -> ((d0 + d2) floordiv 2, (d1 + d3) floordiv 2, (d0 + d2) mod 2, (d1 + d3) mod 2) which crashed the current test. Note that an actual test and working code is still to follow (since we need to fix a few other things first)
1 parent 944e031 commit 3d3e46c

File tree

1 file changed

+12
-8
lines changed

1 file changed

+12
-8
lines changed

mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -858,22 +858,26 @@ bool mlir::sparse_tensor::isBlockSparsity(AffineMap dimToLvl) {
858858
std::map<unsigned, int64_t> coeffientMap;
859859
for (auto result : dimToLvl.getResults()) {
860860
if (auto binOp = dyn_cast<AffineBinaryOpExpr>(result)) {
861-
auto pos = dyn_cast<AffineDimExpr>(binOp.getLHS()).getPosition();
862-
if (result.getKind() == AffineExprKind::FloorDiv) {
861+
// Check for "dim op const".
862+
auto dimOp = dyn_cast<AffineDimExpr>(binOp.getLHS());
863+
auto conOp = dyn_cast<AffineConstantExpr>(binOp.getRHS());
864+
if (!dimOp || !conOp)
865+
return false;
866+
// Inspect "dim / const" or "dim % const".
867+
auto pos = dimOp.getPosition();
868+
if (binOp.getKind() == AffineExprKind::FloorDiv) {
863869
// Expect only one floordiv for each dimension.
864870
if (coeffientMap.find(pos) != coeffientMap.end())
865871
return false;
866-
coeffientMap[pos] =
867-
dyn_cast<AffineConstantExpr>(binOp.getRHS()).getValue();
868-
} else if (result.getKind() == AffineExprKind::Mod) {
872+
// Record coefficient of the floordiv.
873+
coeffientMap[pos] = conOp.getValue();
874+
} else if (binOp.getKind() == AffineExprKind::Mod) {
869875
// Expect floordiv before mod.
870876
if (coeffientMap.find(pos) == coeffientMap.end())
871877
return false;
872878
// Expect mod to have the same coefficient as floordiv.
873-
if (dyn_cast<AffineConstantExpr>(binOp.getRHS()).getValue() !=
874-
coeffientMap[pos]) {
879+
if (conOp.getValue() != coeffientMap[pos])
875880
return false;
876-
}
877881
} else {
878882
return false;
879883
}

0 commit comments

Comments
 (0)