Skip to content

[mlir][sparse] split post-sparsification-rewriting into two passes. #70727

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Oct 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 14 additions & 8 deletions mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -114,17 +114,23 @@ void populateStageSparseOperationsPatterns(RewritePatternSet &patterns);
std::unique_ptr<Pass> createStageSparseOperationsPass();

//===----------------------------------------------------------------------===//
// The PostSparsificationRewriting pass.
// The LowerSparseOpsToForeach pass.
//===----------------------------------------------------------------------===//

void populatePostSparsificationRewriting(RewritePatternSet &patterns,
bool enableRT, bool enableForeach,
bool enableConvert);
void populateLowerSparseOpsToForeachPatterns(RewritePatternSet &patterns,
bool enableRT, bool enableConvert);

std::unique_ptr<Pass> createPostSparsificationRewritePass();
std::unique_ptr<Pass>
createPostSparsificationRewritePass(bool enableRT, bool enableForeach = true,
bool enableConvert = true);
std::unique_ptr<Pass> createLowerSparseOpsToForeachPass();
std::unique_ptr<Pass> createLowerSparseOpsToForeachPass(bool enableRT,
bool enableConvert);

//===----------------------------------------------------------------------===//
// The LowerForeachToSCF pass.
//===----------------------------------------------------------------------===//

void populateLowerForeachToSCFPatterns(RewritePatternSet &patterns);

std::unique_ptr<Pass> createLowerForeachToSCFPass();

//===----------------------------------------------------------------------===//
// The SparseTensorConversion pass.
Expand Down
23 changes: 17 additions & 6 deletions mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -167,13 +167,12 @@ def StageSparseOperations : Pass<"stage-sparse-ops", "func::FuncOp"> {
];
}

