Skip to content

Commit f82bee1

Browse files
authored
[mlir][sparse] split post-sparsification-rewriting into two passes. (llvm#70727)
1 parent b1c59b5 commit f82bee1

17 files changed

+92
-55
lines changed

mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -114,17 +114,23 @@ void populateStageSparseOperationsPatterns(RewritePatternSet &patterns);
114114
std::unique_ptr<Pass> createStageSparseOperationsPass();
115115

116116
//===----------------------------------------------------------------------===//
117-
// The PostSparsificationRewriting pass.
117+
// The LowerSparseOpsToForeach pass.
118118
//===----------------------------------------------------------------------===//
119119

120-
void populatePostSparsificationRewriting(RewritePatternSet &patterns,
121-
bool enableRT, bool enableForeach,
122-
bool enableConvert);
120+
void populateLowerSparseOpsToForeachPatterns(RewritePatternSet &patterns,
121+
bool enableRT, bool enableConvert);
123122

124-
std::unique_ptr<Pass> createPostSparsificationRewritePass();
125-
std::unique_ptr<Pass>
126-
createPostSparsificationRewritePass(bool enableRT, bool enableForeach = true,
127-
bool enableConvert = true);
123+
std::unique_ptr<Pass> createLowerSparseOpsToForeachPass();
124+
std::unique_ptr<Pass> createLowerSparseOpsToForeachPass(bool enableRT,
125+
bool enableConvert);
126+
127+
//===----------------------------------------------------------------------===//
128+
// The LowerForeachToSCF pass.
129+
//===----------------------------------------------------------------------===//
130+
131+
void populateLowerForeachToSCFPatterns(RewritePatternSet &patterns);
132+
133+
std::unique_ptr<Pass> createLowerForeachToSCFPass();
128134

129135
//===----------------------------------------------------------------------===//
130136
// The SparseTensorConversion pass.

mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -167,13 +167,12 @@ def StageSparseOperations : Pass<"stage-sparse-ops", "func::FuncOp"> {
167167
];
168168
}
169169

