7
7
// ===----------------------------------------------------------------------===//
8
8
9
9
#include " mlir/Dialect/Affine/IR/AffineOps.h"
10
+ #include " mlir/Dialect/Linalg/IR/Linalg.h"
11
+ #include " mlir/Dialect/Linalg/Utils/Utils.h"
10
12
#include " mlir/Dialect/SparseTensor/IR/SparseTensor.h"
11
13
#include " mlir/Dialect/SparseTensor/IR/SparseTensorType.h"
12
14
#include " mlir/Dialect/SparseTensor/Transforms/Passes.h"
@@ -18,10 +20,130 @@ using namespace mlir::sparse_tensor;
18
20
19
21
namespace {
20
22
21
- // TODO:
22
- // (1) insert the zero-cost sparse_tensor.reinterpret_map ops
23
- // (2) rewrite linalg.generic ops traits on level crds
24
- // (3) compute topsort, and resolve cyles with sparse_tensor.convert ops
23
+ // ===----------------------------------------------------------------------===//
24
+ // Helper methods.
25
+ // ===----------------------------------------------------------------------===//
26
+
27
+ // Translates a "simple" map according to an identity lvl-map.
28
+ static AffineMap translateMap (OpBuilder &builder, SparseTensorType stt,
29
+ AffineMap map) {
30
+ unsigned lvlRank = stt.getLvlRank ();
31
+ AffineMap lvl2dim = stt.getLvlToDim ();
32
+ assert (lvl2dim.getNumInputs () == lvlRank);
33
+ SmallVector<AffineExpr> exps;
34
+ for (unsigned i = 0 , n = map.getNumResults (); i < n; i++) {
35
+ unsigned pos = map.getResult (i).cast <AffineDimExpr>().getPosition ();
36
+ exps.push_back (lvl2dim.getResult (pos));
37
+ }
38
+ return AffineMap::get (lvlRank, 0 , exps, builder.getContext ());
39
+ }
40
+
41
+ // Generates a "de"mapping reinterpretation of the map.
42
+ static Value genDemap (OpBuilder &builder, SparseTensorEncodingAttr enc,
43
+ Value val) {
44
+ return builder.create <ReinterpretMapOp>(val.getLoc (), enc.withoutDimToLvl (),
45
+ val);
46
+ }
47
+
48
+ // Generates a "re"mapping reinterpretation of the map.
49
+ static Value genRemap (OpBuilder &builder, SparseTensorEncodingAttr enc,
50
+ Value val) {
51
+ return builder.create <ReinterpretMapOp>(val.getLoc (), enc, val);
52
+ }
53
+
54
+ // Generates a clone of the given linalg generic operation, but with
55
+ // remapped arguments, index maps, and iteration types.
56
+ //
57
+ // TODO: As decribed below, this is proof-of-concept code which makes a lot
58
+ // of simplifying assumptions for now.
59
+ //
60
+ static linalg::GenericOp genGenericLinalg (PatternRewriter &rewriter,
61
+ linalg::GenericOp linalgOp,
62
+ SparseTensorType stt, Value out) {
63
+ unsigned dimRank = stt.getDimRank ();
64
+ unsigned lvlRank = stt.getLvlRank ();
65
+ SmallVector<Value> inputOps = linalgOp.getInputs ();
66
+ SmallVector<Value> outputOps = {out};
67
+ SmallVector<AffineMap> indexMaps;
68
+ SmallVector<utils::IteratorType> iterTypes;
69
+ // Translate the index maps, except output map, which is lvl-identity.
70
+ auto maps = linalgOp.getIndexingMapsArray ();
71
+ for (unsigned i = 0 , n = maps.size () - 1 ; i < n; i++)
72
+ indexMaps.push_back (translateMap (rewriter, stt, maps[i]));
73
+ indexMaps.push_back (
74
+ AffineMap::getMultiDimIdentityMap (lvlRank, rewriter.getContext ()));
75
+ // Add additional "parallel" iteration types at the top.
76
+ for (unsigned i = 0 , diff = lvlRank = dimRank; i < diff; i++)
77
+ iterTypes.push_back (utils::IteratorType::parallel);
78
+ for (auto &i : linalgOp.getIteratorTypesArray ())
79
+ iterTypes.push_back (i);
80
+ // Generate the new linalg generic operation and clone body.
81
+ auto newOp = rewriter.create <linalg::GenericOp>(
82
+ linalgOp.getLoc (), out.getType (), inputOps, outputOps, indexMaps,
83
+ iterTypes);
84
+ rewriter.cloneRegionBefore (linalgOp.getRegion (), newOp.getRegion (),
85
+ newOp.getRegion ().begin ());
86
+ return newOp;
87
+ }
88
+
89
+ // ===----------------------------------------------------------------------===//
90
+ // Rewriting rules for linalg generic ops.
91
+ // ===----------------------------------------------------------------------===//
92
+
93
+ // / Sparse rewriting rule for the generic `linalg` operation.
94
+ struct GenericOpReinterpretMap : public OpRewritePattern <linalg::GenericOp> {
95
+ public:
96
+ GenericOpReinterpretMap (MLIRContext *context)
97
+ : OpRewritePattern<linalg::GenericOp>(context) {}
98
+
99
+ LogicalResult matchAndRewrite (linalg::GenericOp linalgOp,
100
+ PatternRewriter &rewriter) const override {
101
+ // Only rewrite single output operations with pure tensor semantics.
102
+ if (linalgOp.getNumDpsInits () != 1 || !linalgOp.hasTensorSemantics ())
103
+ return failure ();
104
+ // Scan all operands, inspect sparse tensors.
105
+ //
106
+ // TODO: generalize this proof-of-concept algorithm, since the current
107
+ // implementation accepts only simple indexing maps, and one
108
+ // non-permutation sparse tensor, which must have an identity
109
+ // indexing map and be the output.
110
+ //
111
+ OpOperand *tx = nullptr ;
112
+ for (OpOperand &t : linalgOp->getOpOperands ()) {
113
+ // Ensure every index map is "simple".
114
+ const auto map = linalgOp.getMatchingIndexingMap (&t);
115
+ for (unsigned i = 0 , n = map.getNumResults (); i < n; i++)
116
+ if (map.getResult (i).getKind () != AffineExprKind::DimId)
117
+ return failure ();
118
+ // Inspect sparse operands.
119
+ auto stt = getSparseTensorType (t.get ());
120
+ if (stt.hasEncoding ()) {
121
+ if (stt.isPermutation ())
122
+ continue ;
123
+ assert (stt.getDimRank () < stt.getLvlRank ()); // only allowed non-perm
124
+ if (tx)
125
+ return failure (); // more than one non-perm
126
+ if (!map.isIdentity ())
127
+ return failure (); // no ID indexing map on the non-perm
128
+ tx = &t;
129
+ }
130
+ }
131
+ // Found a non-permutation, rewrite when this is the output.
132
+ if (tx && tx == linalgOp.getDpsInitOperand (0 )) {
133
+ auto stt = getSparseTensorType (tx->get ());
134
+ auto demap = genDemap (rewriter, stt.getEncoding (), tx->get ());
135
+ auto newOp = genGenericLinalg (rewriter, linalgOp, stt, demap);
136
+ auto remap = genRemap (rewriter, stt.getEncoding (), newOp.getResult (0 ));
137
+ rewriter.replaceOp (linalgOp, remap);
138
+ return success ();
139
+ }
140
+ return failure ();
141
+ }
142
+ };
143
+
144
+ // ===----------------------------------------------------------------------===//
145
+ // Rewriting rules for operations other than linalg generic ops.
146
+ // ===----------------------------------------------------------------------===//
25
147
26
148
// CRTP to help implementing a rewriter that demaps all its inputs and remaps
27
149
// all its outputs.
@@ -59,10 +181,6 @@ struct DemapInsRemapOutsRewriter : public OpRewritePattern<SourceOp> {
59
181
}
60
182
};
61
183
62
- // ===----------------------------------------------------------------------===//
63
- // Reinterpret Map Rewriters for operations other than linalg.generics
64
- // ===----------------------------------------------------------------------===//
65
-
66
184
struct CrdTranslateRewriter : public OpRewritePattern <CrdTranslateOp> {
67
185
using OpRewritePattern::OpRewritePattern;
68
186
LogicalResult matchAndRewrite (CrdTranslateOp op,
@@ -110,6 +228,10 @@ struct TensorInsertRewriter
110
228
111
229
void mlir::populateSparseReinterpretMap (RewritePatternSet &patterns,
112
230
ReinterpretMapScope scope) {
231
+ if (scope == ReinterpretMapScope::kAll ||
232
+ scope == ReinterpretMapScope::kGenericOnly ) {
233
+ patterns.add <GenericOpReinterpretMap>(patterns.getContext ());
234
+ }
113
235
if (scope == ReinterpretMapScope::kAll ||
114
236
scope == ReinterpretMapScope::kExceptGeneric ) {
115
237
patterns.add <CrdTranslateRewriter, TensorInsertRewriter>(
0 commit comments