Skip to content

Commit 067bebb

Browse files
authored
[mlir][sparse] minor refactoring of sparsification file (#74403)
Removed obsoleted TODOs and NOTEs, formatting, removed unused parameter
1 parent 030b8cb commit 067bebb

File tree

1 file changed

+19
-38
lines changed

1 file changed

+19
-38
lines changed

mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp

Lines changed: 19 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
#include "mlir/IR/Matchers.h"
3535
#include "mlir/IR/TensorEncoding.h"
3636
#include "llvm/ADT/SmallBitVector.h"
37+
3738
#include <optional>
3839

3940
using namespace mlir;
@@ -43,11 +44,6 @@ using namespace mlir::sparse_tensor;
4344
// Sparsifier analysis methods.
4445
//===----------------------------------------------------------------------===//
4546

46-
// TODO: the "idx"-vs-"ldx" naming convention is not self-explanatory,
47-
// and those letters are too easy to confuse visually. We should switch
48-
// to a more self-explanatory naming convention like "curLoop"-vs-"prevLoop"
49-
// (assuming that's the actual meaning behind the "idx"-vs-"ldx" convention).
50-
5147
/// Determines if affine expression is invariant.
5248
static bool isInvariantAffine(AffineExpr a, unsigned loopDepth, LoopId ldx,
5349
bool &isAtLoop) {
@@ -56,11 +52,9 @@ static bool isInvariantAffine(AffineExpr a, unsigned loopDepth, LoopId ldx,
5652
const LoopId i = cast<AffineDimExpr>(a).getPosition();
5753
if (i == ldx) {
5854
isAtLoop = true;
59-
// Must be invariant if we are at the given loop.
60-
return true;
55+
return true; // invariant at given loop
6156
}
62-
// The DimExpr is invariant the loop has already been generated.
63-
return i < loopDepth;
57+
return i < loopDepth; // invariant when already generated
6458
}
6559
case AffineExprKind::Add:
6660
case AffineExprKind::Mul: {
@@ -85,7 +79,6 @@ static bool findAffine(Merger &merger, TensorId tid, Level lvl, AffineExpr a,
8579
const LoopId idx = merger.makeLoopId(cast<AffineDimExpr>(a).getPosition());
8680
if (!isUndefLT(merger.getLvlType(tid, idx)))
8781
return false; // used more than once
88-
8982
if (setLvlFormat)
9083
merger.setLevelAndType(tid, idx, lvl, lt);
9184
return true;
@@ -195,7 +188,7 @@ static bool findDepIdxSet(Merger &merger, TensorId tensor, Level lvl,
195188
}
196189
}
197190

198-
/// Get the total number of compound affine expressions in the
191+
/// Gets the total number of compound affine expressions in the
199192
/// `getMatchingIndexingMap` for the given tensor. For the following inputs:
200193
///
201194
/// map = (d0, d1, d2) => (d0 + d1 : compressed, d2 : compressed)
@@ -225,7 +218,7 @@ static unsigned getNumNonTrivialIdxExpOnSparseLvls(AffineMap map,
225218
return num;
226219
}
227220

228-
/// Get the total number of sparse levels with compound affine
221+
/// Gets the total number of sparse levels with compound affine
229222
/// expressions, summed over all operands of the `GenericOp`.
230223
static unsigned getNumNonTrivialIdxExpOnSparseLvls(linalg::GenericOp op) {
231224
unsigned num = 0;
@@ -235,6 +228,7 @@ static unsigned getNumNonTrivialIdxExpOnSparseLvls(linalg::GenericOp op) {
235228
return num;
236229
}
237230

231+
// Returns true iff output has nontrivial affine indices.
238232
static bool hasNonTrivialAffineOnSparseOut(linalg::GenericOp op) {
239233
OpOperand *out = op.getDpsInitOperand(0);
240234
if (getSparseTensorType(out->get()).isAllDense())
@@ -260,11 +254,9 @@ static bool findSparseAnnotations(CodegenEnv &env, bool idxReducBased) {
260254
const auto enc = getSparseTensorEncoding(t.get().getType());
261255
if (enc)
262256
annotated = true;
263-
264257
const Level lvlRank = map.getNumResults();
265258
assert(!enc || lvlRank == enc.getLvlRank());
266259
assert(static_cast<Level>(env.op().getRank(&t)) == lvlRank);
267-
268260
// We only need to do index reduction if there is at least one non-trivial
269261
// index expression on sparse levels.
270262
// If all non-trivial index expression is on dense levels, we can
@@ -343,9 +335,6 @@ static void genBuffers(CodegenEnv &env, OpBuilder &builder) {
343335
}
344336

345337
/// Generates index for load/store on sparse tensor.
346-
// FIXME: It's not entirely clear what "index" means here (i.e., is it
347-
// a "coordinate", or "Ldx", or what). So the function should be renamed
348-
// and/or the documentation expanded in order to clarify.
349338
static Value genIndex(CodegenEnv &env, OpOperand *t) {
350339
const auto map = env.op().getMatchingIndexingMap(t);
351340
const auto stt = getSparseTensorType(t->get());
@@ -495,7 +484,6 @@ static Value genTensorLoad(CodegenEnv &env, OpBuilder &builder, ExprId exp) {
495484
Value val = env.exp(exp).val;
496485
if (val)
497486
return val;
498-
499487
// Load during insertion.
500488
linalg::GenericOp op = env.op();
501489
OpOperand *t = &op->getOpOperand(env.exp(exp).tensor);
@@ -574,7 +562,7 @@ inline static Value genInvariantValue(CodegenEnv &env, ExprId exp) {
574562
/// exception of index computations, which need to be relinked to actual
575563
/// inlined cloned code.
576564
static Value relinkBranch(CodegenEnv &env, RewriterBase &rewriter, Block *block,
577-
Value e, LoopId ldx) {
565+
Value e) {
578566
if (auto arg = dyn_cast<BlockArgument>(e)) {
579567
// Direct arguments of the original linalg op must be converted
580568
// into dense tensor loads. Note that we should not encounter
@@ -598,7 +586,7 @@ static Value relinkBranch(CodegenEnv &env, RewriterBase &rewriter, Block *block,
598586
for (unsigned i = 0, n = def->getNumOperands(); i < n; i++) {
599587
rewriter.updateRootInPlace(def, [&]() {
600588
def->setOperand(
601-
i, relinkBranch(env, rewriter, block, def->getOperand(i), ldx));
589+
i, relinkBranch(env, rewriter, block, def->getOperand(i)));
602590
});
603591
}
604592
}
@@ -607,8 +595,7 @@ static Value relinkBranch(CodegenEnv &env, RewriterBase &rewriter, Block *block,
607595
}
608596

609597
/// Recursively generates tensor expression.
610-
static Value genExp(CodegenEnv &env, RewriterBase &rewriter, ExprId e,
611-
LoopId ldx) {
598+
static Value genExp(CodegenEnv &env, RewriterBase &rewriter, ExprId e) {
612599
if (e == ::mlir::sparse_tensor::detail::kInvalidId)
613600
return Value();
614601

@@ -631,15 +618,15 @@ static Value genExp(CodegenEnv &env, RewriterBase &rewriter, ExprId e,
631618
// based on the type of the other operand.
632619
if (exp.children.e0 != ::mlir::sparse_tensor::detail::kInvalidId &&
633620
env.exp(exp.children.e0).kind == TensorExp::Kind::kSynZero) {
634-
v1 = genExp(env, rewriter, exp.children.e1, ldx);
621+
v1 = genExp(env, rewriter, exp.children.e1);
635622
v0 = constantZero(rewriter, loc, v1.getType());
636623
} else if (exp.children.e1 != ::mlir::sparse_tensor::detail::kInvalidId &&
637624
env.exp(exp.children.e1).kind == TensorExp::Kind::kSynZero) {
638-
v0 = genExp(env, rewriter, exp.children.e0, ldx);
625+
v0 = genExp(env, rewriter, exp.children.e0);
639626
v1 = constantZero(rewriter, loc, v0.getType());
640627
} else {
641-
v0 = genExp(env, rewriter, exp.children.e0, ldx);
642-
v1 = genExp(env, rewriter, exp.children.e1, ldx);
628+
v0 = genExp(env, rewriter, exp.children.e0);
629+
v1 = genExp(env, rewriter, exp.children.e1);
643630
}
644631

645632
Value ee;
@@ -653,7 +640,7 @@ static Value genExp(CodegenEnv &env, RewriterBase &rewriter, ExprId e,
653640
kind == TensorExp::Kind::kReduce ||
654641
kind == TensorExp::Kind::kSelect)) {
655642
OpBuilder::InsertionGuard guard(rewriter);
656-
ee = relinkBranch(env, rewriter, ee.getParentBlock(), ee, ldx);
643+
ee = relinkBranch(env, rewriter, ee.getParentBlock(), ee);
657644
}
658645
}
659646

@@ -806,7 +793,6 @@ static bool shouldTryParallize(CodegenEnv &env, LoopId ldx, bool isOuter,
806793
const auto lt = env.lt(env.unpackTensorLevel(tidLvl).first, ldx);
807794
return isCompressedLT(lt) || isSingletonLT(lt);
808795
});
809-
810796
return isParallelFor(env, isOuter, isSparse);
811797
}
812798

@@ -1112,11 +1098,6 @@ static bool translateBitsToTidLvlPairs(
11121098
// level. We need to generate the address according to the
11131099
// affine expression. This is also the best place we can do it
11141100
// to avoid putting it inside inner loops.
1115-
// NOTE: It assumes that the levels of the input tensor are
1116-
// initialized in order (and it is also currently guaranteed by
1117-
// computeIterationGraph), another more admissible approach
1118-
// might be accepting out-of-order access between consecutive
1119-
// dense levels.
11201101
affineTidLvls.emplace_back(env.makeTensorLevel(tid, l), exp);
11211102
}
11221103
}
@@ -1221,7 +1202,7 @@ static void genStmt(CodegenEnv &env, RewriterBase &rewriter, ExprId exp,
12211202
LoopOrd at) {
12221203
// At each leaf, assign remaining tensor (sub)expression to output tensor.
12231204
if (at == env.getLoopNum()) {
1224-
Value rhs = genExp(env, rewriter, exp, at - 1);
1205+
Value rhs = genExp(env, rewriter, exp);
12251206
genTensorStore(env, rewriter, exp, rhs);
12261207
return;
12271208
}
@@ -1235,8 +1216,7 @@ static void genStmt(CodegenEnv &env, RewriterBase &rewriter, ExprId exp,
12351216
bool needsUniv = startLoopSeq(env, rewriter, exp, at, ldx, lts);
12361217

12371218
// Emit a loop for every lattice point L0 >= Li in this loop sequence.
1238-
//
1239-
// NOTE: We cannot change this to `for (const LatPointId li : env.set(lts))`
1219+
// We cannot change this to `for (const LatPointId li : env.set(lts))`
12401220
// because the loop body causes data-movement which invalidates
12411221
// the iterator.
12421222
const unsigned lsize = env.set(lts).size();
@@ -1251,7 +1231,7 @@ static void genStmt(CodegenEnv &env, RewriterBase &rewriter, ExprId exp,
12511231
Value cntInput = env.getExpandCount();
12521232
Value insInput = env.getInsertionChain();
12531233
Value validIns = env.getValidLexInsert();
1254-
// NOTE: We cannot change this to `for (const LatPointId lj : env.set(lts))`
1234+
// We cannot change this to `for (const LatPointId lj : env.set(lts))`
12551235
// because the loop body causes data-movement which invalidates the
12561236
// iterator.
12571237
for (unsigned j = 0; j < lsize; j++) {
@@ -1323,6 +1303,7 @@ struct GenericOpSparsifier : public OpRewritePattern<linalg::GenericOp> {
13231303
if (hasNonTrivialAffineOnSparseOut(op))
13241304
return failure();
13251305

1306+
// Only accept scheduled loops.
13261307
if (!op->hasAttr("sorted")) {
13271308
return rewriter.notifyMatchFailure(
13281309
op, "Loops not yet scheduled, try run --sparse-reinterpret-map "
@@ -1348,9 +1329,9 @@ struct GenericOpSparsifier : public OpRewritePattern<linalg::GenericOp> {
13481329
}
13491330
}
13501331

1351-
CodegenEnv env(op, options, numTensors, numLoops, maxLvlRank);
13521332
// Detects sparse annotations and translates the per-level sparsity
13531333
// information for all tensors to loop indices in the kernel.
1334+
CodegenEnv env(op, options, numTensors, numLoops, maxLvlRank);
13541335
if (!findSparseAnnotations(env, needIdxRed))
13551336
return failure();
13561337

0 commit comments

Comments
 (0)