Skip to content

Commit c7e24db

Browse files
committed
[mlir][sparse] Introducing options for the SparseTensorConversion pass
This is work towards: llvm/llvm-project#51652 This differential sets up the options and threads them through everywhere, but doesn't actually use them yet. The differential that finally makes use of them is D122061, which is the final differential in the chain that fixes bug 51652. Reviewed By: aartbik Differential Revision: https://reviews.llvm.org/D122054
1 parent 110295e commit c7e24db

File tree

6 files changed

+124
-15
lines changed

6 files changed

+124
-15
lines changed

mlir/include/mlir/Dialect/SparseTensor/Pipelines/Passes.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,17 @@ struct SparseCompilerOptions
4949
vectorLength, enableSIMDIndex32);
5050
}
5151

52+
// These options must be kept in sync with `SparseTensorConversionBase`.
53+
PassOptions::Option<int32_t> sparseToSparse{
54+
*this, "s2s-strategy",
55+
desc("Set the strategy for sparse-to-sparse conversion"), init(0)};
56+
57+
/// Projects out the options for `createSparsificationPass`.
58+
SparseTensorConversionOptions sparseTensorConversionOptions() const {
59+
return SparseTensorConversionOptions(
60+
sparseToSparseConversionStrategy(sparseToSparse));
61+
}
62+
5263
// These options must be kept in sync with `ConvertVectorToLLVMBase`.
5364
// TODO(wrengr): does `indexOptimizations` differ from `enableSIMDIndex32`?
5465
PassOptions::Option<bool> reassociateFPReductions{

mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h

Lines changed: 57 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,12 @@
88
//
99
// This header file defines prototypes of all sparse tensor passes.
1010
//
11+
// In general, this file takes the approach of keeping "mechanism" (the
12+
// actual steps of applying a transformation) completely separate from
13+
// "policy" (heuristics for when and where to apply transformations).
14+
// The only exception is in `SparseToSparseConversionStrategy`; for which,
15+
// see further discussion there.
16+
//
1117
//===----------------------------------------------------------------------===//
1218

1319
#ifndef MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_PASSES_H_
@@ -21,6 +27,10 @@ namespace mlir {
2127
// Forward.
2228
class TypeConverter;
2329

30+
//===----------------------------------------------------------------------===//
31+
// The Sparsification pass.
32+
//===----------------------------------------------------------------------===//
33+
2434
/// Defines a parallelization strategy. Any independent loop is a candidate
2535
/// for parallelization. The loop is made parallel if (1) allowed by the
2636
/// strategy (e.g., AnyStorageOuterLoop considers either a dense or sparse
@@ -51,7 +61,7 @@ enum class SparseVectorizationStrategy {
5161
/// Converts command-line vectorization flag to the strategy enum.
5262
SparseVectorizationStrategy sparseVectorizationStrategy(int32_t flag);
5363

54-
/// Sparsification options.
64+
/// Options for the Sparsification pass.
5565
struct SparsificationOptions {
5666
SparsificationOptions(SparseParallelizationStrategy p,
5767
SparseVectorizationStrategy v, unsigned vl, bool e)
@@ -71,14 +81,56 @@ void populateSparsificationPatterns(
7181
RewritePatternSet &patterns,
7282
const SparsificationOptions &options = SparsificationOptions());
7383

74-
/// Sets up sparse tensor conversion rules.
75-
void populateSparseTensorConversionPatterns(TypeConverter &typeConverter,
76-
RewritePatternSet &patterns);
77-
7884
std::unique_ptr<Pass> createSparsificationPass();
7985
std::unique_ptr<Pass>
8086
createSparsificationPass(const SparsificationOptions &options);
87+
88+
//===----------------------------------------------------------------------===//
89+
// The SparseTensorConversion pass.
90+
//===----------------------------------------------------------------------===//
91+
92+
/// Defines a strategy for implementing sparse-to-sparse conversion.
93+
/// `kAuto` leaves it up to the compiler to automatically determine
94+
/// the method used. `kViaCOO` converts the source tensor to COO and
95+
/// then converts the COO to the target format. `kDirect` converts
96+
/// directly via the algorithm in <https://arxiv.org/abs/2001.02609>;
97+
/// however, beware that there are many formats not supported by this
98+
/// conversion method.
99+
///
100+
/// The presence of the `kAuto` option violates our usual goal of keeping
101+
/// policy completely separated from mechanism. The reason it exists is
102+
/// because (at present) this strategy can only be specified on a per-file
103+
/// basis. To see why this is a problem, note that `kDirect` cannot
104+
/// support certain conversions; so if there is no `kAuto` setting,
105+
/// then whenever a file contains a single non-`kDirect`-able conversion
106+
/// the user would be forced to use `kViaCOO` for all conversions in
107+
/// that file! In the future, instead of using this enum as a `Pass`
108+
/// option, we could instead move it to being an attribute on the
109+
/// conversion op; at which point `kAuto` would no longer be necessary.
110+
enum class SparseToSparseConversionStrategy { kAuto, kViaCOO, kDirect };
111+
112+
/// Converts command-line sparse2sparse flag to the strategy enum.
113+
SparseToSparseConversionStrategy sparseToSparseConversionStrategy(int32_t flag);
114+
115+
/// SparseTensorConversion options.
116+
struct SparseTensorConversionOptions {
117+
SparseTensorConversionOptions(SparseToSparseConversionStrategy s2s)
118+
: sparseToSparseStrategy(s2s) {}
119+
SparseTensorConversionOptions()
120+
: SparseTensorConversionOptions(SparseToSparseConversionStrategy::kAuto) {
121+
}
122+
SparseToSparseConversionStrategy sparseToSparseStrategy;
123+
};
124+
125+
/// Sets up sparse tensor conversion rules.
126+
void populateSparseTensorConversionPatterns(
127+
TypeConverter &typeConverter, RewritePatternSet &patterns,
128+
const SparseTensorConversionOptions &options =
129+
SparseTensorConversionOptions());
130+
81131
std::unique_ptr<Pass> createSparseTensorConversionPass();
132+
std::unique_ptr<Pass>
133+
createSparseTensorConversionPass(const SparseTensorConversionOptions &options);
82134

83135
//===----------------------------------------------------------------------===//
84136
// Registration.

mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,10 @@ def SparseTensorConversion : Pass<"sparse-tensor-conversion", "ModuleOp"> {
114114
"sparse_tensor::SparseTensorDialect",
115115
"vector::VectorDialect",
116116
];
117+
let options = [
118+
Option<"sparseToSparse", "s2s-strategy", "int32_t", "0",
119+
"Set the strategy for sparse-to-sparse conversion">,
120+
];
117121
}
118122

119123
#endif // MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_PASSES

mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,8 @@ void mlir::sparse_tensor::buildSparseCompiler(
3333
pm.addNestedPass<FuncOp>(createLinalgGeneralizationPass());
3434
pm.addPass(createLinalgElementwiseOpFusionPass());
3535
pm.addPass(createSparsificationPass(options.sparsificationOptions()));
36-
pm.addPass(createSparseTensorConversionPass());
36+
pm.addPass(createSparseTensorConversionPass(
37+
options.sparseTensorConversionOptions()));
3738
pm.addNestedPass<FuncOp>(createLinalgBufferizePass());
3839
pm.addNestedPass<FuncOp>(vector::createVectorBufferizePass());
3940
pm.addNestedPass<FuncOp>(createConvertLinalgToLoopsPass());

mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -453,7 +453,18 @@ class SparseTensorInitConverter : public OpConversionPattern<InitOp> {
453453

454454
/// Sparse conversion rule for the convert operator.
455455
class SparseTensorConvertConverter : public OpConversionPattern<ConvertOp> {
456+
/// Options to control sparse code generation.
457+
SparseTensorConversionOptions options;
458+
459+
public:
456460
using OpConversionPattern::OpConversionPattern;
461+
SparseTensorConvertConverter(MLIRContext *context,
462+
SparseTensorConversionOptions o)
463+
: OpConversionPattern<ConvertOp>(context), options(o) {}
464+
SparseTensorConvertConverter(TypeConverter &typeConv, MLIRContext *context,
465+
SparseTensorConversionOptions o)
466+
: OpConversionPattern<ConvertOp>(typeConv, context), options(o) {}
467+
457468
LogicalResult
458469
matchAndRewrite(ConvertOp op, OpAdaptor adaptor,
459470
ConversionPatternRewriter &rewriter) const override {
@@ -825,14 +836,17 @@ class SparseTensorOutConverter : public OpConversionPattern<OutOp> {
825836

826837
/// Populates the given patterns list with conversion rules required for
827838
/// the sparsification of linear algebra operations.
828-
void mlir::populateSparseTensorConversionPatterns(TypeConverter &typeConverter,
829-
RewritePatternSet &patterns) {
839+
void mlir::populateSparseTensorConversionPatterns(
840+
TypeConverter &typeConverter, RewritePatternSet &patterns,
841+
const SparseTensorConversionOptions &options) {
830842
patterns.add<SparseReturnConverter, SparseTensorToDimSizeConverter,
831843
SparseCastConverter, SparseTensorNewConverter,
832-
SparseTensorInitConverter, SparseTensorConvertConverter,
833-
SparseTensorReleaseConverter, SparseTensorToPointersConverter,
834-
SparseTensorToIndicesConverter, SparseTensorToValuesConverter,
835-
SparseTensorLoadConverter, SparseTensorLexInsertConverter,
836-
SparseTensorExpandConverter, SparseTensorCompressConverter,
837-
SparseTensorOutConverter>(typeConverter, patterns.getContext());
844+
SparseTensorInitConverter, SparseTensorReleaseConverter,
845+
SparseTensorToPointersConverter, SparseTensorToIndicesConverter,
846+
SparseTensorToValuesConverter, SparseTensorLoadConverter,
847+
SparseTensorLexInsertConverter, SparseTensorExpandConverter,
848+
SparseTensorCompressConverter, SparseTensorOutConverter>(
849+
typeConverter, patterns.getContext());
850+
patterns.add<SparseTensorConvertConverter>(typeConverter,
851+
patterns.getContext(), options);
838852
}

mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,13 @@ class SparseTensorTypeConverter : public TypeConverter {
7373

7474
struct SparseTensorConversionPass
7575
: public SparseTensorConversionBase<SparseTensorConversionPass> {
76+
77+
SparseTensorConversionPass() = default;
78+
SparseTensorConversionPass(const SparseTensorConversionPass &pass) = default;
79+
SparseTensorConversionPass(const SparseTensorConversionOptions &options) {
80+
sparseToSparse = static_cast<int32_t>(options.sparseToSparseStrategy);
81+
}
82+
7683
void runOnOperation() override {
7784
auto *ctx = &getContext();
7885
RewritePatternSet patterns(ctx);
@@ -106,11 +113,14 @@ struct SparseTensorConversionPass
106113
target
107114
.addLegalDialect<bufferization::BufferizationDialect, LLVM::LLVMDialect,
108115
memref::MemRefDialect, scf::SCFDialect>();
116+
// Translate strategy flags to strategy options.
117+
SparseTensorConversionOptions options(
118+
sparseToSparseConversionStrategy(sparseToSparse));
109119
// Populate with rules and apply rewriting rules.
110120
populateFunctionOpInterfaceTypeConversionPattern<FuncOp>(patterns,
111121
converter);
112122
populateCallOpTypeConversionPattern(patterns, converter);
113-
populateSparseTensorConversionPatterns(converter, patterns);
123+
populateSparseTensorConversionPatterns(converter, patterns, options);
114124
if (failed(applyPartialConversion(getOperation(), target,
115125
std::move(patterns))))
116126
signalPassFailure();
@@ -146,6 +156,18 @@ SparseVectorizationStrategy mlir::sparseVectorizationStrategy(int32_t flag) {
146156
}
147157
}
148158

159+
SparseToSparseConversionStrategy
160+
mlir::sparseToSparseConversionStrategy(int32_t flag) {
161+
switch (flag) {
162+
default:
163+
return SparseToSparseConversionStrategy::kAuto;
164+
case 1:
165+
return SparseToSparseConversionStrategy::kViaCOO;
166+
case 2:
167+
return SparseToSparseConversionStrategy::kDirect;
168+
}
169+
}
170+
149171
std::unique_ptr<Pass> mlir::createSparsificationPass() {
150172
return std::make_unique<SparsificationPass>();
151173
}
@@ -158,3 +180,8 @@ mlir::createSparsificationPass(const SparsificationOptions &options) {
158180
std::unique_ptr<Pass> mlir::createSparseTensorConversionPass() {
159181
return std::make_unique<SparseTensorConversionPass>();
160182
}
183+
184+
std::unique_ptr<Pass> mlir::createSparseTensorConversionPass(
185+
const SparseTensorConversionOptions &options) {
186+
return std::make_unique<SparseTensorConversionPass>(options);
187+
}

0 commit comments

Comments
 (0)