Skip to content

Commit 6a38c77

Browse files
committed
[mlir][sparse] fixed bug with unary op, dense output
Note that by sparse compiler convention, dense output is zerod out when not set, so complement results in zeros where elements were present. Reviewed By: wrengr Differential Revision: https://reviews.llvm.org/D152046
1 parent 0a16813 commit 6a38c77

File tree

2 files changed

+78
-46
lines changed

2 files changed

+78
-46
lines changed

mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp

Lines changed: 38 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1049,50 +1049,52 @@ static Value genTensorLoad(CodegenEnv &env, OpBuilder &builder, ExprId exp) {
10491049
/// Generates a store on a dense or sparse tensor.
10501050
static void genTensorStore(CodegenEnv &env, OpBuilder &builder, ExprId exp,
10511051
Value rhs) {
1052-
linalg::GenericOp op = env.op();
1053-
Location loc = op.getLoc();
1052+
// Only unary and binary are allowed to return uninitialized rhs
1053+
// to indicate missing output.
1054+
if (!rhs) {
1055+
assert(env.exp(exp).kind == TensorExp::Kind::kUnary ||
1056+
env.exp(exp).kind == TensorExp::Kind::kBinary);
1057+
return;
1058+
}
10541059
// Test if this is a scalarized reduction.
10551060
if (env.isReduc()) {
10561061
env.updateReduc(rhs);
10571062
return;
10581063
}
1059-
// Store during insertion.
1064+
// Regular store.
1065+
linalg::GenericOp op = env.op();
1066+
Location loc = op.getLoc();
10601067
OpOperand *t = op.getDpsInitOperand(0);
1061-
if (env.isSparseOutput(t)) {
1062-
if (!rhs) {
1063-
// Only unary and binary are allowed to return uninitialized rhs
1064-
// to indicate missing output.
1065-
assert(env.exp(exp).kind == TensorExp::Kind::kUnary ||
1066-
env.exp(exp).kind == TensorExp::Kind::kBinary);
1067-
} else if (env.exp(exp).kind == TensorExp::Kind::kSelect) {
1068-
// Select operation insertion.
1069-
Value chain = env.getInsertionChain();
1070-
scf::IfOp ifOp =
1071-
builder.create<scf::IfOp>(loc, chain.getType(), rhs, /*else=*/true);
1072-
builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
1073-
// Existing value was preserved to be used here.
1074-
assert(env.exp(exp).val);
1075-
Value v0 = env.exp(exp).val;
1076-
genInsertionStore(env, builder, t, v0);
1077-
env.merger().clearExprValue(exp);
1078-
// Yield modified insertion chain along true branch.
1079-
Value mchain = env.getInsertionChain();
1080-
builder.create<scf::YieldOp>(op.getLoc(), mchain);
1081-
// Yield original insertion chain along false branch.
1082-
builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
1083-
builder.create<scf::YieldOp>(loc, chain);
1084-
// Done with if statement.
1085-
env.updateInsertionChain(ifOp->getResult(0));
1086-
builder.setInsertionPointAfter(ifOp);
1087-
} else {
1088-
genInsertionStore(env, builder, t, rhs);
1089-
}
1068+
if (!env.isSparseOutput(t)) {
1069+
SmallVector<Value> args;
1070+
Value ptr = genSubscript(env, builder, t, args);
1071+
builder.create<memref::StoreOp>(loc, rhs, ptr, args);
10901072
return;
10911073
}
1092-
// Actual store.
1093-
SmallVector<Value> args;
1094-
Value ptr = genSubscript(env, builder, t, args);
1095-
builder.create<memref::StoreOp>(loc, rhs, ptr, args);
1074+
// Store during sparse insertion.
1075+
if (env.exp(exp).kind != TensorExp::Kind::kSelect) {
1076+
genInsertionStore(env, builder, t, rhs);
1077+
return;
1078+
}
1079+
// Select operation insertion.
1080+
Value chain = env.getInsertionChain();
1081+
scf::IfOp ifOp =
1082+
builder.create<scf::IfOp>(loc, chain.getType(), rhs, /*else=*/true);
1083+
builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
1084+
// Existing value was preserved to be used here.
1085+
assert(env.exp(exp).val);
1086+
Value v0 = env.exp(exp).val;
1087+
genInsertionStore(env, builder, t, v0);
1088+
env.merger().clearExprValue(exp);
1089+
// Yield modified insertion chain along true branch.
1090+
Value mchain = env.getInsertionChain();
1091+
builder.create<scf::YieldOp>(op.getLoc(), mchain);
1092+
// Yield original insertion chain along false branch.
1093+
builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
1094+
builder.create<scf::YieldOp>(loc, chain);
1095+
// Done with if statement.
1096+
env.updateInsertionChain(ifOp->getResult(0));
1097+
builder.setInsertionPointAfter(ifOp);
10961098
}
10971099

10981100
/// Generates an invariant value.

mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_unary.mlir

Lines changed: 40 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -32,14 +32,14 @@
3232
//
3333
// Traits for tensor operations.
3434
//
35-
#trait_vec_scale = {
35+
#trait_vec = {
3636
indexing_maps = [
3737
affine_map<(i) -> (i)>, // a (in)
3838
affine_map<(i) -> (i)> // x (out)
3939
],
4040
iterator_types = ["parallel"]
4141
}
42-
#trait_mat_scale = {
42+
#trait_mat = {
4343
indexing_maps = [
4444
affine_map<(i,j) -> (i,j)>, // A (in)
4545
affine_map<(i,j) -> (i,j)> // X (out)
@@ -49,13 +49,13 @@
4949

5050
module {
5151
// Invert the structure of a sparse vector. Present values become missing.
52-
// Missing values are filled with 1 (i32).
53-
func.func @vector_complement(%arga: tensor<?xf64, #SparseVector>) -> tensor<?xi32, #SparseVector> {
52+
// Missing values are filled with 1 (i32). Output is sparse.
53+
func.func @vector_complement_sparse(%arga: tensor<?xf64, #SparseVector>) -> tensor<?xi32, #SparseVector> {
5454
%c = arith.constant 0 : index
5555
%ci1 = arith.constant 1 : i32
5656
%d = tensor.dim %arga, %c : tensor<?xf64, #SparseVector>
5757
%xv = bufferization.alloc_tensor(%d) : tensor<?xi32, #SparseVector>
58-
%0 = linalg.generic #trait_vec_scale
58+
%0 = linalg.generic #trait_vec
5959
ins(%arga: tensor<?xf64, #SparseVector>)
6060
outs(%xv: tensor<?xi32, #SparseVector>) {
6161
^bb(%a: f64, %x: i32):
@@ -69,13 +69,35 @@ module {
6969
return %0 : tensor<?xi32, #SparseVector>
7070
}
7171

72+
// Invert the structure of a sparse vector, where missing values are
73+
// filled with 1. For a dense output, the sparse compiler initializes
74+
// the buffer to all zero at all other places.
75+
func.func @vector_complement_dense(%arga: tensor<?xf64, #SparseVector>) -> tensor<?xi32> {
76+
%c = arith.constant 0 : index
77+
%d = tensor.dim %arga, %c : tensor<?xf64, #SparseVector>
78+
%xv = bufferization.alloc_tensor(%d) : tensor<?xi32>
79+
%0 = linalg.generic #trait_vec
80+
ins(%arga: tensor<?xf64, #SparseVector>)
81+
outs(%xv: tensor<?xi32>) {
82+
^bb(%a: f64, %x: i32):
83+
%1 = sparse_tensor.unary %a : f64 to i32
84+
present={}
85+
absent={
86+
%ci1 = arith.constant 1 : i32
87+
sparse_tensor.yield %ci1 : i32
88+
}
89+
linalg.yield %1 : i32
90+
} -> tensor<?xi32>
91+
return %0 : tensor<?xi32>
92+
}
93+
7294
// Negate existing values. Fill missing ones with +1.
7395
func.func @vector_negation(%arga: tensor<?xf64, #SparseVector>) -> tensor<?xf64, #SparseVector> {
7496
%c = arith.constant 0 : index
7597
%cf1 = arith.constant 1.0 : f64
7698
%d = tensor.dim %arga, %c : tensor<?xf64, #SparseVector>
7799
%xv = bufferization.alloc_tensor(%d) : tensor<?xf64, #SparseVector>
78-
%0 = linalg.generic #trait_vec_scale
100+
%0 = linalg.generic #trait_vec
79101
ins(%arga: tensor<?xf64, #SparseVector>)
80102
outs(%xv: tensor<?xf64, #SparseVector>) {
81103
^bb(%a: f64, %x: f64):
@@ -98,7 +120,7 @@ module {
98120
%c = arith.constant 0 : index
99121
%d = tensor.dim %arga, %c : tensor<?xf64, #SparseVector>
100122
%xv = bufferization.alloc_tensor(%d) : tensor<?xf64, #SparseVector>
101-
%0 = linalg.generic #trait_vec_scale
123+
%0 = linalg.generic #trait_vec
102124
ins(%arga: tensor<?xf64, #SparseVector>)
103125
outs(%xv: tensor<?xf64, #SparseVector>) {
104126
^bb(%a: f64, %x: f64):
@@ -126,7 +148,7 @@ module {
126148
%d0 = tensor.dim %argx, %c0 : tensor<?x?xf64, #DCSR>
127149
%d1 = tensor.dim %argx, %c1 : tensor<?x?xf64, #DCSR>
128150
%xv = bufferization.alloc_tensor(%d0, %d1) : tensor<?x?xf64, #DCSR>
129-
%0 = linalg.generic #trait_mat_scale
151+
%0 = linalg.generic #trait_mat
130152
ins(%argx: tensor<?x?xf64, #DCSR>)
131153
outs(%xv: tensor<?x?xf64, #DCSR>) {
132154
^bb(%a: f64, %x: f64):
@@ -153,7 +175,7 @@ module {
153175
%d0 = tensor.dim %argx, %c0 : tensor<?x?xf64, #DCSR>
154176
%d1 = tensor.dim %argx, %c1 : tensor<?x?xf64, #DCSR>
155177
%xv = bufferization.alloc_tensor(%d0, %d1) : tensor<?x?xf64, #DCSR>
156-
%0 = linalg.generic #trait_mat_scale
178+
%0 = linalg.generic #trait_mat
157179
ins(%argx: tensor<?x?xf64, #DCSR>)
158180
outs(%xv: tensor<?x?xf64, #DCSR>) {
159181
^bb(%a: f64, %x: f64):
@@ -223,6 +245,7 @@ module {
223245

224246
// Driver method to call and verify vector kernels.
225247
func.func @entry() {
248+
%cmu = arith.constant -99 : i32
226249
%c0 = arith.constant 0 : index
227250

228251
// Setup sparse vectors.
@@ -240,7 +263,7 @@ module {
240263
%sm1 = sparse_tensor.convert %m1 : tensor<4x8xf64> to tensor<?x?xf64, #DCSR>
241264

242265
// Call sparse vector kernels.
243-
%0 = call @vector_complement(%sv1)
266+
%0 = call @vector_complement_sparse(%sv1)
244267
: (tensor<?xf64, #SparseVector>) -> tensor<?xi32, #SparseVector>
245268
%1 = call @vector_negation(%sv1)
246269
: (tensor<?xf64, #SparseVector>) -> tensor<?xf64, #SparseVector>
@@ -253,6 +276,9 @@ module {
253276
%4 = call @matrix_slice(%sm1)
254277
: (tensor<?x?xf64, #DCSR>) -> tensor<?x?xf64, #DCSR>
255278

279+
// Call kernel with dense output.
280+
%5 = call @vector_complement_dense(%sv1) : (tensor<?xf64, #SparseVector>) -> tensor<?xi32>
281+
256282
//
257283
// Verify the results.
258284
//
@@ -268,13 +294,16 @@ module {
268294
// CHECK-NEXT: ( ( 3, 3, 0, 0, 0, 0, 0, 0 ), ( 0, 0, 0, 0, 0, 0, 0, 3 ), ( 0, 0, 4, 0, 5, 0, 0, 6 ), ( 7, 0, 7, 7, 0, 0, 0, 0 ) )
269295
// CHECK-NEXT: ( 99, 99, 99, 99, 5, 6, 99, 99, 99, 0, 0, 0, 0, 0, 0, 0 )
270296
// CHECK-NEXT: ( ( 99, 99, 0, 0, 0, 0, 0, 0 ), ( 0, 0, 0, 0, 0, 0, 0, 99 ), ( 0, 0, 99, 0, 5, 0, 0, 6 ), ( 99, 0, 99, 99, 0, 0, 0, 0 ) )
297+
// CHECK-NEXT: ( 0, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 1, 0 )
271298
//
272299
call @dump_vec_f64(%sv1) : (tensor<?xf64, #SparseVector>) -> ()
273300
call @dump_vec_i32(%0) : (tensor<?xi32, #SparseVector>) -> ()
274301
call @dump_vec_f64(%1) : (tensor<?xf64, #SparseVector>) -> ()
275302
call @dump_vec_f64(%2) : (tensor<?xf64, #SparseVector>) -> ()
276303
call @dump_mat(%3) : (tensor<?x?xf64, #DCSR>) -> ()
277304
call @dump_mat(%4) : (tensor<?x?xf64, #DCSR>) -> ()
305+
%v = vector.transfer_read %5[%c0], %cmu: tensor<?xi32>, vector<32xi32>
306+
vector.print %v : vector<32xi32>
278307

279308
// Release the resources.
280309
bufferization.dealloc_tensor %sv1 : tensor<?xf64, #SparseVector>
@@ -284,6 +313,7 @@ module {
284313
bufferization.dealloc_tensor %2 : tensor<?xf64, #SparseVector>
285314
bufferization.dealloc_tensor %3 : tensor<?x?xf64, #DCSR>
286315
bufferization.dealloc_tensor %4 : tensor<?x?xf64, #DCSR>
316+
bufferization.dealloc_tensor %5 : tensor<?xi32>
287317
return
288318
}
289319
}

0 commit comments

Comments
 (0)