Skip to content

[mlir][sparse] add parallelization options to mini pipeline #104233

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Aug 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,8 @@ std::unique_ptr<Pass> createSparsificationAndBufferizationPass(
bool createSparseDeallocs, bool enableRuntimeLibrary,
bool enableBufferInitialization, unsigned vectorLength,
bool enableVLAVectorization, bool enableSIMDIndex32, bool enableGPULibgen,
SparseEmitStrategy emitStrategy);
SparseEmitStrategy emitStrategy,
SparseParallelizationStrategy parallelizationStrategy);

//===----------------------------------------------------------------------===//
// Sparse Iteration Transform Passes
Expand Down
17 changes: 17 additions & 0 deletions mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -496,6 +496,23 @@ def SparsificationAndBufferization : Pass<"sparsification-and-bufferization", "M
"Emit (experimental) loops (with sparse.iterate)."),
clEnumValN(mlir::SparseEmitStrategy::kDebugInterface, "debug-interface",
"Emit non-functional but easy-to-read interfaces to debug."))}]>,
Option<"parallelization", "parallelization-strategy", "mlir::SparseParallelizationStrategy",
"mlir::SparseParallelizationStrategy::kNone",
"Set the parallelization strategy", [{llvm::cl::values(
clEnumValN(mlir::SparseParallelizationStrategy::kNone, "none",
"Turn off sparse parallelization."),
clEnumValN(mlir::SparseParallelizationStrategy::kDenseOuterLoop,
"dense-outer-loop",
"Enable dense outer loop sparse parallelization."),
clEnumValN(mlir::SparseParallelizationStrategy::kAnyStorageOuterLoop,
"any-storage-outer-loop",
"Enable sparse parallelization regardless of storage for the outer loop."),
clEnumValN(mlir::SparseParallelizationStrategy::kDenseAnyLoop,
"dense-any-loop",
"Enable dense parallelization for any loop."),
clEnumValN(mlir::SparseParallelizationStrategy::kAnyStorageAnyLoop,
"any-storage-any-loop",
"Enable sparse parallelization for any storage and loop."))}]>,
];
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@ void mlir::sparse_tensor::buildSparsifier(OpPassManager &pm,
/*enableVLAVectorization=*/options.armSVE,
/*enableSIMDIndex32=*/options.force32BitVectorIndices,
options.enableGPULibgen,
options.sparsificationOptions().sparseEmitStrategy));
options.sparsificationOptions().sparseEmitStrategy,
options.sparsificationOptions().parallelizationStrategy));

// Bail-early for test setup.
if (options.testBufferizationAnalysisOnly)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,8 @@ class SparsificationAndBufferizationPass
const SparsificationOptions &sparsificationOptions,
bool createSparseDeallocs, bool enableRuntimeLibrary,
bool enableBufferInitialization, unsigned vl, bool vla, bool index32,
bool gpu, SparseEmitStrategy emitStrategy)
bool gpu, SparseEmitStrategy emitStrategy,
SparseParallelizationStrategy parallelizationStrategy)
: bufferizationOptions(bufferizationOptions),
sparsificationOptions(sparsificationOptions),
createSparseDeallocs(createSparseDeallocs),
Expand All @@ -90,6 +91,7 @@ class SparsificationAndBufferizationPass
enableSIMDIndex32 = index32;
enableGPULibgen = gpu;
sparseEmitStrategy = emitStrategy;
parallelization = parallelizationStrategy;
}

/// Bufferize all dense ops. This assumes that no further analysis is needed
Expand Down Expand Up @@ -124,6 +126,9 @@ class SparsificationAndBufferizationPass
// Overrides the default emit strategy using user-provided value.
this->sparsificationOptions.sparseEmitStrategy = sparseEmitStrategy;

// Overrides the default parallelization strategy using user-provided value.
this->sparsificationOptions.parallelizationStrategy = parallelization;

// Run enabling transformations.
{
OpPassManager pm("builtin.module");
Expand Down Expand Up @@ -248,10 +253,12 @@ std::unique_ptr<mlir::Pass> mlir::createSparsificationAndBufferizationPass(
bool createSparseDeallocs, bool enableRuntimeLibrary,
bool enableBufferInitialization, unsigned vectorLength,
bool enableVLAVectorization, bool enableSIMDIndex32, bool enableGPULibgen,
SparseEmitStrategy emitStrategy) {
SparseEmitStrategy emitStrategy,
SparseParallelizationStrategy parallelizationStrategy) {
return std::make_unique<
mlir::sparse_tensor::SparsificationAndBufferizationPass>(
bufferizationOptions, sparsificationOptions, createSparseDeallocs,
enableRuntimeLibrary, enableBufferInitialization, vectorLength,
enableVLAVectorization, enableSIMDIndex32, enableGPULibgen, emitStrategy);
enableVLAVectorization, enableSIMDIndex32, enableGPULibgen, emitStrategy,
parallelizationStrategy);
}
38 changes: 38 additions & 0 deletions mlir/test/Dialect/SparseTensor/minipipeline_parallel.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
// RUN: mlir-opt %s --sparsification-and-bufferization | FileCheck %s --check-prefix=CHECK-NOPARA
// RUN: mlir-opt %s --sparsification-and-bufferization="parallelization-strategy=any-storage-any-loop" | FileCheck %s --check-prefix=CHECK-PARA

// Test to ensure we can pass parallelization flags into
// the mini sparsification and bufferization pipeline.

#SparseMatrix = #sparse_tensor.encoding<{
map = (d0, d1) -> (d0 : compressed, d1 : compressed)
}>

#trait_ss = {
indexing_maps = [
affine_map<(i,j) -> (i,j)>, // A
affine_map<(i,j) -> (i,j)> // X (out)
],
iterator_types = ["parallel", "parallel"],
doc = "X(i,j) = A(i,j) * SCALE"
}

//
// CHECK-NOPARA-LABEL: func.func @scale_ss
// CHECK-NOPARA: scf.for
//
// CHECK-PARA-LABEL: func.func @scale_ss
// CHECK-PARA: scf.parallel
//
func.func @scale_ss(%scale: f32,
%arga: tensor<?x?xf32, #SparseMatrix>,
%argx: tensor<?x?xf32>) -> tensor<?x?xf32> {
%0 = linalg.generic #trait_ss
ins(%arga: tensor<?x?xf32, #SparseMatrix>)
outs(%argx: tensor<?x?xf32>) {
^bb(%a: f32, %x: f32):
%0 = arith.mulf %a, %scale : f32
linalg.yield %0 : f32
} -> tensor<?x?xf32>
return %0 : tensor<?x?xf32>
}
Loading