Skip to content

Commit f16cb0e

Browse files
authored
[mlir][sparse] connect MapRef's lvl2dim with latest AffineMap computation (#69540)
This makes sure - GEN MAP dim=2 lvl=4 (d0, d1) -> (d0 floordiv 2, d1 floordiv 2, d0 mod 2, d1 mod 2) -- (d0, d1, d2, d3) -> (d0 * 2 + d2, d1 * 2 + d3) is indeed encoded as MAP-REF (dim=2, lvl=4) isperm=0 d2l = [ d0/2 d1/2 d0%2 d1%2 ] ld2 = [ l2+2*l0 l3+2*l1 ]
1 parent e103515 commit f16cb0e

File tree

1 file changed

+42
-20
lines changed

1 file changed

+42
-20
lines changed

mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp

Lines changed: 42 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -691,6 +691,7 @@ Value sparse_tensor::genMapBuffers(OpBuilder &builder, Location loc,
691691
// This code deals with permutations as well as non-permutations that
692692
// arise from rank changing blocking.
693693
const auto dimToLvl = stt.getDimToLvl();
694+
const auto lvlToDim = stt.getLvlToDim();
694695
SmallVector<Value> dim2lvlValues(lvlRank); // for each lvl, expr in dim vars
695696
SmallVector<Value> lvl2dimValues(dimRank); // for each dim, expr in lvl vars
696697
SmallVector<Value> lvlSizesValues(lvlRank);
@@ -705,34 +706,26 @@ Value sparse_tensor::genMapBuffers(OpBuilder &builder, Location loc,
705706
Dimension d = 0;
706707
uint64_t cf = 0, cm = 0;
707708
switch (exp.getKind()) {
708-
case AffineExprKind::DimId:
709+
case AffineExprKind::DimId: {
709710
d = exp.cast<AffineDimExpr>().getPosition();
710711
break;
711-
case AffineExprKind::FloorDiv:
712-
d = exp.cast<AffineBinaryOpExpr>()
713-
.getLHS()
714-
.cast<AffineDimExpr>()
715-
.getPosition();
716-
cf = exp.cast<AffineBinaryOpExpr>()
717-
.getRHS()
718-
.cast<AffineConstantExpr>()
719-
.getValue();
712+
}
713+
case AffineExprKind::FloorDiv: {
714+
auto floor = exp.cast<AffineBinaryOpExpr>();
715+
d = floor.getLHS().cast<AffineDimExpr>().getPosition();
716+
cf = floor.getRHS().cast<AffineConstantExpr>().getValue();
720717
break;
721-
case AffineExprKind::Mod:
722-
d = exp.cast<AffineBinaryOpExpr>()
723-
.getLHS()
724-
.cast<AffineDimExpr>()
725-
.getPosition();
726-
cm = exp.cast<AffineBinaryOpExpr>()
727-
.getRHS()
728-
.cast<AffineConstantExpr>()
729-
.getValue();
718+
}
719+
case AffineExprKind::Mod: {
720+
auto mod = exp.cast<AffineBinaryOpExpr>();
721+
d = mod.getLHS().cast<AffineDimExpr>().getPosition();
722+
cm = mod.getRHS().cast<AffineConstantExpr>().getValue();
730723
break;
724+
}
731725
default:
732726
llvm::report_fatal_error("unsupported dim2lvl in sparse tensor type");
733727
}
734728
dim2lvlValues[l] = constantIndex(builder, loc, encodeDim(d, cf, cm));
735-
lvl2dimValues[d] = constantIndex(builder, loc, l); // FIXME, use lvlToDim
736729
// Compute the level sizes.
737730
// (1) l = d : size(d)
738731
// (2) l = d / c : size(d) / c
@@ -751,6 +744,35 @@ Value sparse_tensor::genMapBuffers(OpBuilder &builder, Location loc,
751744
}
752745
lvlSizesValues[l] = lvlSz;
753746
}
747+
// Generate lvl2dim.
748+
assert(dimRank == lvlToDim.getNumResults());
749+
for (Dimension d = 0; d < dimRank; d++) {
750+
AffineExpr exp = lvlToDim.getResult(d);
751+
// We expect:
752+
// (1) d = l
753+
// (2) d = l' * c + l
754+
Level l = 0, ll = 0;
755+
uint64_t c = 0;
756+
switch (exp.getKind()) {
757+
case AffineExprKind::DimId: {
758+
l = exp.cast<AffineDimExpr>().getPosition();
759+
break;
760+
}
761+
case AffineExprKind::Add: {
762+
// Always mul on lhs, symbol/constant on rhs.
763+
auto add = exp.cast<AffineBinaryOpExpr>();
764+
assert(add.getLHS().getKind() == AffineExprKind::Mul);
765+
auto mul = add.getLHS().cast<AffineBinaryOpExpr>();
766+
ll = mul.getLHS().cast<AffineDimExpr>().getPosition();
767+
c = mul.getRHS().cast<AffineConstantExpr>().getValue();
768+
l = add.getRHS().cast<AffineDimExpr>().getPosition();
769+
break;
770+
}
771+
default:
772+
llvm::report_fatal_error("unsupported lvl2dim in sparse tensor type");
773+
}
774+
lvl2dimValues[d] = constantIndex(builder, loc, encodeLvl(l, c, ll));
775+
}
754776
// Return buffers.
755777
dim2lvlBuffer = allocaBuffer(builder, loc, dim2lvlValues);
756778
lvl2dimBuffer = allocaBuffer(builder, loc, lvl2dimValues);

0 commit comments

Comments
 (0)