Skip to content

Commit 71c97c7

Browse files
authored
[mlir][sparse] avoid tensor to memref conversion in sparse tensor rewri… (#69362)
…ting rules.
1 parent fd31112 commit 71c97c7

File tree

3 files changed

+132
-158
lines changed

3 files changed

+132
-158
lines changed

mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp

Lines changed: 41 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -829,47 +829,40 @@ struct ReshapeRewriter : public OpRewritePattern<ReshapeOp> {
829829
}
830830
};
831831

832+
// A trivial wrapper to help generate different operations for dense/sparse
833+
// tensors.
832834
struct TensorLike {
833835
TensorLike(OpBuilder &builder, Location loc, RankedTensorType rtt,
834-
ValueRange sizes)
835-
: isSparse(rtt.getEncoding() != nullptr) {
836+
ValueRange sizes) {
836837
SmallVector<Value> dynSzs;
837838
getDynamicSizes(rtt, sizes, dynSzs);
838839

839-
if (isSparse)
840-
val = builder.create<AllocTensorOp>(loc, rtt, dynSzs);
841-
else
842-
val = allocDenseTensor(builder, loc, rtt, sizes);
843-
};
844-
845-
void insertOrStore(OpBuilder &builder, Location loc, Value v,
846-
ValueRange crds) {
847-
if (isSparse)
848-
val = builder.create<InsertOp>(loc, v, val, crds);
849-
else
850-
builder.create<memref::StoreOp>(loc, v, val, crds);
840+
val = builder.create<AllocTensorOp>(loc, rtt, dynSzs);
841+
if (!isSparse()) {
842+
Value c0 = constantZero(builder, loc, rtt.getElementType());
843+
val = builder.create<linalg::FillOp>(loc, c0, val).getResult(0);
844+
}
851845
}
852846

853-
Value getSSA() const {
854-
// We don't need to maintain the SSA chain for a memref value.
855-
return isSparse ? val : nullptr;
847+
void insert(OpBuilder &builder, Location loc, Value v, ValueRange crds) {
848+
// TODO: Unify these two.
849+
if (isSparse())
850+
val = builder.create<sparse_tensor::InsertOp>(loc, v, val, crds);
851+
else
852+
val = builder.create<tensor::InsertOp>(loc, v, val, crds);
856853
}
857854

858855
Value finalize(OpBuilder &builder, Location loc, RankedTensorType rtp) const {
859-
if (isSparse)
856+
if (isSparse())
860857
return builder.create<LoadOp>(loc, val, true);
861-
return builder.create<bufferization::ToTensorOp>(loc, rtp, val);
858+
return val;
862859
}
863860

864-
void updateSSA(Value v) {
865-
// Dense memref is a non-SSA value.
866-
assert(isSparse);
867-
val = v;
861+
bool isSparse() const {
862+
return getSparseTensorEncoding(val.getType()) != nullptr;
868863
}
869864

870-
private:
871-
bool isSparse;
872-
Value val; // either a memref (for dense tensor) or a sparse tensor.
865+
Value val;
873866
};
874867

875868
struct ConcatenateRewriter : public OpRewritePattern<ConcatenateOp> {
@@ -901,14 +894,14 @@ struct ConcatenateRewriter : public OpRewritePattern<ConcatenateOp> {
901894

902895
TensorLike dstBuf(rewriter, loc, dstTp.getRankedTensorType(), sizes);
903896
Value offset = constantIndex(rewriter, loc, 0);
904-
Value iterArg = dstBuf.getSSA();
897+
Value iterArg = dstBuf.val;
905898

906899
ForeachOp foreachOp;
907900
for (Value input : op.getInputs()) {
908901
// Builds a for op for each input tensor to append new values into the
909902
// output tensor.
910903
foreachOp = rewriter.create<ForeachOp>(
911-
loc, input, iterArg ? ValueRange{iterArg} : ValueRange{},
904+
loc, input, iterArg,
912905
[&](OpBuilder &builder, Location loc, ValueRange dcvs, Value v,
913906
ValueRange reduc) {
914907
SmallVector<Value> dstLcvs(dstTp.getLvlRank());
@@ -920,32 +913,26 @@ struct ConcatenateRewriter : public OpRewritePattern<ConcatenateOp> {
920913
// FIXME: `toStoredDim` is deprecated
921914
dstLcvs[toStoredDim(dstTp.getEncoding(), d)] = crd;
922915
}
923-
924-
if (!reduc.empty())
925-
dstBuf.updateSSA(reduc.front());
926-
916+
// Enters foreach, updates the SSA chain.
917+
dstBuf.val = reduc.front();
927918
if (!dstTp.isAllDense()) {
928919
Value cond = genIsNonzero(builder, loc, v);
929920
auto ifOp = builder.create<scf::IfOp>(loc, reduc.getTypes(), cond,
930921
/*else*/ true);
931922
builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
932-
builder.create<scf::YieldOp>(loc, dstBuf.getSSA());
923+
builder.create<scf::YieldOp>(loc, dstBuf.val);
933924

934925
builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
935-
dstBuf.insertOrStore(builder, loc, v, dstLcvs);
936-
builder.create<scf::YieldOp>(loc, dstBuf.getSSA());
926+
dstBuf.insert(builder, loc, v, dstLcvs);
927+
builder.create<scf::YieldOp>(loc, dstBuf.val);
937928

938929
// Exits the ifOp, update the sparse tensor SSA value.
939930
builder.setInsertionPointAfter(ifOp);
940-
assert(!reduc.empty());
941-
dstBuf.updateSSA(ifOp.getResult(0));
931+
dstBuf.val = ifOp.getResult(0);
942932
} else {
943-
dstBuf.insertOrStore(builder, loc, v, dstLcvs);
933+
dstBuf.insert(builder, loc, v, dstLcvs);
944934
}
945-
if (reduc.empty())
946-
builder.create<sparse_tensor::YieldOp>(loc);
947-
else
948-
builder.create<sparse_tensor::YieldOp>(loc, dstBuf.getSSA());
935+
builder.create<sparse_tensor::YieldOp>(loc, dstBuf.val);
949936
});
950937
// Accumulates the offset. Note that only static-shaped inputs are allowed
951938
// by concatenate op verifier, which saves us from computing the offset
@@ -955,15 +942,11 @@ struct ConcatenateRewriter : public OpRewritePattern<ConcatenateOp> {
955942
offset = rewriter.create<arith::AddIOp>(
956943
loc, offset, constantIndex(rewriter, loc, *sh));
957944

958-
if (!foreachOp.getResults().empty()) {
959-
iterArg = foreachOp.getResult(0);
960-
dstBuf.updateSSA(iterArg);
961-
}
945+
iterArg = foreachOp.getResult(0);
946+
dstBuf.val = iterArg;
962947
}
963948

964-
if (!foreachOp.getResults().empty())
965-
dstBuf.updateSSA(iterArg);
966-
949+
dstBuf.val = iterArg;
967950
Value ret = dstBuf.finalize(rewriter, loc, dstTp.getRankedTensorType());
968951
rewriter.replaceOp(op, ret);
969952
return success();
@@ -1010,15 +993,12 @@ struct DirectConvertRewriter : public OpRewritePattern<ConvertOp> {
1010993
ValueRange vs;
1011994
TensorLike dstBuf(rewriter, loc, dstStt.getRankedTensorType(), sizes);
1012995

1013-
Value iterArg = dstBuf.getSSA();
1014996
auto foreachOp = rewriter.create<ForeachOp>(
1015-
loc, src, iterArg ? ValueRange{iterArg} : ValueRange{}, foreachOrder,
997+
loc, src, dstBuf.val, foreachOrder,
1016998
[&](OpBuilder &builder, Location loc, ValueRange dcvs, Value v,
1017999
ValueRange reduc) {
10181000
// Enters the loop, update the SSA value for insertion chain.
1019-
if (!reduc.empty())
1020-
dstBuf.updateSSA(reduc.front());
1021-
1001+
dstBuf.val = reduc.front();
10221002
const Dimension dimRank = dstStt.getDimRank();
10231003
const Level lvlRank = dstStt.getLvlRank();
10241004
SmallVector<Value> lcvs(lvlRank);
@@ -1028,34 +1008,29 @@ struct DirectConvertRewriter : public OpRewritePattern<ConvertOp> {
10281008
}
10291009

10301010
if (!skipZeroCheck) {
1031-
assert(!reduc.empty());
10321011
Value cond = genIsNonzero(builder, loc, v);
10331012
auto ifOp = builder.create<scf::IfOp>(loc, reduc.getTypes(), cond,
10341013
/*else*/ true);
10351014
builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
1036-
builder.create<scf::YieldOp>(loc, dstBuf.getSSA());
1015+
builder.create<scf::YieldOp>(loc, dstBuf.val);
10371016

10381017
builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
1039-
dstBuf.insertOrStore(builder, loc, v, lcvs);
1040-
builder.create<scf::YieldOp>(loc, dstBuf.getSSA());
1018+
dstBuf.insert(builder, loc, v, lcvs);
1019+
builder.create<scf::YieldOp>(loc, dstBuf.val);
10411020

10421021
// Exits the ifOp, update the sparse tensor SSA value.
10431022
builder.setInsertionPointAfter(ifOp);
1044-
dstBuf.updateSSA(ifOp.getResult(0));
1023+
dstBuf.val = ifOp.getResult(0);
10451024
} else {
1046-
dstBuf.insertOrStore(builder, loc, v, lcvs);
1025+
dstBuf.insert(builder, loc, v, lcvs);
10471026
}
1048-
if (reduc.empty())
1049-
builder.create<sparse_tensor::YieldOp>(loc);
1050-
else
1051-
builder.create<sparse_tensor::YieldOp>(loc, dstBuf.getSSA());
1027+
builder.create<sparse_tensor::YieldOp>(loc, dstBuf.val);
10521028
});
10531029

10541030
rewriter.setInsertionPointAfter(foreachOp);
10551031

10561032
// Exits the for loop, links the SSA chain.
1057-
if (!foreachOp.getResults().empty())
1058-
dstBuf.updateSSA(foreachOp.getResult(0));
1033+
dstBuf.val = foreachOp.getResult(0);
10591034

10601035
Value ret = dstBuf.finalize(rewriter, loc, dstStt.getRankedTensorType());
10611036
rewriter.replaceOp(op, ret);

mlir/test/Dialect/SparseTensor/convert_sparse2dense.mlir

Lines changed: 14 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -14,83 +14,76 @@
1414

1515
// CHECK-LABEL: func.func @sparse_convert_1d
1616
// CHECK-NOT: sparse_tensor.reorder_coo
17-
// CHECK: memref.alloc
17+
// CHECK: bufferization.alloc_tensor
1818
// CHECK: linalg.fill
1919
// CHECK: sparse_tensor.foreach
20-
// CHECK: memref.store
21-
// CHECK: bufferization.to_tensor
20+
// CHECK: tensor.insert
2221
func.func @sparse_convert_1d(%arg0: tensor<13xi32, #SparseVector>) -> tensor<13xi32> {
2322
%0 = sparse_tensor.convert %arg0 : tensor<13xi32, #SparseVector> to tensor<13xi32>
2423
return %0 : tensor<13xi32>
2524
}
2625

2726
// CHECK-LABEL: func.func @sparse_convert_1d_dyn
2827
// CHECK-NOT: sparse_tensor.reorder_coo
29-
// CHECK: memref.alloc
28+
// CHECK: bufferization.alloc_tensor
3029
// CHECK: linalg.fill
3130
// CHECK: sparse_tensor.foreach
32-
// CHECK: memref.store
33-
// CHECK: bufferization.to_tensor
31+
// CHECK: tensor.insert
3432
func.func @sparse_convert_1d_dyn(%arg0: tensor<?xi32, #SparseVector>) -> tensor<?xi32> {
3533
%0 = sparse_tensor.convert %arg0 : tensor<?xi32, #SparseVector> to tensor<?xi32>
3634
return %0 : tensor<?xi32>
3735
}
3836

3937
// CHECK-LABEL: func.func @sparse_convert_2d
4038
// CHECK-NOT: sparse_tensor.reorder_coo
41-
// CHECK: memref.alloc
39+
// CHECK: bufferization.alloc_tensor
4240
// CHECK: linalg.fill
4341
// CHECK: sparse_tensor.foreach
44-
// CHECK: memref.store
45-
// CHECK: bufferization.to_tensor
42+
// CHECK: tensor.insert
4643
func.func @sparse_convert_2d(%arg0: tensor<2x4xf64, #SparseMatrix>) -> tensor<2x4xf64> {
4744
%0 = sparse_tensor.convert %arg0 : tensor<2x4xf64, #SparseMatrix> to tensor<2x4xf64>
4845
return %0 : tensor<2x4xf64>
4946
}
5047

5148
// CHECK-LABEL: func.func @sparse_convert_2d_dyn
5249
// CHECK-NOT: sparse_tensor.reorder_coo
53-
// CHECK: memref.alloc
50+
// CHECK: bufferization.alloc_tensor
5451
// CHECK: linalg.fill
5552
// CHECK: sparse_tensor.foreach
56-
// CHECK: memref.store
57-
// CHECK: bufferization.to_tensor
53+
// CHECK: tensor.insert
5854
func.func @sparse_convert_2d_dyn0(%arg0: tensor<?x4xf64, #SparseMatrix>) -> tensor<?x4xf64> {
5955
%0 = sparse_tensor.convert %arg0 : tensor<?x4xf64, #SparseMatrix> to tensor<?x4xf64>
6056
return %0 : tensor<?x4xf64>
6157
}
6258

6359
// CHECK-LABEL: func.func @sparse_convert_2d_dyn1
6460
// CHECK-NOT: sparse_tensor.reorder_coo
65-
// CHECK: memref.alloc
61+
// CHECK: bufferization.alloc_tensor
6662
// CHECK: linalg.fill
6763
// CHECK: sparse_tensor.foreach
68-
// CHECK: memref.store
69-
// CHECK: bufferization.to_tensor
64+
// CHECK: tensor.insert
7065
func.func @sparse_convert_2d_dyn1(%arg0: tensor<2x?xf64, #SparseMatrix>) -> tensor<2x?xf64> {
7166
%0 = sparse_tensor.convert %arg0 : tensor<2x?xf64, #SparseMatrix> to tensor<2x?xf64>
7267
return %0 : tensor<2x?xf64>
7368
}
7469

7570
// CHECK-LABEL: func.func @sparse_convert_2d_dyn2
7671
// CHECK-NOT: sparse_tensor.reorder_coo
77-
// CHECK: memref.alloc
72+
// CHECK: bufferization.alloc_tensor
7873
// CHECK: linalg.fill
7974
// CHECK: sparse_tensor.foreach
80-
// CHECK: memref.store
81-
// CHECK: bufferization.to_tensor
75+
// CHECK: tensor.insert
8276
func.func @sparse_convert_2d_dyn2(%arg0: tensor<?x?xf64, #SparseMatrix>) -> tensor<?x?xf64> {
8377
%0 = sparse_tensor.convert %arg0 : tensor<?x?xf64, #SparseMatrix> to tensor<?x?xf64>
8478
return %0 : tensor<?x?xf64>
8579
}
8680

8781
// CHECK-LABEL: func.func @sparse_convert_3d
8882
// CHECK-NOT: sparse_tensor.reorder_coo
89-
// CHECK: memref.alloc
83+
// CHECK: bufferization.alloc_tensor
9084
// CHECK: linalg.fill
9185
// CHECK: sparse_tensor.foreach
92-
// CHECK: memref.store
93-
// CHECK: bufferization.to_tensor
86+
// CHECK: tensor.insert
9487
func.func @sparse_convert_3d(%arg0: tensor<2x3x4xf64, #SparseTensor>) -> tensor<2x3x4xf64> {
9588
%0 = sparse_tensor.convert %arg0 : tensor<2x3x4xf64, #SparseTensor> to tensor<2x3x4xf64>
9689
return %0 : tensor<2x3x4xf64>

0 commit comments

Comments
 (0)