Skip to content

Commit 438a7d4

Browse files
authored
[mlir][sparse] expose optimization flags to mini pipeline (#95158)
Some of the options only fed into the full sparse pipeline. However, some backends prefer to use the sparse minipipeline. This change exposes some important optimization flags to the pass as well. This prepares some SIMDization of PyTorch sparsified code.
1 parent c6ee562 commit 438a7d4

File tree

3 files changed

+78
-16
lines changed

3 files changed

+78
-16
lines changed

mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -462,6 +462,18 @@ def SparsificationAndBufferization : Pass<"sparsification-and-bufferization", "M
462462
"sparse_tensor::SparseTensorDialect",
463463
"vector::VectorDialect"
464464
];
465+
// Important optimization options are made visible to the mini-pipeline
466+
// so that clients can set these (when not using the full pipeline).
467+
let options = [
468+
Option<"vectorLength", "vl", "int32_t", "0",
469+
"Set the vector length (use 0 to disable vectorization)">,
470+
Option<"enableVLAVectorization", "enable-vla-vectorization", "bool", "false",
471+
"Enable vector length agnostic vectorization">,
472+
Option<"enableSIMDIndex32", "enable-simd-index32", "bool", "false",
473+
"Enable i32 indexing into vectors (for efficient gather/scatter)">,
474+
Option<"enableGPULibgen", "enable-gpu-libgen", "bool", "false",
475+
"Enable GPU acceleration by means of direct library calls">,
476+
];
465477
}
466478

467479
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp

Lines changed: 23 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -61,20 +61,34 @@ class SparsificationAndBufferizationPass
6161
: public impl::SparsificationAndBufferizationBase<
6262
SparsificationAndBufferizationPass> {
6363
public:
64+
// Private pass options only.
6465
SparsificationAndBufferizationPass(
6566
const bufferization::OneShotBufferizationOptions &bufferizationOptions,
6667
const SparsificationOptions &sparsificationOptions,
6768
bool createSparseDeallocs, bool enableRuntimeLibrary,
68-
bool enableBufferInitialization, unsigned vectorLength,
69-
bool enableVLAVectorization, bool enableSIMDIndex32, bool enableGPULibgen)
69+
bool enableBufferInitialization)
7070
: bufferizationOptions(bufferizationOptions),
7171
sparsificationOptions(sparsificationOptions),
7272
createSparseDeallocs(createSparseDeallocs),
7373
enableRuntimeLibrary(enableRuntimeLibrary),
74-
enableBufferInitialization(enableBufferInitialization),
75-
vectorLength(vectorLength),
76-
enableVLAVectorization(enableVLAVectorization),
77-
enableSIMDIndex32(enableSIMDIndex32), enableGPULibgen(enableGPULibgen) {
74+
enableBufferInitialization(enableBufferInitialization) {}
75+
// Private pass options and visible pass options.
76+
SparsificationAndBufferizationPass(
77+
const bufferization::OneShotBufferizationOptions &bufferizationOptions,
78+
const SparsificationOptions &sparsificationOptions,
79+
bool createSparseDeallocs, bool enableRuntimeLibrary,
80+
bool enableBufferInitialization, unsigned vl, bool vla, bool index32,
81+
bool gpu)
82+
: bufferizationOptions(bufferizationOptions),
83+
sparsificationOptions(sparsificationOptions),
84+
createSparseDeallocs(createSparseDeallocs),
85+
enableRuntimeLibrary(enableRuntimeLibrary),
86+
enableBufferInitialization(enableBufferInitialization) {
87+
// Set the visible pass options explicitly.
88+
vectorLength = vl;
89+
enableVLAVectorization = vla;
90+
enableSIMDIndex32 = index32;
91+
enableGPULibgen = gpu;
7892
}
7993

8094
/// Bufferize all dense ops. This assumes that no further analysis is needed
@@ -178,10 +192,6 @@ class SparsificationAndBufferizationPass
178192
bool createSparseDeallocs;
179193
bool enableRuntimeLibrary;
180194
bool enableBufferInitialization;
181-
unsigned vectorLength;
182-
bool enableVLAVectorization;
183-
bool enableSIMDIndex32;
184-
bool enableGPULibgen;
185195
};
186196

187197
} // namespace sparse_tensor
@@ -213,16 +223,13 @@ mlir::getBufferizationOptionsForSparsification(bool analysisOnly) {
213223

214224
std::unique_ptr<mlir::Pass> mlir::createSparsificationAndBufferizationPass() {
215225
SparsificationOptions sparseOptions;
216-
return createSparsificationAndBufferizationPass(
226+
return std::make_unique<
227+
mlir::sparse_tensor::SparsificationAndBufferizationPass>(
217228
getBufferizationOptionsForSparsification(/*analysisOnly=*/false),
218229
sparseOptions,
219230
/*createSparseDeallocs=*/false,
220231
/*enableRuntimeLibrary=*/false,
221-
/*enableBufferInitialization=*/false,
222-
/*vectorLength=*/0,
223-
/*enableVLAVectorization=*/false,
224-
/*enableSIMDIndex32=*/false,
225-
/*enableGPULibgen=*/false);
232+
/*enableBufferInitialization=*/false);
226233
}
227234

228235
std::unique_ptr<mlir::Pass> mlir::createSparsificationAndBufferizationPass(
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
// RUN: mlir-opt %s --sparsification-and-bufferization | FileCheck %s --check-prefix=CHECK-NOVEC
2+
// RUN: mlir-opt %s --sparsification-and-bufferization="vl=8" | FileCheck %s --check-prefix=CHECK-VEC
3+
4+
// Test to ensure we can pass optimization flags into
5+
// the mini sparsification and bufferization pipeline.
6+
7+
#SV = #sparse_tensor.encoding<{ map = (d0) -> (d0 : compressed) }>
8+
9+
#trait_sum_reduction = {
10+
indexing_maps = [
11+
affine_map<(i) -> (i)>, // a
12+
affine_map<(i) -> ()> // x (scalar out)
13+
],
14+
iterator_types = ["reduction"],
15+
doc = "x += SUM_i a(i)"
16+
}
17+
18+
//
19+
// CHECK-NOVEC-LABEL: func.func @sum_reduction
20+
// CHECK-NOVEC: scf.for
21+
// CHECK-NOVEC: arith.addf %{{.*}} %{{.*}} : f32
22+
// CHECK-NOVEC: }
23+
//
24+
// CHECK-VEC-LABEL: func.func @sum_reduction
25+
// CHECK-VEC: vector.insertelement
26+
// CHECK-VEC: scf.for
27+
// CHECK-VEC: vector.create_mask
28+
// CHECK-VEC: vector.maskedload
29+
// CHECK-VEC: arith.addf %{{.*}} %{{.*}} : vector<8xf32>
30+
// CHECK-VEC: }
31+
// CHECK-VEC: vector.reduction <add>
32+
//
33+
func.func @sum_reduction(%arga: tensor<?xf32, #SV>,
34+
%argx: tensor<f32>) -> tensor<f32> {
35+
%0 = linalg.generic #trait_sum_reduction
36+
ins(%arga: tensor<?xf32, #SV>)
37+
outs(%argx: tensor<f32>) {
38+
^bb(%a: f32, %x: f32):
39+
%0 = arith.addf %x, %a : f32
40+
linalg.yield %0 : f32
41+
} -> tensor<f32>
42+
return %0 : tensor<f32>
43+
}

0 commit comments

Comments
 (0)