@@ -829,47 +829,40 @@ struct ReshapeRewriter : public OpRewritePattern<ReshapeOp> {
829
829
}
830
830
};
831
831
832
+ // A trivial wrapper to help generate different operations for dense/sparse
833
+ // tensors.
832
834
struct TensorLike {
833
835
TensorLike (OpBuilder &builder, Location loc, RankedTensorType rtt,
834
- ValueRange sizes)
835
- : isSparse(rtt.getEncoding() != nullptr ) {
836
+ ValueRange sizes) {
836
837
SmallVector<Value> dynSzs;
837
838
getDynamicSizes (rtt, sizes, dynSzs);
838
839
839
- if (isSparse)
840
- val = builder.create <AllocTensorOp>(loc, rtt, dynSzs);
841
- else
842
- val = allocDenseTensor (builder, loc, rtt, sizes);
843
- };
844
-
845
- void insertOrStore (OpBuilder &builder, Location loc, Value v,
846
- ValueRange crds) {
847
- if (isSparse)
848
- val = builder.create <InsertOp>(loc, v, val, crds);
849
- else
850
- builder.create <memref::StoreOp>(loc, v, val, crds);
840
+ val = builder.create <AllocTensorOp>(loc, rtt, dynSzs);
841
+ if (!isSparse ()) {
842
+ Value c0 = constantZero (builder, loc, rtt.getElementType ());
843
+ val = builder.create <linalg::FillOp>(loc, c0, val).getResult (0 );
844
+ }
851
845
}
852
846
853
- Value getSSA () const {
854
- // We don't need to maintain the SSA chain for a memref value.
855
- return isSparse ? val : nullptr ;
847
+ void insert (OpBuilder &builder, Location loc, Value v, ValueRange crds) {
848
+ // TODO: Unify these two.
849
+ if (isSparse ())
850
+ val = builder.create <sparse_tensor::InsertOp>(loc, v, val, crds);
851
+ else
852
+ val = builder.create <tensor::InsertOp>(loc, v, val, crds);
856
853
}
857
854
858
855
Value finalize (OpBuilder &builder, Location loc, RankedTensorType rtp) const {
859
- if (isSparse)
856
+ if (isSparse () )
860
857
return builder.create <LoadOp>(loc, val, true );
861
- return builder. create <bufferization::ToTensorOp>(loc, rtp, val) ;
858
+ return val;
862
859
}
863
860
864
- void updateSSA (Value v) {
865
- // Dense memref is a non-SSA value.
866
- assert (isSparse);
867
- val = v;
861
+ bool isSparse () const {
862
+ return getSparseTensorEncoding (val.getType ()) != nullptr ;
868
863
}
869
864
870
- private:
871
- bool isSparse;
872
- Value val; // either a memref (for dense tensor) or a sparse tensor.
865
+ Value val;
873
866
};
874
867
875
868
struct ConcatenateRewriter : public OpRewritePattern <ConcatenateOp> {
@@ -901,14 +894,14 @@ struct ConcatenateRewriter : public OpRewritePattern<ConcatenateOp> {
901
894
902
895
TensorLike dstBuf (rewriter, loc, dstTp.getRankedTensorType (), sizes);
903
896
Value offset = constantIndex (rewriter, loc, 0 );
904
- Value iterArg = dstBuf.getSSA () ;
897
+ Value iterArg = dstBuf.val ;
905
898
906
899
ForeachOp foreachOp;
907
900
for (Value input : op.getInputs ()) {
908
901
// Builds a for op for each input tensor to append new values into the
909
902
// output tensor.
910
903
foreachOp = rewriter.create <ForeachOp>(
911
- loc, input, iterArg ? ValueRange{iterArg} : ValueRange{} ,
904
+ loc, input, iterArg,
912
905
[&](OpBuilder &builder, Location loc, ValueRange dcvs, Value v,
913
906
ValueRange reduc) {
914
907
SmallVector<Value> dstLcvs (dstTp.getLvlRank ());
@@ -920,32 +913,26 @@ struct ConcatenateRewriter : public OpRewritePattern<ConcatenateOp> {
920
913
// FIXME: `toStoredDim` is deprecated
921
914
dstLcvs[toStoredDim (dstTp.getEncoding (), d)] = crd;
922
915
}
923
-
924
- if (!reduc.empty ())
925
- dstBuf.updateSSA (reduc.front ());
926
-
916
+ // Enters foreach, updates the SSA chain.
917
+ dstBuf.val = reduc.front ();
927
918
if (!dstTp.isAllDense ()) {
928
919
Value cond = genIsNonzero (builder, loc, v);
929
920
auto ifOp = builder.create <scf::IfOp>(loc, reduc.getTypes (), cond,
930
921
/* else*/ true );
931
922
builder.setInsertionPointToStart (&ifOp.getElseRegion ().front ());
932
- builder.create <scf::YieldOp>(loc, dstBuf.getSSA () );
923
+ builder.create <scf::YieldOp>(loc, dstBuf.val );
933
924
934
925
builder.setInsertionPointToStart (&ifOp.getThenRegion ().front ());
935
- dstBuf.insertOrStore (builder, loc, v, dstLcvs);
936
- builder.create <scf::YieldOp>(loc, dstBuf.getSSA () );
926
+ dstBuf.insert (builder, loc, v, dstLcvs);
927
+ builder.create <scf::YieldOp>(loc, dstBuf.val );
937
928
938
929
// Exits the ifOp, update the sparse tensor SSA value.
939
930
builder.setInsertionPointAfter (ifOp);
940
- assert (!reduc.empty ());
941
- dstBuf.updateSSA (ifOp.getResult (0 ));
931
+ dstBuf.val = ifOp.getResult (0 );
942
932
} else {
943
- dstBuf.insertOrStore (builder, loc, v, dstLcvs);
933
+ dstBuf.insert (builder, loc, v, dstLcvs);
944
934
}
945
- if (reduc.empty ())
946
- builder.create <sparse_tensor::YieldOp>(loc);
947
- else
948
- builder.create <sparse_tensor::YieldOp>(loc, dstBuf.getSSA ());
935
+ builder.create <sparse_tensor::YieldOp>(loc, dstBuf.val );
949
936
});
950
937
// Accumulates the offset. Note that only static-shaped inputs are allowed
951
938
// by concatenate op verifier, which saves us from computing the offset
@@ -955,15 +942,11 @@ struct ConcatenateRewriter : public OpRewritePattern<ConcatenateOp> {
955
942
offset = rewriter.create <arith::AddIOp>(
956
943
loc, offset, constantIndex (rewriter, loc, *sh));
957
944
958
- if (!foreachOp.getResults ().empty ()) {
959
- iterArg = foreachOp.getResult (0 );
960
- dstBuf.updateSSA (iterArg);
961
- }
945
+ iterArg = foreachOp.getResult (0 );
946
+ dstBuf.val = iterArg;
962
947
}
963
948
964
- if (!foreachOp.getResults ().empty ())
965
- dstBuf.updateSSA (iterArg);
966
-
949
+ dstBuf.val = iterArg;
967
950
Value ret = dstBuf.finalize (rewriter, loc, dstTp.getRankedTensorType ());
968
951
rewriter.replaceOp (op, ret);
969
952
return success ();
@@ -1010,15 +993,12 @@ struct DirectConvertRewriter : public OpRewritePattern<ConvertOp> {
1010
993
ValueRange vs;
1011
994
TensorLike dstBuf (rewriter, loc, dstStt.getRankedTensorType (), sizes);
1012
995
1013
- Value iterArg = dstBuf.getSSA ();
1014
996
auto foreachOp = rewriter.create <ForeachOp>(
1015
- loc, src, iterArg ? ValueRange{iterArg} : ValueRange{} , foreachOrder,
997
+ loc, src, dstBuf. val , foreachOrder,
1016
998
[&](OpBuilder &builder, Location loc, ValueRange dcvs, Value v,
1017
999
ValueRange reduc) {
1018
1000
// Enters the loop, update the SSA value for insertion chain.
1019
- if (!reduc.empty ())
1020
- dstBuf.updateSSA (reduc.front ());
1021
-
1001
+ dstBuf.val = reduc.front ();
1022
1002
const Dimension dimRank = dstStt.getDimRank ();
1023
1003
const Level lvlRank = dstStt.getLvlRank ();
1024
1004
SmallVector<Value> lcvs (lvlRank);
@@ -1028,34 +1008,29 @@ struct DirectConvertRewriter : public OpRewritePattern<ConvertOp> {
1028
1008
}
1029
1009
1030
1010
if (!skipZeroCheck) {
1031
- assert (!reduc.empty ());
1032
1011
Value cond = genIsNonzero (builder, loc, v);
1033
1012
auto ifOp = builder.create <scf::IfOp>(loc, reduc.getTypes (), cond,
1034
1013
/* else*/ true );
1035
1014
builder.setInsertionPointToStart (&ifOp.getElseRegion ().front ());
1036
- builder.create <scf::YieldOp>(loc, dstBuf.getSSA () );
1015
+ builder.create <scf::YieldOp>(loc, dstBuf.val );
1037
1016
1038
1017
builder.setInsertionPointToStart (&ifOp.getThenRegion ().front ());
1039
- dstBuf.insertOrStore (builder, loc, v, lcvs);
1040
- builder.create <scf::YieldOp>(loc, dstBuf.getSSA () );
1018
+ dstBuf.insert (builder, loc, v, lcvs);
1019
+ builder.create <scf::YieldOp>(loc, dstBuf.val );
1041
1020
1042
1021
// Exits the ifOp, update the sparse tensor SSA value.
1043
1022
builder.setInsertionPointAfter (ifOp);
1044
- dstBuf.updateSSA ( ifOp.getResult (0 ) );
1023
+ dstBuf.val = ifOp.getResult (0 );
1045
1024
} else {
1046
- dstBuf.insertOrStore (builder, loc, v, lcvs);
1025
+ dstBuf.insert (builder, loc, v, lcvs);
1047
1026
}
1048
- if (reduc.empty ())
1049
- builder.create <sparse_tensor::YieldOp>(loc);
1050
- else
1051
- builder.create <sparse_tensor::YieldOp>(loc, dstBuf.getSSA ());
1027
+ builder.create <sparse_tensor::YieldOp>(loc, dstBuf.val );
1052
1028
});
1053
1029
1054
1030
rewriter.setInsertionPointAfter (foreachOp);
1055
1031
1056
1032
// Exits the for loop, links the SSA chain.
1057
- if (!foreachOp.getResults ().empty ())
1058
- dstBuf.updateSSA (foreachOp.getResult (0 ));
1033
+ dstBuf.val = foreachOp.getResult (0 );
1059
1034
1060
1035
Value ret = dstBuf.finalize (rewriter, loc, dstStt.getRankedTensorType ());
1061
1036
rewriter.replaceOp (op, ret);
0 commit comments