170-
def PostSparsificationRewrite : Pass<"post-sparsification-rewrite", "ModuleOp"> {
170+
def LowerSparseOpsToForeach : Pass<"lower-sparse-ops-to-foreach", "ModuleOp"> {
171171
let summary = "Applies sparse tensor rewriting rules after sparsification";
172172
let description = [{
173-
A pass that applies rewriting rules to sparse tensor operations after
174-
running the actual sparsification pass.
173+
A pass that lowers high-level sparse operations to sparse_tensor.foreach.
175174
}];
176-
let constructor = "mlir::createPostSparsificationRewritePass()";
175+
let constructor = "mlir::createLowerSparseOpsToForeachPass()";
177176
let dependentDialects = [
178177
"affine::AffineDialect",
179178
"arith::ArithDialect",
@@ -186,13 +185,25 @@ def PostSparsificationRewrite : Pass<"post-sparsification-rewrite", "ModuleOp">
186185
let options = [
187186
Option<"enableRuntimeLibrary", "enable-runtime-library", "bool",
188187
"true", "Enable runtime library for manipulating sparse tensors">,
189-
Option<"enableForeach", "enable-foreach", "bool",
190-
"true", "Enable rewriting rules for the foreach operator">,
191188
Option<"enableConvert", "enable-convert", "bool",
192189
"true", "Enable rewriting rules for the convert operator">,
193190
];
194191
}
195192

193+
def LowerForeachToSCF : Pass<"lower-sparse-foreach-to-scf", "func::FuncOp"> {
194+
let summary = "Decompose a complex sparse operation into multiple stages";
195+
let description = [{
196+
A pass that lowers sparse_tensor.foreach operation to scf dialect.
197+
}];
198+
let constructor = "mlir::createLowerForeachToSCFPass()";
199+
let dependentDialects = [
200+
"memref::MemRefDialect",
201+
"scf::SCFDialect",
202+
"sparse_tensor::SparseTensorDialect",
203+
];
204+
}
205+
206+
196207
def SparseTensorConversionPass : Pass<"sparse-tensor-conversion", "ModuleOp"> {
197208
let summary = "Convert sparse tensors and primitives to library calls";
198209
let description = [{

mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp

Lines changed: 30 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,8 @@ namespace mlir {
2525
#define GEN_PASS_DEF_SPARSEREINTERPRETMAP
2626
#define GEN_PASS_DEF_PRESPARSIFICATIONREWRITE
2727
#define GEN_PASS_DEF_SPARSIFICATIONPASS
28-
#define GEN_PASS_DEF_POSTSPARSIFICATIONREWRITE
28+
#define GEN_PASS_DEF_LOWERSPARSEOPSTOFOREACH
29+
#define GEN_PASS_DEF_LOWERFOREACHTOSCF
2930
#define GEN_PASS_DEF_SPARSETENSORCONVERSIONPASS
3031
#define GEN_PASS_DEF_SPARSETENSORCODEGEN
3132
#define GEN_PASS_DEF_SPARSEBUFFERREWRITE
@@ -120,23 +121,34 @@ struct StageSparseOperationsPass
120121
}
121122
};
122123

123-
struct PostSparsificationRewritePass
124-
: public impl::PostSparsificationRewriteBase<
125-
PostSparsificationRewritePass> {
126-
PostSparsificationRewritePass() = default;
127-
PostSparsificationRewritePass(const PostSparsificationRewritePass &pass) =
124+
struct LowerSparseOpsToForeachPass
125+
: public impl::LowerSparseOpsToForeachBase<LowerSparseOpsToForeachPass> {
126+
LowerSparseOpsToForeachPass() = default;
127+
LowerSparseOpsToForeachPass(const LowerSparseOpsToForeachPass &pass) =
128128
default;
129-
PostSparsificationRewritePass(bool enableRT, bool foreach, bool convert) {
129+
LowerSparseOpsToForeachPass(bool enableRT, bool convert) {
130130
enableRuntimeLibrary = enableRT;
131-
enableForeach = foreach;
132131
enableConvert = convert;
133132
}
134133

135134
void runOnOperation() override {
136135
auto *ctx = &getContext();
137136
RewritePatternSet patterns(ctx);
138-
populatePostSparsificationRewriting(patterns, enableRuntimeLibrary,
139-
enableForeach, enableConvert);
137+
populateLowerSparseOpsToForeachPatterns(patterns, enableRuntimeLibrary,
138+
enableConvert);
139+
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
140+
}
141+
};
142+
143+
struct LowerForeachToSCFPass
144+
: public impl::LowerForeachToSCFBase<LowerForeachToSCFPass> {
145+
LowerForeachToSCFPass() = default;
146+
LowerForeachToSCFPass(const LowerForeachToSCFPass &pass) = default;
147+
148+
void runOnOperation() override {
149+
auto *ctx = &getContext();
150+
RewritePatternSet patterns(ctx);
151+
populateLowerForeachToSCFPatterns(patterns);
140152
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
141153
}
142154
};
@@ -399,15 +411,17 @@ std::unique_ptr<Pass> mlir::createStageSparseOperationsPass() {
399411
return std::make_unique<StageSparseOperationsPass>();
400412
}
401413

402-
std::unique_ptr<Pass> mlir::createPostSparsificationRewritePass() {
403-
return std::make_unique<PostSparsificationRewritePass>();
414+
std::unique_ptr<Pass> mlir::createLowerSparseOpsToForeachPass() {
415+
return std::make_unique<LowerSparseOpsToForeachPass>();
404416
}
405417

406418
std::unique_ptr<Pass>
407-
mlir::createPostSparsificationRewritePass(bool enableRT, bool enableForeach,
408-
bool enableConvert) {
409-
return std::make_unique<PostSparsificationRewritePass>(
410-
enableRT, enableForeach, enableConvert);
419+
mlir::createLowerSparseOpsToForeachPass(bool enableRT, bool enableConvert) {
420+
return std::make_unique<LowerSparseOpsToForeachPass>(enableRT, enableConvert);
421+
}
422+
423+
std::unique_ptr<Pass> mlir::createLowerForeachToSCFPass() {
424+
return std::make_unique<LowerForeachToSCFPass>();
411425
}
412426

413427
std::unique_ptr<Pass> mlir::createSparseTensorConversionPass() {

mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1303,21 +1303,23 @@ void mlir::populatePreSparsificationRewriting(RewritePatternSet &patterns) {
13031303
GenSemiRingReduction, GenSemiRingSelect>(patterns.getContext());
13041304
}
13051305

1306-
void mlir::populatePostSparsificationRewriting(RewritePatternSet &patterns,
1307-
bool enableRT,
1308-
bool enableForeach,
1309-
bool enableConvert) {
1306+
void mlir::populateLowerSparseOpsToForeachPatterns(RewritePatternSet &patterns,
1307+
bool enableRT,
1308+
bool enableConvert) {
13101309
patterns.add<ConcatenateRewriter, CrdTranslateRewriter,
13111310
ReshapeRewriter<tensor::ExpandShapeOp>,
13121311
ReshapeRewriter<tensor::CollapseShapeOp>,
13131312
Sparse2SparseReshapeRewriter<tensor::ExpandShapeOp>,
13141313
Sparse2SparseReshapeRewriter<tensor::CollapseShapeOp>,
13151314
SparseTensorDimOpRewriter, TensorReshapeRewriter, OutRewriter>(
13161315
patterns.getContext());
1317-
if (enableForeach)
1318-
patterns.add<ForeachRewriter>(patterns.getContext());
1316+
13191317
if (enableConvert)
13201318
patterns.add<DirectConvertRewriter>(patterns.getContext());
13211319
if (!enableRT)
13221320
patterns.add<NewRewriter>(patterns.getContext());
13231321
}
1322+
1323+
void mlir::populateLowerForeachToSCFPatterns(RewritePatternSet &patterns) {
1324+
patterns.add<ForeachRewriter>(patterns.getContext());
1325+
}

mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,10 @@ class SparsificationAndBufferizationPass
141141
OpPassManager pm("builtin.module");
142142
pm.addPass(createSparsificationPass(sparsificationOptions));
143143
pm.addNestedPass<func::FuncOp>(createStageSparseOperationsPass());
144-
pm.addPass(createPostSparsificationRewritePass(enableRuntimeLibrary));
144+
pm.addPass(createLowerSparseOpsToForeachPass(enableRuntimeLibrary,
145+
/*enableConvert=*/true));
146+
// TODO: DemapPass here!
147+
pm.addNestedPass<func::FuncOp>(createLowerForeachToSCFPass());
145148
if (vectorLength > 0) {
146149
pm.addPass(mlir::createLoopInvariantCodeMotionPass());
147150
pm.addPass(createSparseVectorizationPass(

mlir/test/Dialect/SparseTensor/codegen.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: mlir-opt %s --post-sparsification-rewrite --sparse-tensor-codegen --canonicalize -cse | FileCheck %s
1+
// RUN: mlir-opt %s --lower-sparse-ops-to-foreach --lower-sparse-foreach-to-scf --sparse-tensor-codegen --canonicalize -cse | FileCheck %s
22

33
#SV = #sparse_tensor.encoding<{ map = (d0) -> (d0 : compressed) }>
44

mlir/test/Dialect/SparseTensor/conversion.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: mlir-opt %s --post-sparsification-rewrite --sparse-tensor-conversion --canonicalize --cse | FileCheck %s
1+
// RUN: mlir-opt %s --lower-sparse-ops-to-foreach --lower-sparse-foreach-to-scf --sparse-tensor-conversion --canonicalize --cse | FileCheck %s
22

33
#SparseVector = #sparse_tensor.encoding<{
44
map = (d0) -> (d0 : compressed)

mlir/test/Dialect/SparseTensor/convert_dense2sparse.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: mlir-opt %s --stage-sparse-ops --post-sparsification-rewrite="enable-foreach=false" --canonicalize --cse | FileCheck %s
1+
// RUN: mlir-opt %s --stage-sparse-ops --lower-sparse-ops-to-foreach --canonicalize --cse | FileCheck %s
22

33
#SparseVector = #sparse_tensor.encoding<{
44
map = (d0) -> (d0 : compressed)

mlir/test/Dialect/SparseTensor/convert_sparse2dense.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: mlir-opt %s --stage-sparse-ops --post-sparsification-rewrite="enable-foreach=false" --canonicalize --cse | FileCheck %s
1+
// RUN: mlir-opt %s --stage-sparse-ops --lower-sparse-ops-to-foreach --canonicalize --cse | FileCheck %s
22

33
#SparseVector = #sparse_tensor.encoding<{
44
map = (d0) -> (d0 : compressed)

mlir/test/Dialect/SparseTensor/convert_sparse2sparse.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: mlir-opt %s --stage-sparse-ops --post-sparsification-rewrite="enable-foreach=false" --canonicalize --cse | FileCheck %s
1+
// RUN: mlir-opt %s --stage-sparse-ops --lower-sparse-ops-to-foreach --canonicalize --cse | FileCheck %s
22

33
#SparseVector64 = #sparse_tensor.encoding<{
44
map = (d0) -> (d0 : compressed),

mlir/test/Dialect/SparseTensor/rewriting_for_codegen.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
// RUN: mlir-opt %s -post-sparsification-rewrite="enable-runtime-library=false enable-convert=false" | \
2-
// RUN: FileCheck %s
1+
// RUN: mlir-opt %s --lower-sparse-ops-to-foreach="enable-runtime-library=false enable-convert=false" \
2+
// RUN: --lower-sparse-foreach-to-scf | FileCheck %s
33

44
#CSR = #sparse_tensor.encoding<{
55
map = (d0, d1) -> (d0 : dense, d1 : compressed)

mlir/test/Dialect/SparseTensor/sparse_concat.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1-
// RUN: mlir-opt %s --post-sparsification-rewrite="enable-runtime-library=false enable-convert=false" \
1+
// RUN: mlir-opt %s --lower-sparse-ops-to-foreach="enable-runtime-library=false enable-convert=false" --lower-sparse-foreach-to-scf \
22
// RUN: | FileCheck %s
3-
// RUN: mlir-opt %s --post-sparsification-rewrite="enable-runtime-library=true enable-convert=false" \
3+
// RUN: mlir-opt %s --lower-sparse-ops-to-foreach="enable-runtime-library=true enable-convert=false" --lower-sparse-foreach-to-scf \
44
// RUN: | FileCheck %s
55

66

mlir/test/Dialect/SparseTensor/sparse_expand.mlir

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44
// RUN: FileCheck %s --check-prefix=CHECK-SPARSE
55
// RUN: mlir-opt %s --linalg-generalize-named-ops \
66
// RUN: --linalg-fuse-elementwise-ops \
7-
// RUN: --sparsification --post-sparsification-rewrite \
7+
// RUN: --sparsification --lower-sparse-ops-to-foreach \
8+
// RUN: --lower-sparse-foreach-to-scf \
89
// RUN: --sparse-tensor-conversion --cse | \
910
// RUN: FileCheck %s --check-prefix=CHECK-CONVERT
1011

mlir/test/Dialect/SparseTensor/sparse_foreach.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: mlir-opt %s --post-sparsification-rewrite="enable-runtime-library=false enable-foreach=true" --canonicalize | FileCheck %s
1+
// RUN: mlir-opt %s --lower-sparse-foreach-to-scf --canonicalize | FileCheck %s
22

33
// CHECK-LABEL: func.func @sparse_foreach_constant
44
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index

mlir/test/Dialect/SparseTensor/sparse_pack.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: mlir-opt %s --canonicalize --post-sparsification-rewrite="enable-runtime-library=false" --sparse-tensor-codegen -cse --canonicalize | FileCheck %s
1+
// RUN: mlir-opt %s --canonicalize --sparse-tensor-codegen -cse --canonicalize | FileCheck %s
22

33
#COO = #sparse_tensor.encoding<{
44
map = (d0, d1) -> (d0 : compressed(nonunique), d1 : singleton),

mlir/test/Dialect/SparseTensor/sparse_reshape.mlir

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
// RUN: mlir-opt %s | mlir-opt | FileCheck %s --check-prefix=CHECK-ROUND
2-
// RUN: mlir-opt %s --post-sparsification-rewrite="enable-runtime-library=true enable-convert=false" \
3-
// RUN: --cse --canonicalize | FileCheck %s
4-
// RUN: mlir-opt %s --post-sparsification-rewrite="enable-runtime-library=false enable-convert=false" \
5-
// RUN: --cse --canonicalize | FileCheck %s
2+
// RUN: mlir-opt %s --lower-sparse-ops-to-foreach="enable-runtime-library=true enable-convert=false" \
3+
// RUN: --lower-sparse-foreach-to-scf --cse --canonicalize | FileCheck %s
4+
// RUN: mlir-opt %s --lower-sparse-ops-to-foreach="enable-runtime-library=false enable-convert=false" \
5+
// RUN: --lower-sparse-foreach-to-scf --cse --canonicalize | FileCheck %s
66

77
#SparseVector = #sparse_tensor.encoding<{ map = (d0) -> (d0 : compressed) }>
88
#SparseMatrix = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : compressed, d1 : compressed) }>

mlir/test/Dialect/SparseTensor/sparse_tensor_reshape.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
// RUN: mlir-opt %s --post-sparsification-rewrite="enable-runtime-library=false enable-convert=false" \
2-
// RUN: --cse --canonicalize | FileCheck %s
1+
// RUN: mlir-opt %s --lower-sparse-ops-to-foreach="enable-runtime-library=false enable-convert=false" \
2+
// RUN: --lower-sparse-foreach-to-scf --cse --canonicalize | FileCheck %s
33

44
#SparseMatrix = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : compressed, d1 : compressed) }>
55

0 commit comments

Comments
 (0)