Skip to content

Commit fb8f492

Browse files
author
Peiming Liu
authored
[mlir][sparse] clone a empty sparse tensor when fuse convert into pro… (#92158)
…ducer.
1 parent 1202837 commit fb8f492

File tree

2 files changed

+51
-7
lines changed

2 files changed

+51
-7
lines changed

mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -302,17 +302,17 @@ struct FoldConvertIntoProducer : public OpRewritePattern<ConvertOp> {
302302
!producer.getResult(0).hasOneUse()) {
303303
return failure();
304304
}
305+
// Clone the materialization operation, but update the result to sparse.
306+
rewriter.setInsertionPoint(producer);
307+
Operation *init = producer.getDpsInitOperand(0)->get().getDefiningOp();
308+
Operation *cloned = rewriter.clone(*init);
309+
cloned->getResult(0).setType(op.getResult().getType());
310+
305311
rewriter.modifyOpInPlace(producer, [&]() {
312+
producer.getDpsInitsMutable().assign(cloned->getResults());
306313
producer.getResult(0).setType(op.getResult().getType());
307314
});
308315

309-
Operation *materializeOp =
310-
producer.getDpsInitOperand(0)->get().getDefiningOp();
311-
312-
rewriter.modifyOpInPlace(materializeOp, [&]() {
313-
materializeOp->getResult(0).setType(op.getResult().getType());
314-
});
315-
316316
rewriter.replaceAllOpUsesWith(op, producer);
317317
op->erase();
318318

mlir/test/Dialect/SparseTensor/fuse_sparse_convert_into_producer.mlir

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,50 @@ func.func @fold_convert(%arg0: tensor<128x32x32x1xf32>, %arg1: tensor<128x32x32x
5454
return %2 : tensor<128x32x32x1xf32, #CCCD>
5555
}
5656

57+
#trait_bin = {
58+
indexing_maps = [
59+
affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
60+
affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>,
61+
affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
62+
],
63+
iterator_types = ["parallel", "parallel", "parallel", "parallel"]
64+
}
65+
66+
// CHECK-FOLD-LABEL: func.func @fold_convert_multi_use(
67+
// CHECK-FOLD: tensor.empty() : tensor<128x32x32x1xf32>
68+
// CHECK-FOLD: linalg.generic
69+
// CHECK-FOLD: tensor.empty() : tensor<128x32x32x1xf32, #sparse>
70+
// CHECK-FOLD: linalg.generic
71+
// CHECK-FOLD-NOT: sparse_tensor.convert
72+
func.func @fold_convert_multi_use(%arg0: tensor<128x32x32x1xf32>, %arg1: tensor<128x32x32x1xf32>,
73+
%arg2: tensor<128x32x32x1xf32>, %arg3: tensor<128x32x32x1xf32>) -> (tensor<128x32x32x1xf32>, tensor<128x32x32x1xf32, #CCCD>) {
74+
%cst = arith.constant 0.000000e+00 : f32
75+
%cst_0 = arith.constant 1.000000e+00 : f32
76+
%cst_1 = arith.constant 1.000000e+00 : f32
77+
78+
%0 = tensor.empty() : tensor<128x32x32x1xf32>
79+
%1 = linalg.generic #trait_bin
80+
ins(%arg0, %arg1 : tensor<128x32x32x1xf32>, tensor<128x32x32x1xf32>)
81+
outs(%0 : tensor<128x32x32x1xf32>) {
82+
^bb0(%in: f32, %in_1: f32, %out: f32):
83+
%3 = arith.mulf %in, %in_1 : f32
84+
linalg.yield %3 : f32
85+
} -> tensor<128x32x32x1xf32>
86+
87+
// A second kernel that uses %0 as the init operand.
88+
%3 = linalg.generic #trait_bin
89+
ins(%arg2, %arg3 : tensor<128x32x32x1xf32>, tensor<128x32x32x1xf32>)
90+
outs(%0 : tensor<128x32x32x1xf32>) {
91+
^bb0(%in: f32, %in_1: f32, %out: f32):
92+
%3 = arith.mulf %in, %in_1 : f32
93+
linalg.yield %3 : f32
94+
} -> tensor<128x32x32x1xf32>
95+
%4 = sparse_tensor.convert %3 : tensor<128x32x32x1xf32> to tensor<128x32x32x1xf32, #CCCD>
96+
97+
return %1, %4 : tensor<128x32x32x1xf32>, tensor<128x32x32x1xf32, #CCCD>
98+
}
99+
100+
57101

58102
// FIXME: The following kernel is not sparsifiable because `arith.select`
59103
// operations is not handled by the sparse compiler at the moment.

0 commit comments

Comments
 (0)