Skip to content

Commit c81a2c0

Browse files
authored
[mlir][sparse] add helper class to implement common rewriter to re/demap sparse tensors. (#70750)
1 parent 297230a commit c81a2c0

File tree

3 files changed

+59
-16
lines changed

3 files changed

+59
-16
lines changed

mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ template <typename T>
9797
inline RankedTensorType getRankedTensorType(T &&t) {
9898
assert(static_cast<bool>(std::forward<T>(t)) &&
9999
"getRankedTensorType got null argument");
100-
return cast<RankedTensorType>(std::forward<T>(t).getType());
100+
return dyn_cast<RankedTensorType>(std::forward<T>(t).getType());
101101
}
102102

103103
/// Convenience method to abbreviate casting `getType()`.

mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -336,11 +336,18 @@ class SparseTensorType {
336336
const AffineMap lvlToDim;
337337
};
338338

339-
/// Convenience method to abbreviate wrapping `getRankedTensorType`.
339+
/// Convenience methods to abbreviate wrapping `getRankedTensorType`.
340340
template <typename T>
341341
inline SparseTensorType getSparseTensorType(T t) {
342342
return SparseTensorType(getRankedTensorType(t));
343343
}
344+
template <typename T>
345+
inline std::optional<SparseTensorType> tryGetSparseTensorType(T t) {
346+
RankedTensorType rtp = getRankedTensorType(t);
347+
if (rtp)
348+
return SparseTensorType(rtp);
349+
return std::nullopt;
350+
}
344351

345352
} // namespace sparse_tensor
346353
} // namespace mlir

mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp

Lines changed: 50 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,44 @@ namespace {
2323
// (2) rewrite linalg.generic ops traits on level crds
2424
// (3) compute topsort, and resolve cyles with sparse_tensor.convert ops
2525

26+
// CRTP to help implementing a rewriter that demaps all its inputs and remaps
27+
// all its outputs.
28+
template <typename SubClass, typename SourceOp>
29+
struct DemapInsRemapOutsRewriter : public OpRewritePattern<SourceOp> {
30+
using OpRewritePattern<SourceOp>::OpRewritePattern;
31+
using OpAdaptor = typename SourceOp::Adaptor;
32+
33+
LogicalResult matchAndRewrite(SourceOp op,
34+
PatternRewriter &rewriter) const override {
35+
if (!static_cast<const SubClass *>(this)->matchOp(op))
36+
return failure();
37+
38+
Location loc = op.getLoc();
39+
// Demaps non-trivial inputs.
40+
SmallVector<Value> deMappedIns(op->getOperands());
41+
for (Value &in : deMappedIns)
42+
if (auto stt = tryGetSparseTensorType(in); stt && !stt->isIdentity())
43+
in = rewriter.create<ReinterpretMapOp>(loc, stt->getDemappedType(), in);
44+
45+
// CRTP call.
46+
OpAdaptor adaptor(deMappedIns);
47+
ValueRange outs =
48+
static_cast<const SubClass *>(this)->rewriteOp(op, adaptor, rewriter);
49+
assert(outs.size() == op->getResults().size());
50+
51+
// Remap outputs.
52+
SmallVector<Value> reMappedOuts(outs);
53+
for (auto [r, a] : llvm::zip(reMappedOuts, op->getResults()))
54+
if (r.getType() != a.getType())
55+
r = rewriter.create<ReinterpretMapOp>(loc, a.getType(), r);
56+
57+
rewriter.replaceOp(op, reMappedOuts);
58+
return success();
59+
}
60+
};
61+
2662
//===----------------------------------------------------------------------===//
27-
// Reiterpret Map Rewriters for operations other than linalg.generics
63+
// Reinterpret Map Rewriters for operations other than linalg.generics
2864
//===----------------------------------------------------------------------===//
2965

3066
struct CrdTranslateRewriter : public OpRewritePattern<CrdTranslateOp> {
@@ -34,6 +70,7 @@ struct CrdTranslateRewriter : public OpRewritePattern<CrdTranslateOp> {
3470
AffineMap map = op.getDirection() == CrdTransDirectionKind::dim2lvl
3571
? op.getEncoder().getDimToLvl()
3672
: op.getEncoder().getLvlToDim();
73+
3774
SmallVector<Value> outCrds;
3875
for (AffineExpr result : map.getResults()) {
3976
// TODO: we should probably expand the affine map to IR using our own
@@ -49,24 +86,23 @@ struct CrdTranslateRewriter : public OpRewritePattern<CrdTranslateOp> {
4986
}
5087
};
5188

52-
struct TensorInsertRewriter : public OpRewritePattern<tensor::InsertOp> {
53-
using OpRewritePattern::OpRewritePattern;
54-
LogicalResult matchAndRewrite(tensor::InsertOp op,
55-
PatternRewriter &rewriter) const override {
89+
struct TensorInsertRewriter
90+
: public DemapInsRemapOutsRewriter<TensorInsertRewriter, tensor::InsertOp> {
91+
using DemapInsRemapOutsRewriter::DemapInsRemapOutsRewriter;
5692

57-
if (!op.getResult().getType().getEncoding())
58-
return failure();
93+
bool matchOp(tensor::InsertOp op) const {
94+
return op.getResult().getType().getEncoding() != nullptr;
95+
}
96+
97+
ValueRange rewriteOp(tensor::InsertOp op, OpAdaptor adaptor,
98+
PatternRewriter &rewriter) const {
5999
Location loc = op.getLoc();
60100
auto stt = getSparseTensorType(op.getResult());
61101
ValueRange lvlCrd = stt.translateCrds(rewriter, loc, op.getIndices(),
62102
CrdTransDirectionKind::dim2lvl);
63-
64-
Value t = rewriter.create<ReinterpretMapOp>(
65-
loc, stt.getEncoding().withoutDimToLvl(), op.getDest());
66-
t = rewriter.create<sparse_tensor::InsertOp>(loc, op.getScalar(), t,
67-
lvlCrd);
68-
rewriter.replaceOpWithNewOp<ReinterpretMapOp>(op, op.getType(), t);
69-
return success();
103+
Operation *insertOp = rewriter.create<sparse_tensor::InsertOp>(
104+
loc, op.getScalar(), adaptor.getDest(), lvlCrd);
105+
return insertOp->getResults();
70106
}
71107
};
72108

0 commit comments

Comments
 (0)