@@ -23,8 +23,44 @@ namespace {
23
23
// (2) rewrite linalg.generic ops traits on level crds
24
24
// (3) compute topsort, and resolve cyles with sparse_tensor.convert ops
25
25
26
+ // CRTP to help implementing a rewriter that demaps all its inputs and remaps
27
+ // all its outputs.
28
+ template <typename SubClass, typename SourceOp>
29
+ struct DemapInsRemapOutsRewriter : public OpRewritePattern <SourceOp> {
30
+ using OpRewritePattern<SourceOp>::OpRewritePattern;
31
+ using OpAdaptor = typename SourceOp::Adaptor;
32
+
33
+ LogicalResult matchAndRewrite (SourceOp op,
34
+ PatternRewriter &rewriter) const override {
35
+ if (!static_cast <const SubClass *>(this )->matchOp (op))
36
+ return failure ();
37
+
38
+ Location loc = op.getLoc ();
39
+ // Demaps non-trivial inputs.
40
+ SmallVector<Value> deMappedIns (op->getOperands ());
41
+ for (Value &in : deMappedIns)
42
+ if (auto stt = tryGetSparseTensorType (in); stt && !stt->isIdentity ())
43
+ in = rewriter.create <ReinterpretMapOp>(loc, stt->getDemappedType (), in);
44
+
45
+ // CRTP call.
46
+ OpAdaptor adaptor (deMappedIns);
47
+ ValueRange outs =
48
+ static_cast <const SubClass *>(this )->rewriteOp (op, adaptor, rewriter);
49
+ assert (outs.size () == op->getResults ().size ());
50
+
51
+ // Remap outputs.
52
+ SmallVector<Value> reMappedOuts (outs);
53
+ for (auto [r, a] : llvm::zip (reMappedOuts, op->getResults ()))
54
+ if (r.getType () != a.getType ())
55
+ r = rewriter.create <ReinterpretMapOp>(loc, a.getType (), r);
56
+
57
+ rewriter.replaceOp (op, reMappedOuts);
58
+ return success ();
59
+ }
60
+ };
61
+
26
62
// ===----------------------------------------------------------------------===//
27
- // Reiterpret Map Rewriters for operations other than linalg.generics
63
+ // Reinterpret Map Rewriters for operations other than linalg.generics
28
64
// ===----------------------------------------------------------------------===//
29
65
30
66
struct CrdTranslateRewriter : public OpRewritePattern <CrdTranslateOp> {
@@ -34,6 +70,7 @@ struct CrdTranslateRewriter : public OpRewritePattern<CrdTranslateOp> {
34
70
AffineMap map = op.getDirection () == CrdTransDirectionKind::dim2lvl
35
71
? op.getEncoder ().getDimToLvl ()
36
72
: op.getEncoder ().getLvlToDim ();
73
+
37
74
SmallVector<Value> outCrds;
38
75
for (AffineExpr result : map.getResults ()) {
39
76
// TODO: we should probably expand the affine map to IR using our own
@@ -49,24 +86,23 @@ struct CrdTranslateRewriter : public OpRewritePattern<CrdTranslateOp> {
49
86
}
50
87
};
51
88
52
- struct TensorInsertRewriter : public OpRewritePattern <tensor::InsertOp> {
53
- using OpRewritePattern::OpRewritePattern;
54
- LogicalResult matchAndRewrite (tensor::InsertOp op,
55
- PatternRewriter &rewriter) const override {
89
+ struct TensorInsertRewriter
90
+ : public DemapInsRemapOutsRewriter<TensorInsertRewriter, tensor::InsertOp> {
91
+ using DemapInsRemapOutsRewriter::DemapInsRemapOutsRewriter;
56
92
57
- if (!op.getResult ().getType ().getEncoding ())
58
- return failure ();
93
+ bool matchOp (tensor::InsertOp op) const {
94
+ return op.getResult ().getType ().getEncoding () != nullptr ;
95
+ }
96
+
97
+ ValueRange rewriteOp (tensor::InsertOp op, OpAdaptor adaptor,
98
+ PatternRewriter &rewriter) const {
59
99
Location loc = op.getLoc ();
60
100
auto stt = getSparseTensorType (op.getResult ());
61
101
ValueRange lvlCrd = stt.translateCrds (rewriter, loc, op.getIndices (),
62
102
CrdTransDirectionKind::dim2lvl);
63
-
64
- Value t = rewriter.create <ReinterpretMapOp>(
65
- loc, stt.getEncoding ().withoutDimToLvl (), op.getDest ());
66
- t = rewriter.create <sparse_tensor::InsertOp>(loc, op.getScalar (), t,
67
- lvlCrd);
68
- rewriter.replaceOpWithNewOp <ReinterpretMapOp>(op, op.getType (), t);
69
- return success ();
103
+ Operation *insertOp = rewriter.create <sparse_tensor::InsertOp>(
104
+ loc, op.getScalar (), adaptor.getDest (), lvlCrd);
105
+ return insertOp->getResults ();
70
106
}
71
107
};
72
108
0 commit comments