Skip to content

Commit 99b3849

Browse files
committed
[mlir][sparse] introduce vectorization pass for sparse loops
This brings back previous SIMD functionality, but in a separate pass. The idea is to improve this new pass incrementally, going beyond for-loops to while-loops for co-iteration as welll (masking), while introducing new abstractions to make the lowering more progressive. The separation of sparsification and vectorization is a very good first step on this journey. Also brings back ArmSVE support Still to be fine-tuned: + use of "index" in SIMD loop (viz. a[i] = i) + check that all ops really have SIMD support + check all forms of reductions + chain reduction SIMD values Reviewed By: dcaballe Differential Revision: https://reviews.llvm.org/D138236
1 parent 9df8ba6 commit 99b3849

File tree

7 files changed

+1019
-92
lines changed

7 files changed

+1019
-92
lines changed

mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,16 @@ std::unique_ptr<Pass> createSparseBufferRewritePass();
172172
std::unique_ptr<Pass>
173173
createSparseBufferRewritePass(bool enableBufferInitialization);
174174

175+
void populateSparseVectorizationPatterns(RewritePatternSet &patterns,
176+
unsigned vectorLength,
177+
bool enableVLAVectorization,
178+
bool enableSIMDIndex32);
179+
180+
std::unique_ptr<Pass> createSparseVectorizationPass();
181+
std::unique_ptr<Pass> createSparseVectorizationPass(unsigned vectorLength,
182+
bool enableVLAVectorization,
183+
bool enableSIMDIndex32);
184+
175185
//===----------------------------------------------------------------------===//
176186
// Registration.
177187
//===----------------------------------------------------------------------===//

mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -225,4 +225,64 @@ def SparseBufferRewrite : Pass<"sparse-buffer-rewrite", "ModuleOp"> {
225225
];
226226
}
227227

228+
def SparseVectorization : Pass<"sparse-vectorization", "ModuleOp"> {
229+
let summary = "Vectorizes loops after sparsification";
230+
let description = [{
231+
A pass that converts loops after sparsification into vector loops.
232+
The vector dialect is used as target to provide an architectural
233+
neutral way of exploiting any platform that supports SIMD instructions.
234+
235+
The vector length (viz. `vl`) describes the number of packed data elements
236+
(e.g. both vector<16xf32> and vector<16xf64> have a vector length of 16 even
237+
though the actual bitwidths differ). A small multiple of the actual lengths
238+
supported in hardware typically results in efficient SIMD code, since the
239+
backend will map longer vectors to multiple vector registers, thereby
240+
effectively unrolling an addition level within the generated for-loop.
241+
242+
Example of the conversion:
243+
244+
```mlir
245+
Before:
246+
%3 = memref.load %2[] : memref<f32>
247+
%4 = scf.for %arg3 = %c0 to %c1024 step %c1 iter_args(%arg4 = %3) -> (f32) {
248+
%6 = memref.load %0[%arg3] : memref<?xf32>
249+
%7 = memref.load %1[%arg3] : memref<1024xf32>
250+
%8 = arith.mulf %6, %7 : f32
251+
%9 = arith.addf %arg4, %8 : f32
252+
scf.yield %9 : f32
253+
}
254+
memref.store %4, %2[] : memref<f32>
255+
256+
After:
257+
%3 = memref.load %2[] : memref<f32>
258+
%4 = vector.insertelement %3, %cst[%c0 : index] : vector<32xf32>
259+
%5 = scf.for %arg3 = %c0 to %c1024 step %c32 iter_args(%arg4 = %4) -> (vector<32xf32>) {
260+
%8 = vector.load %0[%arg3] : memref<?xf32>, vector<32xf32>
261+
%9 = vector.load %1[%arg3] : memref<1024xf32>, vector<32xf32>
262+
%10 = arith.mulf %8, %9 : vector<32xf32>
263+
%11 = arith.addf %arg4, %10 : vector<32xf32>
264+
scf.yield %11 : vector<32xf32>
265+
}
266+
%6 = vector.reduction <add>, %5 : vector<32xf32> into f32
267+
memref.store %6, %2[] : memref<f32>
268+
```
269+
}];
270+
let constructor = "mlir::createSparseVectorizationPass()";
271+
let dependentDialects = [
272+
"arith::ArithDialect",
273+
"memref::MemRefDialect",
274+
"scf::SCFDialect",
275+
"sparse_tensor::SparseTensorDialect",
276+
"vector::VectorDialect",
277+
];
278+
let options = [
279+
Option<"vectorLength", "vl", "int32_t", "0",
280+
"Set the vector length (use 0 to disable vectorization)">,
281+
Option<"enableVLAVectorization", "enable-vla-vectorization", "bool",
282+
"false", "Enable vector length agnostic vectorization">,
283+
Option<"enableSIMDIndex32", "enable-simd-index32", "bool", "false",
284+
"Enable i32 indexing into vectors (for efficient gather/scatter)">,
285+
];
286+
}
287+
228288
#endif // MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_PASSES

mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ add_mlir_dialect_library(MLIRSparseTensorTransforms
88
SparseTensorConversion.cpp
99
SparseTensorPasses.cpp
1010
SparseTensorRewriting.cpp
11+
SparseVectorization.cpp
1112

1213
ADDITIONAL_HEADER_DIRS
1314
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/SparseTensor

mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp

Lines changed: 35 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ namespace mlir {
2727
#define GEN_PASS_DEF_SPARSETENSORCONVERSIONPASS
2828
#define GEN_PASS_DEF_SPARSETENSORCODEGEN
2929
#define GEN_PASS_DEF_SPARSEBUFFERREWRITE
30+
#define GEN_PASS_DEF_SPARSEVECTORIZATION
3031
#include "mlir/Dialect/SparseTensor/Transforms/Passes.h.inc"
3132
} // namespace mlir
3233

@@ -67,10 +68,9 @@ struct SparsificationPass
6768
auto *ctx = &getContext();
6869
// Translate strategy flags to strategy options.
6970
SparsificationOptions options(parallelization);
70-
// Apply sparsification and vector cleanup rewriting.
71+
// Apply sparsification and cleanup rewriting.
7172
RewritePatternSet patterns(ctx);
7273
populateSparsificationPatterns(patterns, options);
73-
vector::populateVectorToVectorCanonicalizationPatterns(patterns);
7474
scf::ForOp::getCanonicalizationPatterns(patterns, ctx);
7575
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
7676
}
@@ -250,6 +250,27 @@ struct SparseBufferRewritePass
250250
}
251251
};
252252

253+
struct SparseVectorizationPass
254+
: public impl::SparseVectorizationBase<SparseVectorizationPass> {
255+
256+
SparseVectorizationPass() = default;
257+
SparseVectorizationPass(const SparseVectorizationPass &pass) = default;
258+
SparseVectorizationPass(unsigned vl, bool vla, bool sidx32) {
259+
vectorLength = vl;
260+
enableVLAVectorization = vla;
261+
enableSIMDIndex32 = sidx32;
262+
}
263+
264+
void runOnOperation() override {
265+
auto *ctx = &getContext();
266+
RewritePatternSet patterns(ctx);
267+
populateSparseVectorizationPatterns(
268+
patterns, vectorLength, enableVLAVectorization, enableSIMDIndex32);
269+
vector::populateVectorToVectorCanonicalizationPatterns(patterns);
270+
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
271+
}
272+
};
273+
253274
} // namespace
254275

255276
//===----------------------------------------------------------------------===//
@@ -322,3 +343,15 @@ std::unique_ptr<Pass>
322343
mlir::createSparseBufferRewritePass(bool enableBufferInitialization) {
323344
return std::make_unique<SparseBufferRewritePass>(enableBufferInitialization);
324345
}
346+
347+
std::unique_ptr<Pass> mlir::createSparseVectorizationPass() {
348+
return std::make_unique<SparseVectorizationPass>();
349+
}
350+
351+
std::unique_ptr<Pass>
352+
mlir::createSparseVectorizationPass(unsigned vectorLength,
353+
bool enableVLAVectorization,
354+
bool enableSIMDIndex32) {
355+
return std::make_unique<SparseVectorizationPass>(
356+
vectorLength, enableVLAVectorization, enableSIMDIndex32);
357+
}

0 commit comments

Comments
 (0)