34
34
#include " mlir/IR/Matchers.h"
35
35
#include " mlir/IR/TensorEncoding.h"
36
36
#include " llvm/ADT/SmallBitVector.h"
37
+
37
38
#include < optional>
38
39
39
40
using namespace mlir ;
@@ -43,11 +44,6 @@ using namespace mlir::sparse_tensor;
43
44
// Sparsifier analysis methods.
44
45
// ===----------------------------------------------------------------------===//
45
46
46
- // TODO: the "idx"-vs-"ldx" naming convention is not self-explanatory,
47
- // and those letters are too easy to confuse visually. We should switch
48
- // to a more self-explanatory naming convention like "curLoop"-vs-"prevLoop"
49
- // (assuming that's the actual meaning behind the "idx"-vs-"ldx" convention).
50
-
51
47
// / Determines if affine expression is invariant.
52
48
static bool isInvariantAffine (AffineExpr a, unsigned loopDepth, LoopId ldx,
53
49
bool &isAtLoop) {
@@ -56,11 +52,9 @@ static bool isInvariantAffine(AffineExpr a, unsigned loopDepth, LoopId ldx,
56
52
const LoopId i = cast<AffineDimExpr>(a).getPosition ();
57
53
if (i == ldx) {
58
54
isAtLoop = true ;
59
- // Must be invariant if we are at the given loop.
60
- return true ;
55
+ return true ; // invariant at given loop
61
56
}
62
- // The DimExpr is invariant the loop has already been generated.
63
- return i < loopDepth;
57
+ return i < loopDepth; // invariant when already generated
64
58
}
65
59
case AffineExprKind::Add:
66
60
case AffineExprKind::Mul: {
@@ -85,7 +79,6 @@ static bool findAffine(Merger &merger, TensorId tid, Level lvl, AffineExpr a,
85
79
const LoopId idx = merger.makeLoopId (cast<AffineDimExpr>(a).getPosition ());
86
80
if (!isUndefLT (merger.getLvlType (tid, idx)))
87
81
return false ; // used more than once
88
-
89
82
if (setLvlFormat)
90
83
merger.setLevelAndType (tid, idx, lvl, lt);
91
84
return true ;
@@ -195,7 +188,7 @@ static bool findDepIdxSet(Merger &merger, TensorId tensor, Level lvl,
195
188
}
196
189
}
197
190
198
- // / Get the total number of compound affine expressions in the
191
+ // / Gets the total number of compound affine expressions in the
199
192
// / `getMatchingIndexingMap` for the given tensor. For the following inputs:
200
193
// /
201
194
// / map = (d0, d1, d2) => (d0 + d1 : compressed, d2 : compressed)
@@ -225,7 +218,7 @@ static unsigned getNumNonTrivialIdxExpOnSparseLvls(AffineMap map,
225
218
return num;
226
219
}
227
220
228
- // / Get the total number of sparse levels with compound affine
221
+ // / Gets the total number of sparse levels with compound affine
229
222
// / expressions, summed over all operands of the `GenericOp`.
230
223
static unsigned getNumNonTrivialIdxExpOnSparseLvls (linalg::GenericOp op) {
231
224
unsigned num = 0 ;
@@ -235,6 +228,7 @@ static unsigned getNumNonTrivialIdxExpOnSparseLvls(linalg::GenericOp op) {
235
228
return num;
236
229
}
237
230
231
+ // Returns true iff output has nontrivial affine indices.
238
232
static bool hasNonTrivialAffineOnSparseOut (linalg::GenericOp op) {
239
233
OpOperand *out = op.getDpsInitOperand (0 );
240
234
if (getSparseTensorType (out->get ()).isAllDense ())
@@ -260,11 +254,9 @@ static bool findSparseAnnotations(CodegenEnv &env, bool idxReducBased) {
260
254
const auto enc = getSparseTensorEncoding (t.get ().getType ());
261
255
if (enc)
262
256
annotated = true ;
263
-
264
257
const Level lvlRank = map.getNumResults ();
265
258
assert (!enc || lvlRank == enc.getLvlRank ());
266
259
assert (static_cast <Level>(env.op ().getRank (&t)) == lvlRank);
267
-
268
260
// We only need to do index reduction if there is at least one non-trivial
269
261
// index expression on sparse levels.
270
262
// If all non-trivial index expression is on dense levels, we can
@@ -343,9 +335,6 @@ static void genBuffers(CodegenEnv &env, OpBuilder &builder) {
343
335
}
344
336
345
337
// / Generates index for load/store on sparse tensor.
346
- // FIXME: It's not entirely clear what "index" means here (i.e., is it
347
- // a "coordinate", or "Ldx", or what). So the function should be renamed
348
- // and/or the documentation expanded in order to clarify.
349
338
static Value genIndex (CodegenEnv &env, OpOperand *t) {
350
339
const auto map = env.op ().getMatchingIndexingMap (t);
351
340
const auto stt = getSparseTensorType (t->get ());
@@ -495,7 +484,6 @@ static Value genTensorLoad(CodegenEnv &env, OpBuilder &builder, ExprId exp) {
495
484
Value val = env.exp (exp).val ;
496
485
if (val)
497
486
return val;
498
-
499
487
// Load during insertion.
500
488
linalg::GenericOp op = env.op ();
501
489
OpOperand *t = &op->getOpOperand (env.exp (exp).tensor );
@@ -574,7 +562,7 @@ inline static Value genInvariantValue(CodegenEnv &env, ExprId exp) {
574
562
// / exception of index computations, which need to be relinked to actual
575
563
// / inlined cloned code.
576
564
static Value relinkBranch (CodegenEnv &env, RewriterBase &rewriter, Block *block,
577
- Value e, LoopId ldx ) {
565
+ Value e) {
578
566
if (auto arg = dyn_cast<BlockArgument>(e)) {
579
567
// Direct arguments of the original linalg op must be converted
580
568
// into dense tensor loads. Note that we should not encounter
@@ -598,7 +586,7 @@ static Value relinkBranch(CodegenEnv &env, RewriterBase &rewriter, Block *block,
598
586
for (unsigned i = 0 , n = def->getNumOperands (); i < n; i++) {
599
587
rewriter.updateRootInPlace (def, [&]() {
600
588
def->setOperand (
601
- i, relinkBranch (env, rewriter, block, def->getOperand (i), ldx ));
589
+ i, relinkBranch (env, rewriter, block, def->getOperand (i)));
602
590
});
603
591
}
604
592
}
@@ -607,8 +595,7 @@ static Value relinkBranch(CodegenEnv &env, RewriterBase &rewriter, Block *block,
607
595
}
608
596
609
597
// / Recursively generates tensor expression.
610
- static Value genExp (CodegenEnv &env, RewriterBase &rewriter, ExprId e,
611
- LoopId ldx) {
598
+ static Value genExp (CodegenEnv &env, RewriterBase &rewriter, ExprId e) {
612
599
if (e == ::mlir::sparse_tensor::detail::kInvalidId )
613
600
return Value ();
614
601
@@ -631,15 +618,15 @@ static Value genExp(CodegenEnv &env, RewriterBase &rewriter, ExprId e,
631
618
// based on the type of the other operand.
632
619
if (exp.children .e0 != ::mlir::sparse_tensor::detail::kInvalidId &&
633
620
env.exp (exp.children .e0 ).kind == TensorExp::Kind::kSynZero ) {
634
- v1 = genExp (env, rewriter, exp.children .e1 , ldx );
621
+ v1 = genExp (env, rewriter, exp.children .e1 );
635
622
v0 = constantZero (rewriter, loc, v1.getType ());
636
623
} else if (exp.children .e1 != ::mlir::sparse_tensor::detail::kInvalidId &&
637
624
env.exp (exp.children .e1 ).kind == TensorExp::Kind::kSynZero ) {
638
- v0 = genExp (env, rewriter, exp.children .e0 , ldx );
625
+ v0 = genExp (env, rewriter, exp.children .e0 );
639
626
v1 = constantZero (rewriter, loc, v0.getType ());
640
627
} else {
641
- v0 = genExp (env, rewriter, exp.children .e0 , ldx );
642
- v1 = genExp (env, rewriter, exp.children .e1 , ldx );
628
+ v0 = genExp (env, rewriter, exp.children .e0 );
629
+ v1 = genExp (env, rewriter, exp.children .e1 );
643
630
}
644
631
645
632
Value ee;
@@ -653,7 +640,7 @@ static Value genExp(CodegenEnv &env, RewriterBase &rewriter, ExprId e,
653
640
kind == TensorExp::Kind::kReduce ||
654
641
kind == TensorExp::Kind::kSelect )) {
655
642
OpBuilder::InsertionGuard guard (rewriter);
656
- ee = relinkBranch (env, rewriter, ee.getParentBlock (), ee, ldx );
643
+ ee = relinkBranch (env, rewriter, ee.getParentBlock (), ee);
657
644
}
658
645
}
659
646
@@ -806,7 +793,6 @@ static bool shouldTryParallize(CodegenEnv &env, LoopId ldx, bool isOuter,
806
793
const auto lt = env.lt (env.unpackTensorLevel (tidLvl).first , ldx);
807
794
return isCompressedLT (lt) || isSingletonLT (lt);
808
795
});
809
-
810
796
return isParallelFor (env, isOuter, isSparse);
811
797
}
812
798
@@ -1112,11 +1098,6 @@ static bool translateBitsToTidLvlPairs(
1112
1098
// level. We need to generate the address according to the
1113
1099
// affine expression. This is also the best place we can do it
1114
1100
// to avoid putting it inside inner loops.
1115
- // NOTE: It assumes that the levels of the input tensor are
1116
- // initialized in order (and it is also currently guaranteed by
1117
- // computeIterationGraph), another more admissible approach
1118
- // might be accepting out-of-order access between consecutive
1119
- // dense levels.
1120
1101
affineTidLvls.emplace_back (env.makeTensorLevel (tid, l), exp);
1121
1102
}
1122
1103
}
@@ -1221,7 +1202,7 @@ static void genStmt(CodegenEnv &env, RewriterBase &rewriter, ExprId exp,
1221
1202
LoopOrd at) {
1222
1203
// At each leaf, assign remaining tensor (sub)expression to output tensor.
1223
1204
if (at == env.getLoopNum ()) {
1224
- Value rhs = genExp (env, rewriter, exp, at - 1 );
1205
+ Value rhs = genExp (env, rewriter, exp);
1225
1206
genTensorStore (env, rewriter, exp, rhs);
1226
1207
return ;
1227
1208
}
@@ -1235,8 +1216,7 @@ static void genStmt(CodegenEnv &env, RewriterBase &rewriter, ExprId exp,
1235
1216
bool needsUniv = startLoopSeq (env, rewriter, exp, at, ldx, lts);
1236
1217
1237
1218
// Emit a loop for every lattice point L0 >= Li in this loop sequence.
1238
- //
1239
- // NOTE: We cannot change this to `for (const LatPointId li : env.set(lts))`
1219
+ // We cannot change this to `for (const LatPointId li : env.set(lts))`
1240
1220
// because the loop body causes data-movement which invalidates
1241
1221
// the iterator.
1242
1222
const unsigned lsize = env.set (lts).size ();
@@ -1251,7 +1231,7 @@ static void genStmt(CodegenEnv &env, RewriterBase &rewriter, ExprId exp,
1251
1231
Value cntInput = env.getExpandCount ();
1252
1232
Value insInput = env.getInsertionChain ();
1253
1233
Value validIns = env.getValidLexInsert ();
1254
- // NOTE: We cannot change this to `for (const LatPointId lj : env.set(lts))`
1234
+ // We cannot change this to `for (const LatPointId lj : env.set(lts))`
1255
1235
// because the loop body causes data-movement which invalidates the
1256
1236
// iterator.
1257
1237
for (unsigned j = 0 ; j < lsize; j++) {
@@ -1323,6 +1303,7 @@ struct GenericOpSparsifier : public OpRewritePattern<linalg::GenericOp> {
1323
1303
if (hasNonTrivialAffineOnSparseOut (op))
1324
1304
return failure ();
1325
1305
1306
+ // Only accept scheduled loops.
1326
1307
if (!op->hasAttr (" sorted" )) {
1327
1308
return rewriter.notifyMatchFailure (
1328
1309
op, " Loops not yet scheduled, try run --sparse-reinterpret-map "
@@ -1348,9 +1329,9 @@ struct GenericOpSparsifier : public OpRewritePattern<linalg::GenericOp> {
1348
1329
}
1349
1330
}
1350
1331
1351
- CodegenEnv env (op, options, numTensors, numLoops, maxLvlRank);
1352
1332
// Detects sparse annotations and translates the per-level sparsity
1353
1333
// information for all tensors to loop indices in the kernel.
1334
+ CodegenEnv env (op, options, numTensors, numLoops, maxLvlRank);
1354
1335
if (!findSparseAnnotations (env, needIdxRed))
1355
1336
return failure ();
1356
1337
0 commit comments