Skip to content

Commit 6456e0b

Browse files
authored
[mlir][sparse] implement sparse_tensor.crd_translate operation (llvm#69653)
1 parent 6243d7d commit 6456e0b

File tree

5 files changed

+68
-17
lines changed

5 files changed

+68
-17
lines changed

mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -429,6 +429,11 @@ def SparseTensorEncodingAttr : SparseTensor_Attr<"SparseTensorEncoding",
429429
std::optional<uint64_t> getStaticLvlSliceSize(::mlir::sparse_tensor::Level lvl) const;
430430
std::optional<uint64_t> getStaticLvlSliceStride(::mlir::sparse_tensor::Level lvl) const;
431431

432+
//
433+
// Helper function to build IR related to the encoding.
434+
//
435+
ValueRange translateCrds(::mlir::OpBuilder &builder, ::mlir::Location loc, ::mlir::ValueRange crds, ::mlir::sparse_tensor::CrdTransDirectionKind) const;
436+
432437
//
433438
// Printing methods.
434439
//

mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorType.h

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,10 @@ class SparseTensorType {
176176
/// Returns the encoding (or the null-attribute for dense-tensors).
177177
SparseTensorEncodingAttr getEncoding() const { return enc; }
178178

179+
//
180+
// SparseTensorEncodingAttr delegators
181+
//
182+
179183
/// Returns true for tensors which have an encoding, and false for
180184
/// those which do not. Therefore tensors with an all-dense encoding
181185
/// return true.
@@ -189,14 +193,20 @@ class SparseTensorType {
189193
/// (This is always true for dense-tensors.)
190194
bool isAllOrdered() const { return enc.isAllOrdered(); }
191195

192-
/// Returns true if the dimToLvl mapping is the identity.
193-
/// (This is always true for dense-tensors.)
194-
bool isIdentity() const { return !dimToLvl; }
196+
/// Translates between level / dimension coordinate space.
197+
ValueRange translateCrds(OpBuilder &builder, Location loc, ValueRange crds,
198+
CrdTransDirectionKind dir) const {
199+
return enc.translateCrds(builder, loc, crds, dir);
200+
}
195201

196202
/// Returns true if the dimToLvl mapping is a permutation.
197203
/// (This is always true for dense-tensors.)
198204
bool isPermutation() const { return enc.isPermutation(); }
199205

206+
/// Returns true if the dimToLvl mapping is the identity.
207+
/// (This is always true for dense-tensors.)
208+
bool isIdentity() const { return enc.isIdentity(); }
209+
200210
/// Returns the dimToLvl mapping (or the null-map for the identity).
201211
/// If you intend to compare the results of this method for equality,
202212
/// see `hasSameDimToLvl` instead.

mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,7 @@ def PostSparsificationRewrite : Pass<"post-sparsification-rewrite", "ModuleOp">
143143
}];
144144
let constructor = "mlir::createPostSparsificationRewritePass()";
145145
let dependentDialects = [
146+
"affine::AffineDialect",
146147
"arith::ArithDialect",
147148
"bufferization::BufferizationDialect",
148149
"linalg::LinalgDialect",

mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -415,6 +415,20 @@ SparseTensorEncodingAttr::getStaticLvlSliceStride(Level lvl) const {
415415
return getStaticDimSliceStride(toOrigDim(*this, lvl));
416416
}
417417

418+
ValueRange
419+
SparseTensorEncodingAttr::translateCrds(OpBuilder &builder, Location loc,
420+
ValueRange crds,
421+
CrdTransDirectionKind dir) const {
422+
if (!getImpl())
423+
return crds;
424+
425+
SmallVector<Type> retType(
426+
dir == CrdTransDirectionKind::lvl2dim ? getDimRank() : getLvlRank(),
427+
builder.getIndexType());
428+
auto transOp = builder.create<CrdTranslateOp>(loc, retType, crds, dir, *this);
429+
return transOp.getOutCrds();
430+
}
431+
418432
Attribute SparseTensorEncodingAttr::parse(AsmParser &parser, Type type) {
419433
#define RETURN_ON_FAIL(stmt) \
420434
if (failed(stmt)) { \
@@ -1155,6 +1169,10 @@ LogicalResult CrdTranslateOp::verify() {
11551169

11561170
LogicalResult CrdTranslateOp::fold(FoldAdaptor adaptor,
11571171
SmallVectorImpl<OpFoldResult> &results) {
1172+
if (getEncoder().isIdentity()) {
1173+
results.assign(getInCrds().begin(), getInCrds().end());
1174+
return success();
1175+
}
11581176
if (getEncoder().isPermutation()) {
11591177
AffineMap perm = getDirection() == CrdTransDirectionKind::dim2lvl
11601178
? getEncoder().getDimToLvl()

mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp

Lines changed: 31 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#include "CodegenUtils.h"
1414
#include "LoopEmitter.h"
1515

16+
#include "mlir/Dialect/Affine/IR/AffineOps.h"
1617
#include "mlir/Dialect/Arith/IR/Arith.h"
1718
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
1819
#include "mlir/Dialect/Linalg/IR/Linalg.h"
@@ -865,6 +866,28 @@ struct TensorLike {
865866
Value val;
866867
};
867868

869+
struct CrdTranslateRewriter : public OpRewritePattern<CrdTranslateOp> {
870+
using OpRewritePattern::OpRewritePattern;
871+
LogicalResult matchAndRewrite(CrdTranslateOp op,
872+
PatternRewriter &rewriter) const override {
873+
AffineMap map = op.getDirection() == CrdTransDirectionKind::dim2lvl
874+
? op.getEncoder().getDimToLvl()
875+
: op.getEncoder().getLvlToDim();
876+
SmallVector<Value> outCrds;
877+
for (AffineExpr result : map.getResults()) {
878+
// TODO: we should probably expand the affine map to IR using our own
879+
// rules, since affine.apply assume signed value, while the cooridinates
880+
// we provided must always be signless.
881+
Value trans = rewriter.create<affine::AffineApplyOp>(
882+
op.getLoc(), AffineMap::get(map.getNumDims(), 0, result),
883+
op.getInCrds());
884+
outCrds.push_back(trans);
885+
}
886+
rewriter.replaceOp(op, outCrds);
887+
return success();
888+
}
889+
};
890+
868891
struct ConcatenateRewriter : public OpRewritePattern<ConcatenateOp> {
869892
using OpRewritePattern::OpRewritePattern;
870893
LogicalResult matchAndRewrite(ConcatenateOp op,
@@ -999,13 +1022,9 @@ struct DirectConvertRewriter : public OpRewritePattern<ConvertOp> {
9991022
ValueRange reduc) {
10001023
// Enters the loop, update the SSA value for insertion chain.
10011024
dstBuf.val = reduc.front();
1002-
const Dimension dimRank = dstStt.getDimRank();
1003-
const Level lvlRank = dstStt.getLvlRank();
1004-
SmallVector<Value> lcvs(lvlRank);
1005-
for (Dimension d = 0; d < dimRank; d++) {
1006-
// FIXME: `toStoredDim` is deprecated
1007-
lcvs[toStoredDim(dstStt.getEncoding(), d)] = dcvs[d];
1008-
}
1025+
1026+
ValueRange lcvs = dstStt.translateCrds(
1027+
builder, loc, dcvs, CrdTransDirectionKind::dim2lvl);
10091028

10101029
if (!skipZeroCheck) {
10111030
Value cond = genIsNonzero(builder, loc, v);
@@ -1101,12 +1120,9 @@ struct ForeachRewriter : public OpRewritePattern<ForeachOp> {
11011120
Block *srcBlock = op.getBody();
11021121

11031122
// Remap coordinates.
1104-
SmallVector<Value> args;
1105-
for (Dimension d = 0; d < dimRank; d++) {
1106-
// FIXME: `toStoredDim` is deprecated
1107-
Value dimCrd = lcvs[toStoredDim(enc, d)];
1108-
args.push_back(dimCrd);
1109-
}
1123+
SmallVector<Value> args =
1124+
enc.translateCrds(rewriter, loc, lcvs, CrdTransDirectionKind::lvl2dim);
1125+
11101126
// Remap value.
11111127
args.push_back(val);
11121128
// Remap reduction variables.
@@ -1249,7 +1265,8 @@ void mlir::populatePostSparsificationRewriting(RewritePatternSet &patterns,
12491265
bool enableRT,
12501266
bool enableForeach,
12511267
bool enableConvert) {
1252-
patterns.add<ConcatenateRewriter, ReshapeRewriter<tensor::ExpandShapeOp>,
1268+
patterns.add<ConcatenateRewriter, CrdTranslateRewriter,
1269+
ReshapeRewriter<tensor::ExpandShapeOp>,
12531270
ReshapeRewriter<tensor::CollapseShapeOp>,
12541271
Sparse2SparseReshapeRewriter<tensor::ExpandShapeOp>,
12551272
Sparse2SparseReshapeRewriter<tensor::CollapseShapeOp>,

0 commit comments

Comments
 (0)