@@ -61,20 +61,34 @@ class SparsificationAndBufferizationPass
61
61
: public impl::SparsificationAndBufferizationBase<
62
62
SparsificationAndBufferizationPass> {
63
63
public:
64
+ // Private pass options only.
64
65
SparsificationAndBufferizationPass (
65
66
const bufferization::OneShotBufferizationOptions &bufferizationOptions,
66
67
const SparsificationOptions &sparsificationOptions,
67
68
bool createSparseDeallocs, bool enableRuntimeLibrary,
68
- bool enableBufferInitialization, unsigned vectorLength,
69
- bool enableVLAVectorization, bool enableSIMDIndex32, bool enableGPULibgen)
69
+ bool enableBufferInitialization)
70
70
: bufferizationOptions(bufferizationOptions),
71
71
sparsificationOptions (sparsificationOptions),
72
72
createSparseDeallocs(createSparseDeallocs),
73
73
enableRuntimeLibrary(enableRuntimeLibrary),
74
- enableBufferInitialization(enableBufferInitialization),
75
- vectorLength(vectorLength),
76
- enableVLAVectorization(enableVLAVectorization),
77
- enableSIMDIndex32(enableSIMDIndex32), enableGPULibgen(enableGPULibgen) {
74
+ enableBufferInitialization(enableBufferInitialization) {}
75
+ // Private pass options and visible pass options.
76
+ SparsificationAndBufferizationPass (
77
+ const bufferization::OneShotBufferizationOptions &bufferizationOptions,
78
+ const SparsificationOptions &sparsificationOptions,
79
+ bool createSparseDeallocs, bool enableRuntimeLibrary,
80
+ bool enableBufferInitialization, unsigned vl, bool vla, bool index32,
81
+ bool gpu)
82
+ : bufferizationOptions(bufferizationOptions),
83
+ sparsificationOptions(sparsificationOptions),
84
+ createSparseDeallocs(createSparseDeallocs),
85
+ enableRuntimeLibrary(enableRuntimeLibrary),
86
+ enableBufferInitialization(enableBufferInitialization) {
87
+ // Set the visible pass options explicitly.
88
+ vectorLength = vl;
89
+ enableVLAVectorization = vla;
90
+ enableSIMDIndex32 = index32;
91
+ enableGPULibgen = gpu;
78
92
}
79
93
80
94
// / Bufferize all dense ops. This assumes that no further analysis is needed
@@ -178,10 +192,6 @@ class SparsificationAndBufferizationPass
178
192
bool createSparseDeallocs;
179
193
bool enableRuntimeLibrary;
180
194
bool enableBufferInitialization;
181
- unsigned vectorLength;
182
- bool enableVLAVectorization;
183
- bool enableSIMDIndex32;
184
- bool enableGPULibgen;
185
195
};
186
196
187
197
} // namespace sparse_tensor
@@ -213,16 +223,13 @@ mlir::getBufferizationOptionsForSparsification(bool analysisOnly) {
213
223
214
224
std::unique_ptr<mlir::Pass> mlir::createSparsificationAndBufferizationPass () {
215
225
SparsificationOptions sparseOptions;
216
- return createSparsificationAndBufferizationPass (
226
+ return std::make_unique<
227
+ mlir::sparse_tensor::SparsificationAndBufferizationPass>(
217
228
getBufferizationOptionsForSparsification (/* analysisOnly=*/ false ),
218
229
sparseOptions,
219
230
/* createSparseDeallocs=*/ false ,
220
231
/* enableRuntimeLibrary=*/ false ,
221
- /* enableBufferInitialization=*/ false ,
222
- /* vectorLength=*/ 0 ,
223
- /* enableVLAVectorization=*/ false ,
224
- /* enableSIMDIndex32=*/ false ,
225
- /* enableGPULibgen=*/ false );
232
+ /* enableBufferInitialization=*/ false );
226
233
}
227
234
228
235
std::unique_ptr<mlir::Pass> mlir::createSparsificationAndBufferizationPass (
0 commit comments