def PostSparsificationRewrite : Pass<"post-sparsification-rewrite", "ModuleOp"> {
def LowerSparseOpsToForeach : Pass<"lower-sparse-ops-to-foreach", "ModuleOp"> {
let summary = "Applies sparse tensor rewriting rules after sparsification";
let description = [{
A pass that applies rewriting rules to sparse tensor operations after
running the actual sparsification pass.
A pass that lowers high-level sparse operations to sparse_tensor.foreach.
}];
let constructor = "mlir::createPostSparsificationRewritePass()";
let constructor = "mlir::createLowerSparseOpsToForeachPass()";
let dependentDialects = [
"affine::AffineDialect",
"arith::ArithDialect",
Expand All @@ -186,13 +185,25 @@ def PostSparsificationRewrite : Pass<"post-sparsification-rewrite", "ModuleOp">
let options = [
Option<"enableRuntimeLibrary", "enable-runtime-library", "bool",
"true", "Enable runtime library for manipulating sparse tensors">,
Option<"enableForeach", "enable-foreach", "bool",
"true", "Enable rewriting rules for the foreach operator">,
Option<"enableConvert", "enable-convert", "bool",
"true", "Enable rewriting rules for the convert operator">,
];
}

def LowerForeachToSCF : Pass<"lower-sparse-foreach-to-scf", "func::FuncOp"> {
let summary = "Decompose a complex sparse operation into multiple stages";
let description = [{
A pass that lowers sparse_tensor.foreach operation to scf dialect.
}];
let constructor = "mlir::createLowerForeachToSCFPass()";
let dependentDialects = [
"memref::MemRefDialect",
"scf::SCFDialect",
"sparse_tensor::SparseTensorDialect",
];
}


def SparseTensorConversionPass : Pass<"sparse-tensor-conversion", "ModuleOp"> {
let summary = "Convert sparse tensors and primitives to library calls";
let description = [{
Expand Down
46 changes: 30 additions & 16 deletions mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ namespace mlir {
#define GEN_PASS_DEF_SPARSEREINTERPRETMAP
#define GEN_PASS_DEF_PRESPARSIFICATIONREWRITE
#define GEN_PASS_DEF_SPARSIFICATIONPASS
#define GEN_PASS_DEF_POSTSPARSIFICATIONREWRITE
#define GEN_PASS_DEF_LOWERSPARSEOPSTOFOREACH
#define GEN_PASS_DEF_LOWERFOREACHTOSCF
#define GEN_PASS_DEF_SPARSETENSORCONVERSIONPASS
#define GEN_PASS_DEF_SPARSETENSORCODEGEN
#define GEN_PASS_DEF_SPARSEBUFFERREWRITE
Expand Down Expand Up @@ -120,23 +121,34 @@ struct StageSparseOperationsPass
}
};

struct PostSparsificationRewritePass
: public impl::PostSparsificationRewriteBase<
PostSparsificationRewritePass> {
PostSparsificationRewritePass() = default;
PostSparsificationRewritePass(const PostSparsificationRewritePass &pass) =
struct LowerSparseOpsToForeachPass
: public impl::LowerSparseOpsToForeachBase<LowerSparseOpsToForeachPass> {
LowerSparseOpsToForeachPass() = default;
LowerSparseOpsToForeachPass(const LowerSparseOpsToForeachPass &pass) =
default;
PostSparsificationRewritePass(bool enableRT, bool foreach, bool convert) {
LowerSparseOpsToForeachPass(bool enableRT, bool convert) {
enableRuntimeLibrary = enableRT;
enableForeach = foreach;
enableConvert = convert;
}

void runOnOperation() override {
auto *ctx = &getContext();
RewritePatternSet patterns(ctx);
populatePostSparsificationRewriting(patterns, enableRuntimeLibrary,
enableForeach, enableConvert);
populateLowerSparseOpsToForeachPatterns(patterns, enableRuntimeLibrary,
enableConvert);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
}
};

struct LowerForeachToSCFPass
: public impl::LowerForeachToSCFBase<LowerForeachToSCFPass> {
LowerForeachToSCFPass() = default;
LowerForeachToSCFPass(const LowerForeachToSCFPass &pass) = default;

void runOnOperation() override {
auto *ctx = &getContext();
RewritePatternSet patterns(ctx);
populateLowerForeachToSCFPatterns(patterns);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
}
};
Expand Down Expand Up @@ -399,15 +411,17 @@ std::unique_ptr<Pass> mlir::createStageSparseOperationsPass() {
return std::make_unique<StageSparseOperationsPass>();
}

std::unique_ptr<Pass> mlir::createPostSparsificationRewritePass() {
return std::make_unique<PostSparsificationRewritePass>();
std::unique_ptr<Pass> mlir::createLowerSparseOpsToForeachPass() {
return std::make_unique<LowerSparseOpsToForeachPass>();
}

std::unique_ptr<Pass>
mlir::createPostSparsificationRewritePass(bool enableRT, bool enableForeach,
bool enableConvert) {
return std::make_unique<PostSparsificationRewritePass>(
enableRT, enableForeach, enableConvert);
mlir::createLowerSparseOpsToForeachPass(bool enableRT, bool enableConvert) {
return std::make_unique<LowerSparseOpsToForeachPass>(enableRT, enableConvert);
}

std::unique_ptr<Pass> mlir::createLowerForeachToSCFPass() {
return std::make_unique<LowerForeachToSCFPass>();
}

std::unique_ptr<Pass> mlir::createSparseTensorConversionPass() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1303,21 +1303,23 @@ void mlir::populatePreSparsificationRewriting(RewritePatternSet &patterns) {
GenSemiRingReduction, GenSemiRingSelect>(patterns.getContext());
}

void mlir::populatePostSparsificationRewriting(RewritePatternSet &patterns,
bool enableRT,
bool enableForeach,
bool enableConvert) {
void mlir::populateLowerSparseOpsToForeachPatterns(RewritePatternSet &patterns,
bool enableRT,
bool enableConvert) {
patterns.add<ConcatenateRewriter, CrdTranslateRewriter,
ReshapeRewriter<tensor::ExpandShapeOp>,
ReshapeRewriter<tensor::CollapseShapeOp>,
Sparse2SparseReshapeRewriter<tensor::ExpandShapeOp>,
Sparse2SparseReshapeRewriter<tensor::CollapseShapeOp>,
SparseTensorDimOpRewriter, TensorReshapeRewriter, OutRewriter>(
patterns.getContext());
if (enableForeach)
patterns.add<ForeachRewriter>(patterns.getContext());

if (enableConvert)
patterns.add<DirectConvertRewriter>(patterns.getContext());
if (!enableRT)
patterns.add<NewRewriter>(patterns.getContext());
}

void mlir::populateLowerForeachToSCFPatterns(RewritePatternSet &patterns) {
patterns.add<ForeachRewriter>(patterns.getContext());
}
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,10 @@ class SparsificationAndBufferizationPass
OpPassManager pm("builtin.module");
pm.addPass(createSparsificationPass(sparsificationOptions));
pm.addNestedPass<func::FuncOp>(createStageSparseOperationsPass());
pm.addPass(createPostSparsificationRewritePass(enableRuntimeLibrary));
pm.addPass(createLowerSparseOpsToForeachPass(enableRuntimeLibrary,
/*enableConvert=*/true));
// TODO: DemapPass here!
pm.addNestedPass<func::FuncOp>(createLowerForeachToSCFPass());
if (vectorLength > 0) {
pm.addPass(mlir::createLoopInvariantCodeMotionPass());
pm.addPass(createSparseVectorizationPass(
Expand Down
2 changes: 1 addition & 1 deletion mlir/test/Dialect/SparseTensor/codegen.mlir
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: mlir-opt %s --post-sparsification-rewrite --sparse-tensor-codegen --canonicalize -cse | FileCheck %s
// RUN: mlir-opt %s --lower-sparse-ops-to-foreach --lower-sparse-foreach-to-scf --sparse-tensor-codegen --canonicalize -cse | FileCheck %s

#SV = #sparse_tensor.encoding<{ map = (d0) -> (d0 : compressed) }>

Expand Down
2 changes: 1 addition & 1 deletion mlir/test/Dialect/SparseTensor/conversion.mlir
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: mlir-opt %s --post-sparsification-rewrite --sparse-tensor-conversion --canonicalize --cse | FileCheck %s
// RUN: mlir-opt %s --lower-sparse-ops-to-foreach --lower-sparse-foreach-to-scf --sparse-tensor-conversion --canonicalize --cse | FileCheck %s

#SparseVector = #sparse_tensor.encoding<{
map = (d0) -> (d0 : compressed)
Expand Down
2 changes: 1 addition & 1 deletion mlir/test/Dialect/SparseTensor/convert_dense2sparse.mlir
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: mlir-opt %s --stage-sparse-ops --post-sparsification-rewrite="enable-foreach=false" --canonicalize --cse | FileCheck %s
// RUN: mlir-opt %s --stage-sparse-ops --lower-sparse-ops-to-foreach --canonicalize --cse | FileCheck %s

#SparseVector = #sparse_tensor.encoding<{
map = (d0) -> (d0 : compressed)
Expand Down
2 changes: 1 addition & 1 deletion mlir/test/Dialect/SparseTensor/convert_sparse2dense.mlir
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: mlir-opt %s --stage-sparse-ops --post-sparsification-rewrite="enable-foreach=false" --canonicalize --cse | FileCheck %s
// RUN: mlir-opt %s --stage-sparse-ops --lower-sparse-ops-to-foreach --canonicalize --cse | FileCheck %s

#SparseVector = #sparse_tensor.encoding<{
map = (d0) -> (d0 : compressed)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: mlir-opt %s --stage-sparse-ops --post-sparsification-rewrite="enable-foreach=false" --canonicalize --cse | FileCheck %s
// RUN: mlir-opt %s --stage-sparse-ops --lower-sparse-ops-to-foreach --canonicalize --cse | FileCheck %s

#SparseVector64 = #sparse_tensor.encoding<{
map = (d0) -> (d0 : compressed),
Expand Down
4 changes: 2 additions & 2 deletions mlir/test/Dialect/SparseTensor/rewriting_for_codegen.mlir
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// RUN: mlir-opt %s -post-sparsification-rewrite="enable-runtime-library=false enable-convert=false" | \
// RUN: FileCheck %s
// RUN: mlir-opt %s --lower-sparse-ops-to-foreach="enable-runtime-library=false enable-convert=false" \
// RUN: --lower-sparse-foreach-to-scf | FileCheck %s

#CSR = #sparse_tensor.encoding<{
map = (d0, d1) -> (d0 : dense, d1 : compressed)
Expand Down
4 changes: 2 additions & 2 deletions mlir/test/Dialect/SparseTensor/sparse_concat.mlir
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
// RUN: mlir-opt %s --post-sparsification-rewrite="enable-runtime-library=false enable-convert=false" \
// RUN: mlir-opt %s --lower-sparse-ops-to-foreach="enable-runtime-library=false enable-convert=false" --lower-sparse-foreach-to-scf \
// RUN: | FileCheck %s
// RUN: mlir-opt %s --post-sparsification-rewrite="enable-runtime-library=true enable-convert=false" \
// RUN: mlir-opt %s --lower-sparse-ops-to-foreach="enable-runtime-library=true enable-convert=false" --lower-sparse-foreach-to-scf \
// RUN: | FileCheck %s


Expand Down
3 changes: 2 additions & 1 deletion mlir/test/Dialect/SparseTensor/sparse_expand.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
// RUN: FileCheck %s --check-prefix=CHECK-SPARSE
// RUN: mlir-opt %s --linalg-generalize-named-ops \
// RUN: --linalg-fuse-elementwise-ops \
// RUN: --sparsification --post-sparsification-rewrite \
// RUN: --sparsification --lower-sparse-ops-to-foreach \
// RUN: --lower-sparse-foreach-to-scf \
// RUN: --sparse-tensor-conversion --cse | \
// RUN: FileCheck %s --check-prefix=CHECK-CONVERT

Expand Down
2 changes: 1 addition & 1 deletion mlir/test/Dialect/SparseTensor/sparse_foreach.mlir
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: mlir-opt %s --post-sparsification-rewrite="enable-runtime-library=false enable-foreach=true" --canonicalize | FileCheck %s
// RUN: mlir-opt %s --lower-sparse-foreach-to-scf --canonicalize | FileCheck %s

// CHECK-LABEL: func.func @sparse_foreach_constant
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
Expand Down
2 changes: 1 addition & 1 deletion mlir/test/Dialect/SparseTensor/sparse_pack.mlir
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: mlir-opt %s --canonicalize --post-sparsification-rewrite="enable-runtime-library=false" --sparse-tensor-codegen -cse --canonicalize | FileCheck %s
// RUN: mlir-opt %s --canonicalize --sparse-tensor-codegen -cse --canonicalize | FileCheck %s

#COO = #sparse_tensor.encoding<{
map = (d0, d1) -> (d0 : compressed(nonunique), d1 : singleton),
Expand Down
8 changes: 4 additions & 4 deletions mlir/test/Dialect/SparseTensor/sparse_reshape.mlir
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
// RUN: mlir-opt %s | mlir-opt | FileCheck %s --check-prefix=CHECK-ROUND
// RUN: mlir-opt %s --post-sparsification-rewrite="enable-runtime-library=true enable-convert=false" \
// RUN: --cse --canonicalize | FileCheck %s
// RUN: mlir-opt %s --post-sparsification-rewrite="enable-runtime-library=false enable-convert=false" \
// RUN: --cse --canonicalize | FileCheck %s
// RUN: mlir-opt %s --lower-sparse-ops-to-foreach="enable-runtime-library=true enable-convert=false" \
// RUN: --lower-sparse-foreach-to-scf --cse --canonicalize | FileCheck %s
// RUN: mlir-opt %s --lower-sparse-ops-to-foreach="enable-runtime-library=false enable-convert=false" \
// RUN: --lower-sparse-foreach-to-scf --cse --canonicalize | FileCheck %s

#SparseVector = #sparse_tensor.encoding<{ map = (d0) -> (d0 : compressed) }>
#SparseMatrix = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : compressed, d1 : compressed) }>
Expand Down
4 changes: 2 additions & 2 deletions mlir/test/Dialect/SparseTensor/sparse_tensor_reshape.mlir
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// RUN: mlir-opt %s --post-sparsification-rewrite="enable-runtime-library=false enable-convert=false" \
// RUN: --cse --canonicalize | FileCheck %s
// RUN: mlir-opt %s --lower-sparse-ops-to-foreach="enable-runtime-library=false enable-convert=false" \
// RUN: --lower-sparse-foreach-to-scf --cse --canonicalize | FileCheck %s

#SparseMatrix = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : compressed, d1 : compressed) }>

Expand Down