@@ -639,25 +639,20 @@ Value sparse_tensor::createOrFoldSliceStrideOp(OpBuilder &builder, Location loc,
639
639
return builder.create <ToSliceStrideOp>(loc, tensor, APInt (64 , dim));
640
640
}
641
641
642
- void sparse_tensor::fillDimShape (OpBuilder &builder, Location loc,
643
- SparseTensorType stt,
644
- SmallVectorImpl<Value> &out) {
645
- out.clear ();
646
- out.reserve (stt.getDimRank ());
647
- for (const Size sz : stt.getDimShape ()) {
648
- const auto s = ShapedType::isDynamic (sz) ? 0 : sz;
649
- out.push_back (constantIndex (builder, loc, s));
650
- }
651
- }
652
-
653
642
Value sparse_tensor::genReader (OpBuilder &builder, Location loc,
654
643
SparseTensorType stt, Value tensor,
655
- /* out*/ SmallVectorImpl<Value> &dimShapesValues ,
644
+ /* out*/ SmallVectorImpl<Value> &dimSizesValues ,
656
645
/* out*/ Value &dimSizesBuffer) {
657
- // Construct the dimShapes buffer. The buffer contains the static size
658
- // per dimension, or otherwise a zero for a dynamic size.
659
- fillDimShape (builder, loc, stt, dimShapesValues);
660
- Value dimShapesBuffer = allocaBuffer (builder, loc, dimShapesValues);
646
+ // Construct the dimension **shapes** buffer. The buffer contains the static
647
+ // size per dimension, or otherwise a zero for a dynamic size.
648
+ Dimension dimRank = stt.getDimRank ();
649
+ dimSizesValues.clear ();
650
+ dimSizesValues.reserve (dimRank);
651
+ for (const Size sz : stt.getDimShape ()) {
652
+ const auto s = ShapedType::isDynamic (sz) ? 0 : sz;
653
+ dimSizesValues.push_back (constantIndex (builder, loc, s));
654
+ }
655
+ Value dimShapesBuffer = allocaBuffer (builder, loc, dimSizesValues);
661
656
// Create the `CheckedSparseTensorReader`. This reader performs a
662
657
// consistency check on the static sizes, but accepts any size
663
658
// of each dimension with a dynamic size.
@@ -679,29 +674,40 @@ Value sparse_tensor::genReader(OpBuilder &builder, Location loc,
679
674
createFuncCall (builder, loc, " getSparseTensorReaderDimSizes" , memTp,
680
675
reader, EmitCInterface::On)
681
676
.getResult (0 );
677
+ // Also convert the dim shapes values into dim sizes values, just in case
678
+ // subsequent clients need the values (DCE will remove unused).
679
+ for (Dimension d = 0 ; d < dimRank; d++) {
680
+ if (stt.isDynamicDim (d))
681
+ dimSizesValues[d] = builder.create <memref::LoadOp>(
682
+ loc, dimSizesBuffer, constantIndex (builder, loc, d));
683
+ }
682
684
}
683
685
return reader;
684
686
}
685
687
686
- Value sparse_tensor::genMapBuffers (OpBuilder &builder, Location loc,
687
- SparseTensorType stt,
688
- ArrayRef<Value> dimShapesValues ,
689
- Value dimSizesBuffer ,
690
- /* out*/ Value &dim2lvlBuffer,
691
- /* out*/ Value &lvl2dimBuffer) {
688
+ Value sparse_tensor::genMapBuffers (
689
+ OpBuilder &builder, Location loc, SparseTensorType stt,
690
+ ArrayRef<Value> dimSizesValues, Value dimSizesBuffer ,
691
+ /* out */ SmallVectorImpl< Value> &lvlSizesValues ,
692
+ /* out*/ Value &dim2lvlBuffer,
693
+ /* out*/ Value &lvl2dimBuffer) {
692
694
const Dimension dimRank = stt.getDimRank ();
693
695
const Level lvlRank = stt.getLvlRank ();
696
+ lvlSizesValues.clear ();
697
+ lvlSizesValues.reserve (lvlRank);
694
698
// For an identity mapping, the dim2lvl and lvl2dim mappings are
695
699
// identical as are dimSizes and lvlSizes, so buffers are reused
696
700
// as much as possible.
697
701
if (stt.isIdentity ()) {
698
702
assert (dimRank == lvlRank);
699
703
SmallVector<Value> iotaValues;
700
704
iotaValues.reserve (lvlRank);
701
- for (Level l = 0 ; l < lvlRank; l++)
705
+ for (Level l = 0 ; l < lvlRank; l++) {
702
706
iotaValues.push_back (constantIndex (builder, loc, l));
707
+ lvlSizesValues.push_back (dimSizesValues[l]);
708
+ }
703
709
dim2lvlBuffer = lvl2dimBuffer = allocaBuffer (builder, loc, iotaValues);
704
- return dimSizesBuffer;
710
+ return dimSizesBuffer; // now lvlSizesBuffer
705
711
}
706
712
// Otherwise, some code needs to be generated to set up the buffers.
707
713
// This code deals with permutations as well as non-permutations that
@@ -710,7 +716,6 @@ Value sparse_tensor::genMapBuffers(OpBuilder &builder, Location loc,
710
716
const auto lvlToDim = stt.getLvlToDim ();
711
717
SmallVector<Value> dim2lvlValues (lvlRank); // for each lvl, expr in dim vars
712
718
SmallVector<Value> lvl2dimValues (dimRank); // for each dim, expr in lvl vars
713
- SmallVector<Value> lvlSizesValues (lvlRank);
714
719
// Generate dim2lvl.
715
720
assert (lvlRank == dimToLvl.getNumResults ());
716
721
for (Level l = 0 ; l < lvlRank; l++) {
@@ -748,17 +753,14 @@ Value sparse_tensor::genMapBuffers(OpBuilder &builder, Location loc,
748
753
// (3) l = d % c : c
749
754
Value lvlSz;
750
755
if (cm == 0 ) {
751
- lvlSz = dimShapesValues[d];
752
- if (stt.isDynamicDim (d))
753
- lvlSz = builder.create <memref::LoadOp>(loc, dimSizesBuffer,
754
- constantIndex (builder, loc, d));
756
+ lvlSz = dimSizesValues[d];
755
757
if (cf != 0 )
756
758
lvlSz = builder.create <arith::DivUIOp>(loc, lvlSz,
757
759
constantIndex (builder, loc, cf));
758
760
} else {
759
761
lvlSz = constantIndex (builder, loc, cm);
760
762
}
761
- lvlSizesValues[l] = lvlSz;
763
+ lvlSizesValues. push_back ( lvlSz) ;
762
764
}
763
765
// Generate lvl2dim.
764
766
assert (dimRank == lvlToDim.getNumResults ());
@@ -792,5 +794,5 @@ Value sparse_tensor::genMapBuffers(OpBuilder &builder, Location loc,
792
794
// Return buffers.
793
795
dim2lvlBuffer = allocaBuffer (builder, loc, dim2lvlValues);
794
796
lvl2dimBuffer = allocaBuffer (builder, loc, lvl2dimValues);
795
- return allocaBuffer (builder, loc, lvlSizesValues);
797
+ return allocaBuffer (builder, loc, lvlSizesValues); // lvlSizesBuffer
796
798
}
0 commit comments