8
8
//
9
9
// This header file defines prototypes of all sparse tensor passes.
10
10
//
11
+ // In general, this file takes the approach of keeping "mechanism" (the
12
+ // actual steps of applying a transformation) completely separate from
13
+ // "policy" (heuristics for when and where to apply transformations).
14
+ // The only exception is in `SparseToSparseConversionStrategy`; for which,
15
+ // see further discussion there.
16
+ //
11
17
// ===----------------------------------------------------------------------===//
12
18
13
19
#ifndef MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_PASSES_H_
@@ -21,6 +27,10 @@ namespace mlir {
21
27
// Forward.
22
28
class TypeConverter ;
23
29
30
+ // ===----------------------------------------------------------------------===//
31
+ // The Sparsification pass.
32
+ // ===----------------------------------------------------------------------===//
33
+
24
34
// / Defines a parallelization strategy. Any independent loop is a candidate
25
35
// / for parallelization. The loop is made parallel if (1) allowed by the
26
36
// / strategy (e.g., AnyStorageOuterLoop considers either a dense or sparse
@@ -51,7 +61,7 @@ enum class SparseVectorizationStrategy {
51
61
// / Converts command-line vectorization flag to the strategy enum.
52
62
SparseVectorizationStrategy sparseVectorizationStrategy (int32_t flag);
53
63
54
- // / Sparsification options .
64
+ // / Options for the Sparsification pass .
55
65
struct SparsificationOptions {
56
66
SparsificationOptions (SparseParallelizationStrategy p,
57
67
SparseVectorizationStrategy v, unsigned vl, bool e)
@@ -71,14 +81,56 @@ void populateSparsificationPatterns(
71
81
RewritePatternSet &patterns,
72
82
const SparsificationOptions &options = SparsificationOptions());
73
83
74
- // / Sets up sparse tensor conversion rules.
75
- void populateSparseTensorConversionPatterns (TypeConverter &typeConverter,
76
- RewritePatternSet &patterns);
77
-
78
84
std::unique_ptr<Pass> createSparsificationPass ();
79
85
std::unique_ptr<Pass>
80
86
createSparsificationPass (const SparsificationOptions &options);
87
+
88
+ // ===----------------------------------------------------------------------===//
89
+ // The SparseTensorConversion pass.
90
+ // ===----------------------------------------------------------------------===//
91
+
92
+ // / Defines a strategy for implementing sparse-to-sparse conversion.
93
+ // / `kAuto` leaves it up to the compiler to automatically determine
94
+ // / the method used. `kViaCOO` converts the source tensor to COO and
95
+ // / then converts the COO to the target format. `kDirect` converts
96
+ // / directly via the algorithm in <https://arxiv.org/abs/2001.02609>;
97
+ // / however, beware that there are many formats not supported by this
98
+ // / conversion method.
99
+ // /
100
+ // / The presence of the `kAuto` option violates our usual goal of keeping
101
+ // / policy completely separated from mechanism. The reason it exists is
102
+ // / because (at present) this strategy can only be specified on a per-file
103
+ // / basis. To see why this is a problem, note that `kDirect` cannot
104
+ // / support certain conversions; so if there is no `kAuto` setting,
105
+ // / then whenever a file contains a single non-`kDirect`-able conversion
106
+ // / the user would be forced to use `kViaCOO` for all conversions in
107
+ // / that file! In the future, instead of using this enum as a `Pass`
108
+ // / option, we could instead move it to being an attribute on the
109
+ // / conversion op; at which point `kAuto` would no longer be necessary.
110
+ enum class SparseToSparseConversionStrategy { kAuto , kViaCOO , kDirect };
111
+
112
+ // / Converts command-line sparse2sparse flag to the strategy enum.
113
+ SparseToSparseConversionStrategy sparseToSparseConversionStrategy (int32_t flag);
114
+
115
+ // / SparseTensorConversion options.
116
+ struct SparseTensorConversionOptions {
117
+ SparseTensorConversionOptions (SparseToSparseConversionStrategy s2s)
118
+ : sparseToSparseStrategy(s2s) {}
119
+ SparseTensorConversionOptions ()
120
+ : SparseTensorConversionOptions(SparseToSparseConversionStrategy::kAuto ) {
121
+ }
122
+ SparseToSparseConversionStrategy sparseToSparseStrategy;
123
+ };
124
+
125
+ // / Sets up sparse tensor conversion rules.
126
+ void populateSparseTensorConversionPatterns (
127
+ TypeConverter &typeConverter, RewritePatternSet &patterns,
128
+ const SparseTensorConversionOptions &options =
129
+ SparseTensorConversionOptions ());
130
+
81
131
std::unique_ptr<Pass> createSparseTensorConversionPass ();
132
+ std::unique_ptr<Pass>
133
+ createSparseTensorConversionPass (const SparseTensorConversionOptions &options);
82
134
83
135
// ===----------------------------------------------------------------------===//
84
136
// Registration.
0 commit comments