@@ -691,6 +691,7 @@ Value sparse_tensor::genMapBuffers(OpBuilder &builder, Location loc,
691
691
// This code deals with permutations as well as non-permutations that
692
692
// arise from rank changing blocking.
693
693
const auto dimToLvl = stt.getDimToLvl ();
694
+ const auto lvlToDim = stt.getLvlToDim ();
694
695
SmallVector<Value> dim2lvlValues (lvlRank); // for each lvl, expr in dim vars
695
696
SmallVector<Value> lvl2dimValues (dimRank); // for each dim, expr in lvl vars
696
697
SmallVector<Value> lvlSizesValues (lvlRank);
@@ -705,34 +706,26 @@ Value sparse_tensor::genMapBuffers(OpBuilder &builder, Location loc,
705
706
Dimension d = 0 ;
706
707
uint64_t cf = 0 , cm = 0 ;
707
708
switch (exp.getKind ()) {
708
- case AffineExprKind::DimId:
709
+ case AffineExprKind::DimId: {
709
710
d = exp.cast <AffineDimExpr>().getPosition ();
710
711
break ;
711
- case AffineExprKind::FloorDiv:
712
- d = exp.cast <AffineBinaryOpExpr>()
713
- .getLHS ()
714
- .cast <AffineDimExpr>()
715
- .getPosition ();
716
- cf = exp.cast <AffineBinaryOpExpr>()
717
- .getRHS ()
718
- .cast <AffineConstantExpr>()
719
- .getValue ();
712
+ }
713
+ case AffineExprKind::FloorDiv: {
714
+ auto floor = exp.cast <AffineBinaryOpExpr>();
715
+ d = floor.getLHS ().cast <AffineDimExpr>().getPosition ();
716
+ cf = floor.getRHS ().cast <AffineConstantExpr>().getValue ();
720
717
break ;
721
- case AffineExprKind::Mod:
722
- d = exp.cast <AffineBinaryOpExpr>()
723
- .getLHS ()
724
- .cast <AffineDimExpr>()
725
- .getPosition ();
726
- cm = exp.cast <AffineBinaryOpExpr>()
727
- .getRHS ()
728
- .cast <AffineConstantExpr>()
729
- .getValue ();
718
+ }
719
+ case AffineExprKind::Mod: {
720
+ auto mod = exp.cast <AffineBinaryOpExpr>();
721
+ d = mod.getLHS ().cast <AffineDimExpr>().getPosition ();
722
+ cm = mod.getRHS ().cast <AffineConstantExpr>().getValue ();
730
723
break ;
724
+ }
731
725
default :
732
726
llvm::report_fatal_error (" unsupported dim2lvl in sparse tensor type" );
733
727
}
734
728
dim2lvlValues[l] = constantIndex (builder, loc, encodeDim (d, cf, cm));
735
- lvl2dimValues[d] = constantIndex (builder, loc, l); // FIXME, use lvlToDim
736
729
// Compute the level sizes.
737
730
// (1) l = d : size(d)
738
731
// (2) l = d / c : size(d) / c
@@ -751,6 +744,35 @@ Value sparse_tensor::genMapBuffers(OpBuilder &builder, Location loc,
751
744
}
752
745
lvlSizesValues[l] = lvlSz;
753
746
}
747
+ // Generate lvl2dim.
748
+ assert (dimRank == lvlToDim.getNumResults ());
749
+ for (Dimension d = 0 ; d < dimRank; d++) {
750
+ AffineExpr exp = lvlToDim.getResult (d);
751
+ // We expect:
752
+ // (1) d = l
753
+ // (2) d = l' * c + l
754
+ Level l = 0 , ll = 0 ;
755
+ uint64_t c = 0 ;
756
+ switch (exp.getKind ()) {
757
+ case AffineExprKind::DimId: {
758
+ l = exp.cast <AffineDimExpr>().getPosition ();
759
+ break ;
760
+ }
761
+ case AffineExprKind::Add: {
762
+ // Always mul on lhs, symbol/constant on rhs.
763
+ auto add = exp.cast <AffineBinaryOpExpr>();
764
+ assert (add.getLHS ().getKind () == AffineExprKind::Mul);
765
+ auto mul = add.getLHS ().cast <AffineBinaryOpExpr>();
766
+ ll = mul.getLHS ().cast <AffineDimExpr>().getPosition ();
767
+ c = mul.getRHS ().cast <AffineConstantExpr>().getValue ();
768
+ l = add.getRHS ().cast <AffineDimExpr>().getPosition ();
769
+ break ;
770
+ }
771
+ default :
772
+ llvm::report_fatal_error (" unsupported lvl2dim in sparse tensor type" );
773
+ }
774
+ lvl2dimValues[d] = constantIndex (builder, loc, encodeLvl (l, c, ll));
775
+ }
754
776
// Return buffers.
755
777
dim2lvlBuffer = allocaBuffer (builder, loc, dim2lvlValues);
756
778
lvl2dimBuffer = allocaBuffer (builder, loc, lvl2dimValues);
0 commit comments