Skip to content

[mlir][sparse] minor refactoring of sparsification file #74403

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Dec 5, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 19 additions & 38 deletions mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
#include "mlir/IR/Matchers.h"
#include "mlir/IR/TensorEncoding.h"
#include "llvm/ADT/SmallBitVector.h"

#include <optional>

using namespace mlir;
Expand All @@ -43,11 +44,6 @@ using namespace mlir::sparse_tensor;
// Sparsifier analysis methods.
//===----------------------------------------------------------------------===//

// TODO: the "idx"-vs-"ldx" naming convention is not self-explanatory,
// and those letters are too easy to confuse visually. We should switch
// to a more self-explanatory naming convention like "curLoop"-vs-"prevLoop"
// (assuming that's the actual meaning behind the "idx"-vs-"ldx" convention).

/// Determines if affine expression is invariant.
static bool isInvariantAffine(AffineExpr a, unsigned loopDepth, LoopId ldx,
bool &isAtLoop) {
Expand All @@ -56,11 +52,9 @@ static bool isInvariantAffine(AffineExpr a, unsigned loopDepth, LoopId ldx,
const LoopId i = cast<AffineDimExpr>(a).getPosition();
if (i == ldx) {
isAtLoop = true;
// Must be invariant if we are at the given loop.
return true;
return true; // invariant at given loop
}
// The DimExpr is invariant the loop has already been generated.
return i < loopDepth;
return i < loopDepth; // invariant when already generated
}
case AffineExprKind::Add:
case AffineExprKind::Mul: {
Expand All @@ -85,7 +79,6 @@ static bool findAffine(Merger &merger, TensorId tid, Level lvl, AffineExpr a,
const LoopId idx = merger.makeLoopId(cast<AffineDimExpr>(a).getPosition());
if (!isUndefLT(merger.getLvlType(tid, idx)))
return false; // used more than once

if (setLvlFormat)
merger.setLevelAndType(tid, idx, lvl, lt);
return true;
Expand Down Expand Up @@ -195,7 +188,7 @@ static bool findDepIdxSet(Merger &merger, TensorId tensor, Level lvl,
}
}

/// Get the total number of compound affine expressions in the
/// Gets the total number of compound affine expressions in the
/// `getMatchingIndexingMap` for the given tensor. For the following inputs:
///
/// map = (d0, d1, d2) => (d0 + d1 : compressed, d2 : compressed)
Expand Down Expand Up @@ -225,7 +218,7 @@ static unsigned getNumNonTrivialIdxExpOnSparseLvls(AffineMap map,
return num;
}

/// Get the total number of sparse levels with compound affine
/// Gets the total number of sparse levels with compound affine
/// expressions, summed over all operands of the `GenericOp`.
static unsigned getNumNonTrivialIdxExpOnSparseLvls(linalg::GenericOp op) {
unsigned num = 0;
Expand All @@ -235,6 +228,7 @@ static unsigned getNumNonTrivialIdxExpOnSparseLvls(linalg::GenericOp op) {
return num;
}

// Returns true iff output has nontrivial affine indices.
static bool hasNonTrivialAffineOnSparseOut(linalg::GenericOp op) {
OpOperand *out = op.getDpsInitOperand(0);
if (getSparseTensorType(out->get()).isAllDense())
Expand All @@ -260,11 +254,9 @@ static bool findSparseAnnotations(CodegenEnv &env, bool idxReducBased) {
const auto enc = getSparseTensorEncoding(t.get().getType());
if (enc)
annotated = true;

const Level lvlRank = map.getNumResults();
assert(!enc || lvlRank == enc.getLvlRank());
assert(static_cast<Level>(env.op().getRank(&t)) == lvlRank);

// We only need to do index reduction if there is at least one non-trivial
// index expression on sparse levels.
// If all non-trivial index expression is on dense levels, we can
Expand Down Expand Up @@ -343,9 +335,6 @@ static void genBuffers(CodegenEnv &env, OpBuilder &builder) {
}

/// Generates index for load/store on sparse tensor.
// FIXME: It's not entirely clear what "index" means here (i.e., is it
// a "coordinate", or "Ldx", or what). So the function should be renamed
// and/or the documentation expanded in order to clarify.
static Value genIndex(CodegenEnv &env, OpOperand *t) {
const auto map = env.op().getMatchingIndexingMap(t);
const auto stt = getSparseTensorType(t->get());
Expand Down Expand Up @@ -495,7 +484,6 @@ static Value genTensorLoad(CodegenEnv &env, OpBuilder &builder, ExprId exp) {
Value val = env.exp(exp).val;
if (val)
return val;

// Load during insertion.
linalg::GenericOp op = env.op();
OpOperand *t = &op->getOpOperand(env.exp(exp).tensor);
Expand Down Expand Up @@ -574,7 +562,7 @@ inline static Value genInvariantValue(CodegenEnv &env, ExprId exp) {
/// exception of index computations, which need to be relinked to actual
/// inlined cloned code.
static Value relinkBranch(CodegenEnv &env, RewriterBase &rewriter, Block *block,
Value e, LoopId ldx) {
Value e) {
if (auto arg = dyn_cast<BlockArgument>(e)) {
// Direct arguments of the original linalg op must be converted
// into dense tensor loads. Note that we should not encounter
Expand All @@ -598,7 +586,7 @@ static Value relinkBranch(CodegenEnv &env, RewriterBase &rewriter, Block *block,
for (unsigned i = 0, n = def->getNumOperands(); i < n; i++) {
rewriter.updateRootInPlace(def, [&]() {
def->setOperand(
i, relinkBranch(env, rewriter, block, def->getOperand(i), ldx));
i, relinkBranch(env, rewriter, block, def->getOperand(i)));
});
}
}
Expand All @@ -607,8 +595,7 @@ static Value relinkBranch(CodegenEnv &env, RewriterBase &rewriter, Block *block,
}

/// Recursively generates tensor expression.
static Value genExp(CodegenEnv &env, RewriterBase &rewriter, ExprId e,
LoopId ldx) {
static Value genExp(CodegenEnv &env, RewriterBase &rewriter, ExprId e) {
if (e == ::mlir::sparse_tensor::detail::kInvalidId)
return Value();

Expand All @@ -631,15 +618,15 @@ static Value genExp(CodegenEnv &env, RewriterBase &rewriter, ExprId e,
// based on the type of the other operand.
if (exp.children.e0 != ::mlir::sparse_tensor::detail::kInvalidId &&
env.exp(exp.children.e0).kind == TensorExp::Kind::kSynZero) {
v1 = genExp(env, rewriter, exp.children.e1, ldx);
v1 = genExp(env, rewriter, exp.children.e1);
v0 = constantZero(rewriter, loc, v1.getType());
} else if (exp.children.e1 != ::mlir::sparse_tensor::detail::kInvalidId &&
env.exp(exp.children.e1).kind == TensorExp::Kind::kSynZero) {
v0 = genExp(env, rewriter, exp.children.e0, ldx);
v0 = genExp(env, rewriter, exp.children.e0);
v1 = constantZero(rewriter, loc, v0.getType());
} else {
v0 = genExp(env, rewriter, exp.children.e0, ldx);
v1 = genExp(env, rewriter, exp.children.e1, ldx);
v0 = genExp(env, rewriter, exp.children.e0);
v1 = genExp(env, rewriter, exp.children.e1);
}

Value ee;
Expand All @@ -653,7 +640,7 @@ static Value genExp(CodegenEnv &env, RewriterBase &rewriter, ExprId e,
kind == TensorExp::Kind::kReduce ||
kind == TensorExp::Kind::kSelect)) {
OpBuilder::InsertionGuard guard(rewriter);
ee = relinkBranch(env, rewriter, ee.getParentBlock(), ee, ldx);
ee = relinkBranch(env, rewriter, ee.getParentBlock(), ee);
}
}

Expand Down Expand Up @@ -806,7 +793,6 @@ static bool shouldTryParallize(CodegenEnv &env, LoopId ldx, bool isOuter,
const auto lt = env.lt(env.unpackTensorLevel(tidLvl).first, ldx);
return isCompressedLT(lt) || isSingletonLT(lt);
});

return isParallelFor(env, isOuter, isSparse);
}

Expand Down Expand Up @@ -1112,11 +1098,6 @@ static bool translateBitsToTidLvlPairs(
// level. We need to generate the address according to the
// affine expression. This is also the best place we can do it
// to avoid putting it inside inner loops.
// NOTE: It assumes that the levels of the input tensor are
// initialized in order (and it is also currently guaranteed by
// computeIterationGraph), another more admissible approach
// might be accepting out-of-order access between consecutive
// dense levels.
affineTidLvls.emplace_back(env.makeTensorLevel(tid, l), exp);
}
}
Expand Down Expand Up @@ -1221,7 +1202,7 @@ static void genStmt(CodegenEnv &env, RewriterBase &rewriter, ExprId exp,
LoopOrd at) {
// At each leaf, assign remaining tensor (sub)expression to output tensor.
if (at == env.getLoopNum()) {
Value rhs = genExp(env, rewriter, exp, at - 1);
Value rhs = genExp(env, rewriter, exp);
genTensorStore(env, rewriter, exp, rhs);
return;
}
Expand All @@ -1235,8 +1216,7 @@ static void genStmt(CodegenEnv &env, RewriterBase &rewriter, ExprId exp,
bool needsUniv = startLoopSeq(env, rewriter, exp, at, ldx, lts);

// Emit a loop for every lattice point L0 >= Li in this loop sequence.
//
// NOTE: We cannot change this to `for (const LatPointId li : env.set(lts))`
// We cannot change this to `for (const LatPointId li : env.set(lts))`
// because the loop body causes data-movement which invalidates
// the iterator.
const unsigned lsize = env.set(lts).size();
Expand All @@ -1251,7 +1231,7 @@ static void genStmt(CodegenEnv &env, RewriterBase &rewriter, ExprId exp,
Value cntInput = env.getExpandCount();
Value insInput = env.getInsertionChain();
Value validIns = env.getValidLexInsert();
// NOTE: We cannot change this to `for (const LatPointId lj : env.set(lts))`
// We cannot change this to `for (const LatPointId lj : env.set(lts))`
// because the loop body causes data-movement which invalidates the
// iterator.
for (unsigned j = 0; j < lsize; j++) {
Expand Down Expand Up @@ -1323,6 +1303,7 @@ struct GenericOpSparsifier : public OpRewritePattern<linalg::GenericOp> {
if (hasNonTrivialAffineOnSparseOut(op))
return failure();

// Only accept scheduled loops.
if (!op->hasAttr("sorted")) {
return rewriter.notifyMatchFailure(
op, "Loops not yet scheduled, try run --sparse-reinterpret-map "
Expand All @@ -1348,9 +1329,9 @@ struct GenericOpSparsifier : public OpRewritePattern<linalg::GenericOp> {
}
}

CodegenEnv env(op, options, numTensors, numLoops, maxLvlRank);
// Detects sparse annotations and translates the per-level sparsity
// information for all tensors to loop indices in the kernel.
CodegenEnv env(op, options, numTensors, numLoops, maxLvlRank);
if (!findSparseAnnotations(env, needIdxRed))
return failure();

Expand Down