Skip to content

Commit b19c40c

Browse files
authored
[mlir][sparse] first end-to-end linalg.generic op on BSR (#70880)
1 parent 4c1c32c commit b19c40c

File tree

4 files changed

+36
-13
lines changed

4 files changed

+36
-13
lines changed

mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -341,11 +341,9 @@ template <typename T>
341341
inline SparseTensorType getSparseTensorType(T t) {
342342
return SparseTensorType(getRankedTensorType(t));
343343
}
344-
template <typename T>
345-
inline std::optional<SparseTensorType> tryGetSparseTensorType(T t) {
346-
RankedTensorType rtp = getRankedTensorType(t);
347-
if (rtp)
348-
return SparseTensorType(rtp);
344+
inline std::optional<SparseTensorType> tryGetSparseTensorType(Value v) {
345+
if (isa<RankedTensorType>(v.getType()))
346+
return getSparseTensorType(v);
349347
return std::nullopt;
350348
}
351349

mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -116,11 +116,11 @@ struct GenericOpReinterpretMap : public OpRewritePattern<linalg::GenericOp> {
116116
if (map.getResult(i).getKind() != AffineExprKind::DimId)
117117
return failure();
118118
// Inspect sparse operands.
119-
auto stt = getSparseTensorType(t.get());
120-
if (stt.hasEncoding()) {
121-
if (stt.isPermutation())
119+
auto stt = tryGetSparseTensorType(t.get());
120+
if (stt && stt->hasEncoding()) {
121+
if (stt->isPermutation())
122122
continue;
123-
assert(stt.getDimRank() < stt.getLvlRank()); // only allowed non-perm
123+
assert(stt->getDimRank() < stt->getLvlRank()); // only allowed non-perm
124124
if (tx)
125125
return failure(); // more than one non-perm
126126
if (!map.isIdentity())

mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -104,8 +104,8 @@ class SparsificationAndBufferizationPass
104104
}
105105

106106
void runOnOperation() override {
107+
// Run enabling transformations.
107108
{
108-
// Run enabling transformations.
109109
OpPassManager pm("builtin.module");
110110
pm.addPass(createPreSparsificationRewritePass());
111111
pm.addNestedPass<func::FuncOp>(
@@ -128,7 +128,7 @@ class SparsificationAndBufferizationPass
128128
bufferizationOptions)))
129129
return signalPassFailure();
130130

131-
// `testAnalysisOnly` is a debug/testing flag. If set, the results of
131+
// Option `testAnalysisOnly` is a debug/testing flag. If set, the results of
132132
// OneShotAnalysis are added to the IR via attributes. In that case, do not
133133
// continue with the remaining pipeline.
134134
if (bufferizationOptions.testAnalysisOnly)
@@ -139,6 +139,8 @@ class SparsificationAndBufferizationPass
139139
// of `bufferization.alloc_tensor` ops.
140140
{
141141
OpPassManager pm("builtin.module");
142+
pm.addPass(
143+
createSparseReinterpretMapPass(ReinterpretMapScope::kGenericOnly));
142144
pm.addPass(createSparsificationPass(sparsificationOptions));
143145
pm.addNestedPass<func::FuncOp>(createStageSparseOperationsPass());
144146
pm.addPass(createLowerSparseOpsToForeachPass(enableRuntimeLibrary,

mlir/test/Integration/Dialect/SparseTensor/CPU/block.mlir

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@
2525
// REDEFINE: %{sparse_compiler_opts} = enable-runtime-library=false
2626
// R_UN: %{compile} | env %{env} %{run} | FileCheck %s
2727

28+
!Filename = !llvm.ptr<i8>
29+
2830
#BSR = #sparse_tensor.encoding<{
2931
map = (i, j) ->
3032
( i floordiv 2 : dense
@@ -38,8 +40,12 @@
3840
map = (i, j, k, l) -> ( i : dense, j : compressed, k : dense, l : dense)
3941
}>
4042

41-
42-
!Filename = !llvm.ptr<i8>
43+
#trait_scale_inplace = {
44+
indexing_maps = [
45+
affine_map<(i,j) -> (i,j)> // X (out)
46+
],
47+
iterator_types = ["parallel", "parallel"]
48+
}
4349

4450
//
4551
// Example 2x2 block storage:
@@ -62,6 +68,17 @@ module {
6268

6369
func.func private @getTensorFilename(index) -> (!Filename)
6470

71+
func.func @scale(%arg0: tensor<?x?xf64, #BSR>) -> tensor<?x?xf64, #BSR> {
72+
%c = arith.constant 3.0 : f64
73+
%0 = linalg.generic #trait_scale_inplace
74+
outs(%arg0: tensor<?x?xf64, #BSR>) {
75+
^bb(%x: f64):
76+
%1 = arith.mulf %x, %c : f64
77+
linalg.yield %1 : f64
78+
} -> tensor<?x?xf64, #BSR>
79+
return %0 : tensor<?x?xf64, #BSR>
80+
}
81+
6582
func.func @entry() {
6683
%c0 = arith.constant 0 : index
6784
%f0 = arith.constant 0.0 : f64
@@ -89,6 +106,12 @@ module {
89106
%vecdsdd = vector.transfer_read %vdsdd[%c0], %f0 : memref<?xf64>, vector<12xf64>
90107
vector.print %vecdsdd : vector<12xf64>
91108

109+
// CHECK-NEXT: ( 3, 6, 0, 9, 12, 0, 0, 15, 18, 21, 24, 0 )
110+
%As = call @scale(%A) : (tensor<?x?xf64, #BSR>) -> (tensor<?x?xf64, #BSR>)
111+
%vals = sparse_tensor.values %As : tensor<?x?xf64, #BSR> to memref<?xf64>
112+
%vecs = vector.transfer_read %vals[%c0], %f0 : memref<?xf64>, vector<12xf64>
113+
vector.print %vecs : vector<12xf64>
114+
92115
// Release the resources.
93116
bufferization.dealloc_tensor %A: tensor<?x?xf64, #BSR>
94117

0 commit comments

Comments
 (0)