Skip to content

Commit e599978

Browse files
authored
[mlir][sparse] first proof-of-concept non-permutation rewriter (#70863)
Rather than extending sparsifier codegen with higher order non-permutations, we follow the path of rewriting linalg generic ops into higher order operations. That way, code generation will simply work out of the box. This is a very first proof-of-concept rewriting of that idea.
1 parent c5dafd1 commit e599978

File tree

2 files changed

+178
-9
lines changed

2 files changed

+178
-9
lines changed

mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp

Lines changed: 130 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
//===----------------------------------------------------------------------===//
88

99
#include "mlir/Dialect/Affine/IR/AffineOps.h"
10+
#include "mlir/Dialect/Linalg/IR/Linalg.h"
11+
#include "mlir/Dialect/Linalg/Utils/Utils.h"
1012
#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
1113
#include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h"
1214
#include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
@@ -18,10 +20,130 @@ using namespace mlir::sparse_tensor;
1820

1921
namespace {
2022

21-
// TODO:
22-
// (1) insert the zero-cost sparse_tensor.reinterpret_map ops
23-
// (2) rewrite linalg.generic ops traits on level crds
24-
// (3) compute topsort, and resolve cyles with sparse_tensor.convert ops
23+
//===----------------------------------------------------------------------===//
24+
// Helper methods.
25+
//===----------------------------------------------------------------------===//
26+
27+
// Translates a "simple" map according to an identity lvl-map.
28+
static AffineMap translateMap(OpBuilder &builder, SparseTensorType stt,
29+
AffineMap map) {
30+
unsigned lvlRank = stt.getLvlRank();
31+
AffineMap lvl2dim = stt.getLvlToDim();
32+
assert(lvl2dim.getNumInputs() == lvlRank);
33+
SmallVector<AffineExpr> exps;
34+
for (unsigned i = 0, n = map.getNumResults(); i < n; i++) {
35+
unsigned pos = map.getResult(i).cast<AffineDimExpr>().getPosition();
36+
exps.push_back(lvl2dim.getResult(pos));
37+
}
38+
return AffineMap::get(lvlRank, 0, exps, builder.getContext());
39+
}
40+
41+
// Generates a "de"mapping reinterpretation of the map.
42+
static Value genDemap(OpBuilder &builder, SparseTensorEncodingAttr enc,
43+
Value val) {
44+
return builder.create<ReinterpretMapOp>(val.getLoc(), enc.withoutDimToLvl(),
45+
val);
46+
}
47+
48+
// Generates a "re"mapping reinterpretation of the map.
49+
static Value genRemap(OpBuilder &builder, SparseTensorEncodingAttr enc,
50+
Value val) {
51+
return builder.create<ReinterpretMapOp>(val.getLoc(), enc, val);
52+
}
53+
54+
// Generates a clone of the given linalg generic operation, but with
55+
// remapped arguments, index maps, and iteration types.
56+
//
57+
// TODO: As decribed below, this is proof-of-concept code which makes a lot
58+
// of simplifying assumptions for now.
59+
//
60+
static linalg::GenericOp genGenericLinalg(PatternRewriter &rewriter,
61+
linalg::GenericOp linalgOp,
62+
SparseTensorType stt, Value out) {
63+
unsigned dimRank = stt.getDimRank();
64+
unsigned lvlRank = stt.getLvlRank();
65+
SmallVector<Value> inputOps = linalgOp.getInputs();
66+
SmallVector<Value> outputOps = {out};
67+
SmallVector<AffineMap> indexMaps;
68+
SmallVector<utils::IteratorType> iterTypes;
69+
// Translate the index maps, except output map, which is lvl-identity.
70+
auto maps = linalgOp.getIndexingMapsArray();
71+
for (unsigned i = 0, n = maps.size() - 1; i < n; i++)
72+
indexMaps.push_back(translateMap(rewriter, stt, maps[i]));
73+
indexMaps.push_back(
74+
AffineMap::getMultiDimIdentityMap(lvlRank, rewriter.getContext()));
75+
// Add additional "parallel" iteration types at the top.
76+
for (unsigned i = 0, diff = lvlRank = dimRank; i < diff; i++)
77+
iterTypes.push_back(utils::IteratorType::parallel);
78+
for (auto &i : linalgOp.getIteratorTypesArray())
79+
iterTypes.push_back(i);
80+
// Generate the new linalg generic operation and clone body.
81+
auto newOp = rewriter.create<linalg::GenericOp>(
82+
linalgOp.getLoc(), out.getType(), inputOps, outputOps, indexMaps,
83+
iterTypes);
84+
rewriter.cloneRegionBefore(linalgOp.getRegion(), newOp.getRegion(),
85+
newOp.getRegion().begin());
86+
return newOp;
87+
}
88+
89+
//===----------------------------------------------------------------------===//
90+
// Rewriting rules for linalg generic ops.
91+
//===----------------------------------------------------------------------===//
92+
93+
/// Sparse rewriting rule for the generic `linalg` operation.
94+
struct GenericOpReinterpretMap : public OpRewritePattern<linalg::GenericOp> {
95+
public:
96+
GenericOpReinterpretMap(MLIRContext *context)
97+
: OpRewritePattern<linalg::GenericOp>(context) {}
98+
99+
LogicalResult matchAndRewrite(linalg::GenericOp linalgOp,
100+
PatternRewriter &rewriter) const override {
101+
// Only rewrite single output operations with pure tensor semantics.
102+
if (linalgOp.getNumDpsInits() != 1 || !linalgOp.hasTensorSemantics())
103+
return failure();
104+
// Scan all operands, inspect sparse tensors.
105+
//
106+
// TODO: generalize this proof-of-concept algorithm, since the current
107+
// implementation accepts only simple indexing maps, and one
108+
// non-permutation sparse tensor, which must have an identity
109+
// indexing map and be the output.
110+
//
111+
OpOperand *tx = nullptr;
112+
for (OpOperand &t : linalgOp->getOpOperands()) {
113+
// Ensure every index map is "simple".
114+
const auto map = linalgOp.getMatchingIndexingMap(&t);
115+
for (unsigned i = 0, n = map.getNumResults(); i < n; i++)
116+
if (map.getResult(i).getKind() != AffineExprKind::DimId)
117+
return failure();
118+
// Inspect sparse operands.
119+
auto stt = getSparseTensorType(t.get());
120+
if (stt.hasEncoding()) {
121+
if (stt.isPermutation())
122+
continue;
123+
assert(stt.getDimRank() < stt.getLvlRank()); // only allowed non-perm
124+
if (tx)
125+
return failure(); // more than one non-perm
126+
if (!map.isIdentity())
127+
return failure(); // no ID indexing map on the non-perm
128+
tx = &t;
129+
}
130+
}
131+
// Found a non-permutation, rewrite when this is the output.
132+
if (tx && tx == linalgOp.getDpsInitOperand(0)) {
133+
auto stt = getSparseTensorType(tx->get());
134+
auto demap = genDemap(rewriter, stt.getEncoding(), tx->get());
135+
auto newOp = genGenericLinalg(rewriter, linalgOp, stt, demap);
136+
auto remap = genRemap(rewriter, stt.getEncoding(), newOp.getResult(0));
137+
rewriter.replaceOp(linalgOp, remap);
138+
return success();
139+
}
140+
return failure();
141+
}
142+
};
143+
144+
//===----------------------------------------------------------------------===//
145+
// Rewriting rules for operations other than linalg generic ops.
146+
//===----------------------------------------------------------------------===//
25147

26148
// CRTP to help implementing a rewriter that demaps all its inputs and remaps
27149
// all its outputs.
@@ -59,10 +181,6 @@ struct DemapInsRemapOutsRewriter : public OpRewritePattern<SourceOp> {
59181
}
60182
};
61183

62-
//===----------------------------------------------------------------------===//
63-
// Reinterpret Map Rewriters for operations other than linalg.generics
64-
//===----------------------------------------------------------------------===//
65-
66184
struct CrdTranslateRewriter : public OpRewritePattern<CrdTranslateOp> {
67185
using OpRewritePattern::OpRewritePattern;
68186
LogicalResult matchAndRewrite(CrdTranslateOp op,
@@ -110,6 +228,10 @@ struct TensorInsertRewriter
110228

111229
void mlir::populateSparseReinterpretMap(RewritePatternSet &patterns,
112230
ReinterpretMapScope scope) {
231+
if (scope == ReinterpretMapScope::kAll ||
232+
scope == ReinterpretMapScope::kGenericOnly) {
233+
patterns.add<GenericOpReinterpretMap>(patterns.getContext());
234+
}
113235
if (scope == ReinterpretMapScope::kAll ||
114236
scope == ReinterpretMapScope::kExceptGeneric) {
115237
patterns.add<CrdTranslateRewriter, TensorInsertRewriter>(

mlir/test/Dialect/SparseTensor/sparse_reinterpret_map.mlir

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: mlir-opt %s --sparse-reinterpret-map | FileCheck %s
1+
// RUN: mlir-opt %s -split-input-file --sparse-reinterpret-map | FileCheck %s
22

33
#SparseVector = #sparse_tensor.encoding<{ map = (d0) -> (d0 : compressed) }>
44

@@ -8,3 +8,50 @@
88
func.func @sparse_nop(%arg0: tensor<?xf64, #SparseVector>) -> tensor<?xf64, #SparseVector> {
99
return %arg0 : tensor<?xf64, #SparseVector>
1010
}
11+
12+
// -----
13+
14+
#trait_mul = {
15+
indexing_maps = [
16+
affine_map<(i,j) -> (i,j)>, // A (in)
17+
affine_map<(i,j) -> (j,i)>, // B (in, transposed)
18+
affine_map<(i,j) -> (i,j)> // X (out)
19+
],
20+
iterator_types = ["parallel", "parallel"],
21+
doc = "X(i,j) *= A(i,j) * B(j,i)"
22+
}
23+
24+
#BSR = #sparse_tensor.encoding<{ // 2x4 blocks
25+
map = (i, j) ->
26+
( i floordiv 2 : dense
27+
, j floordiv 4 : compressed
28+
, i mod 2 : dense
29+
, j mod 4 : dense
30+
)
31+
}>
32+
33+
// CHECK: #[[$map0:.*]] = affine_map<(d0, d1, d2, d3) -> (d0 * 2 + d2, d1 * 4 + d3)>
34+
// CHECK: #[[$map1:.*]] = affine_map<(d0, d1, d2, d3) -> (d1 * 4 + d3, d0 * 2 + d2)>
35+
// CHECK: #[[$map2:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
36+
// CHECK-LABEL: func @mul(
37+
// CHECK-SAME: %[[A0:.*0]]: tensor<32x32xf32>,
38+
// CHECK-SAME: %[[A1:.*1]]: tensor<32x32xf32>,
39+
// CHECK-SAME: %[[A2:.*2]]: tensor<32x32xf32, #sparse_tensor.encoding<{{{.*}}}>>)
40+
// CHECK: %[[T0:.*]] = sparse_tensor.reinterpret_map %[[A2]]
41+
// CHECK: %[[T1:.*]] = linalg.generic {indexing_maps = [#[[$map0]], #[[$map1]], #[[$map2]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
42+
// CHECK: %[[T2:.*]] = sparse_tensor.reinterpret_map %[[T1]]
43+
// CHECK: return %[[T2]] : tensor<32x32xf32, #sparse_tensor.encoding<{{{.*}}}>>
44+
func.func @mul(%arg0: tensor<32x32xf32>,
45+
%arg1: tensor<32x32xf32>,
46+
%arg2: tensor<32x32xf32, #BSR>) -> tensor<32x32xf32, #BSR> {
47+
%0 = linalg.generic #trait_mul
48+
ins(%arg0, %arg1: tensor<32x32xf32>, tensor<32x32xf32>)
49+
outs(%arg2: tensor<32x32xf32, #BSR>) {
50+
^bb(%x: f32, %y : f32, %z : f32):
51+
%1 = arith.mulf %x, %y : f32
52+
%2 = arith.mulf %1, %z : f32
53+
linalg.yield %2 : f32
54+
} -> tensor<32x32xf32, #BSR>
55+
return %0 : tensor<32x32xf32, #BSR>
56+
}
57+

0 commit comments

Comments
 (0)