@@ -842,11 +842,13 @@ static bool shouldTryParallize(CodegenEnv &env, LoopId curr,
842
842
// / one sparse level in the list.
843
843
static Operation *genCoIteration (CodegenEnv &env, OpBuilder &builder,
844
844
ArrayRef<TensorLevel> tidLvls,
845
- bool tryParallel, bool needsUniv) {
845
+ unsigned numCases, bool tryParallel,
846
+ bool needsUniv) {
846
847
Operation *loop = *env.genLoopBoundary ([&](MutableArrayRef<Value> reduc) {
847
848
// Construct while-loop with a parameter for each index.
848
849
return env.emitter ().enterCoIterationOverTensorsAtLvls (
849
- builder, env.op ().getLoc (), tidLvls, reduc, tryParallel, needsUniv);
850
+ builder, env.op ().getLoc (), tidLvls, numCases, reduc, tryParallel,
851
+ needsUniv);
850
852
});
851
853
assert (loop);
852
854
return loop;
@@ -855,9 +857,11 @@ static Operation *genCoIteration(CodegenEnv &env, OpBuilder &builder,
855
857
// / Generates a for-loop or a while-loop, depending on whether it implements
856
858
// / singleton iteration or co-iteration over the given conjunction.
857
859
static Operation *genLoop (CodegenEnv &env, OpBuilder &builder, LoopId curr,
858
- bool needsUniv, ArrayRef<TensorLevel> tidLvls) {
860
+ unsigned numCases, bool needsUniv,
861
+ ArrayRef<TensorLevel> tidLvls) {
859
862
bool tryParallel = shouldTryParallize (env, curr, tidLvls);
860
- return genCoIteration (env, builder, tidLvls, tryParallel, needsUniv);
863
+ return genCoIteration (env, builder, tidLvls, numCases, tryParallel,
864
+ needsUniv);
861
865
}
862
866
863
867
// / Generates the induction structure for a while-loop.
@@ -900,6 +904,26 @@ static void finalizeWhileOp(CodegenEnv &env, OpBuilder &builder,
900
904
// basic block where scf::Yield should be inserted.
901
905
}
902
906
907
+ // / Generates a case region in the coiterate operation.
908
+ static void genCoIterationCase (CodegenEnv &env, OpBuilder &builder,
909
+ unsigned caseIdx, LatPointId allCase,
910
+ LatPointId curCase,
911
+ MutableArrayRef<Value> reduc) {
912
+ assert (allCase == curCase || env.merger ().latGT (allCase, curCase));
913
+ const BitVector &allCaseBits = env.merger ().lat (allCase).simple ;
914
+ const BitVector &curCaseBits = env.merger ().lat (curCase).simple ;
915
+
916
+ // / Computes the subset of iterators that are valid in the current case being
917
+ // / generated.
918
+ I64BitSet caseBit (0 );
919
+ for (auto [idx, set] : llvm::enumerate (allCaseBits.set_bits ()))
920
+ if (curCaseBits.test (set))
921
+ caseBit.set (idx);
922
+
923
+ env.emitter ().enterCurrentCoIterationCase (builder, env.op ().getLoc (), caseBit,
924
+ caseIdx, reduc);
925
+ }
926
+
903
927
// / Generates a single if-statement within a while-loop.
904
928
static scf::IfOp genIf (CodegenEnv &env, OpBuilder &builder, LoopId curr,
905
929
LatPointId p) {
@@ -1175,7 +1199,10 @@ static bool translateBitsToTidLvlPairs(
1175
1199
// / Starts a single loop in current sequence.
1176
1200
static std::pair<Operation *, bool > startLoop (CodegenEnv &env,
1177
1201
OpBuilder &builder, LoopId curr,
1178
- LatPointId li, bool needsUniv) {
1202
+ LatPointId li, unsigned numCases,
1203
+ bool needsUniv) {
1204
+ // TODO: numCases only used when generating iterator-based loops. Cleanup
1205
+ // after fully migration.
1179
1206
// The set of tensors + lvls to generate loops on
1180
1207
SmallVector<TensorLevel> tidLvls;
1181
1208
@@ -1186,7 +1213,7 @@ static std::pair<Operation *, bool> startLoop(CodegenEnv &env,
1186
1213
translateBitsToTidLvlPairs (env, li, curr, tidLvls, affineTidLvls);
1187
1214
1188
1215
// Emit the for/while-loop control.
1189
- Operation *loop = genLoop (env, builder, curr, needsUniv, tidLvls);
1216
+ Operation *loop = genLoop (env, builder, curr, numCases, needsUniv, tidLvls);
1190
1217
Location loc = env.op ().getLoc ();
1191
1218
for (auto [tidLvl, exp] : affineTidLvls) {
1192
1219
env.emitter ().locateLvlAtAffineAddress (builder, loc, tidLvl, exp);
@@ -1259,42 +1286,73 @@ static void genStmt(CodegenEnv &env, RewriterBase &rewriter, ExprId exp,
1259
1286
// Start a loop sequence.
1260
1287
bool needsUniv = startLoopSeq (env, rewriter, exp, curr, lts);
1261
1288
1262
- // Emit a loop for every lattice point L0 >= Li in this loop sequence.
1263
- // We cannot change this to `for (const LatPointId li : env.set(lts))`
1264
- // because the loop body causes data-movement which invalidates
1265
- // the iterator.
1289
+ // When using sparse-iterator-based loops, we only need one loops, as
1290
+ // opposed to a loop sequence, to cover all the iterator spaces.
1266
1291
const unsigned lsize = env.set (lts).size ();
1267
- for (unsigned i = 0 ; i < lsize; i++) {
1268
- const LatPointId li = env.set (lts)[i];
1269
- // Start a loop.
1270
- auto [loop, isSingleCond] = startLoop (env, rewriter, curr, li, needsUniv);
1271
-
1272
- // Visit all lattices points with Li >= Lj to generate the
1273
- // loop-body, possibly with if statements for coiteration.
1274
- Value redInput = env.getReduc ();
1275
- Value cntInput = env.getExpandCount ();
1276
- Value insInput = env.getInsertionChain ();
1277
- Value validIns = env.getValidLexInsert ();
1278
- // We cannot change this to `for (const LatPointId lj : env.set(lts))`
1279
- // because the loop body causes data-movement which invalidates the
1280
- // iterator.
1292
+ if (env.generatingSparseIterator ()) {
1293
+ // Get the largest lattice point and start a loop.
1294
+ const LatPointId li = env.set (lts)[0 ];
1295
+ auto [loop, isSingleCond] =
1296
+ startLoop (env, rewriter, curr, li, lsize, needsUniv);
1297
+ assert (isSingleCond == llvm::isa<IterateOp>(loop));
1298
+ // We cannot change this to `for (const LatPointId li : env.set(lts))`
1299
+ // because the loop body causes data-movement which invalidates
1300
+ // the iterator.
1281
1301
for (unsigned j = 0 ; j < lsize; j++) {
1282
1302
const LatPointId lj = env.set (lts)[j];
1283
1303
const ExprId ej = env.lat (lj).exp ;
1284
- if (li == lj || env.merger ().latGT (li, lj)) {
1285
- // Recurse into body of each branch.
1286
- if (!isSingleCond) {
1287
- scf::IfOp ifOp = genIf (env, rewriter, curr, lj);
1288
- genStmt (env, rewriter, ej, curr + 1 );
1289
- endIf (env, rewriter, ifOp, redInput, cntInput, insInput, validIns);
1290
- } else {
1304
+ // Recurse into body of each branch.
1305
+ if (!isSingleCond) {
1306
+ env.genLoopBoundary ([&, curr, j, li, lj](MutableArrayRef<Value> reduc) {
1307
+ genCoIterationCase (env, rewriter, /* caseIdx*/ j, li, lj, reduc);
1291
1308
genStmt (env, rewriter, ej, curr + 1 );
1292
- }
1309
+ // TODO: handle yield values.
1310
+ assert (reduc.empty () && " Not Implemented" );
1311
+ rewriter.create <sparse_tensor::YieldOp>(env.op ().getLoc ());
1312
+ return std::nullopt;
1313
+ });
1314
+ // endIf(env, rewriter, ifOp, redInput, cntInput, insInput, validIns);
1315
+ } else {
1316
+ genStmt (env, rewriter, ej, curr + 1 );
1293
1317
}
1294
1318
}
1295
-
1296
1319
// End a loop.
1297
1320
needsUniv = endLoop (env, rewriter, loop, curr, needsUniv, isSingleCond);
1321
+ } else {
1322
+ // Emit a loop for every lattice point L0 >= Li in this loop sequence.
1323
+ for (unsigned i = 0 ; i < lsize; i++) {
1324
+ const LatPointId li = env.set (lts)[i];
1325
+ // Start a loop.
1326
+ auto [loop, isSingleCond] =
1327
+ startLoop (env, rewriter, curr, li, lsize, needsUniv);
1328
+
1329
+ // Visit all lattices points with Li >= Lj to generate the
1330
+ // loop-body, possibly with if statements for coiteration.
1331
+ Value redInput = env.getReduc ();
1332
+ Value cntInput = env.getExpandCount ();
1333
+ Value insInput = env.getInsertionChain ();
1334
+ Value validIns = env.getValidLexInsert ();
1335
+ // We cannot change this to `for (const LatPointId lj : env.set(lts))`
1336
+ // because the loop body causes data-movement which invalidates the
1337
+ // iterator.
1338
+ for (unsigned j = 0 ; j < lsize; j++) {
1339
+ const LatPointId lj = env.set (lts)[j];
1340
+ const ExprId ej = env.lat (lj).exp ;
1341
+ if (li == lj || env.merger ().latGT (li, lj)) {
1342
+ // Recurse into body of each branch.
1343
+ if (!isSingleCond) {
1344
+ scf::IfOp ifOp = genIf (env, rewriter, curr, lj);
1345
+ genStmt (env, rewriter, ej, curr + 1 );
1346
+ endIf (env, rewriter, ifOp, redInput, cntInput, insInput, validIns);
1347
+ } else {
1348
+ genStmt (env, rewriter, ej, curr + 1 );
1349
+ }
1350
+ }
1351
+ }
1352
+
1353
+ // End a loop.
1354
+ needsUniv = endLoop (env, rewriter, loop, curr, needsUniv, isSingleCond);
1355
+ }
1298
1356
}
1299
1357
1300
1358
// End a loop sequence.
0 commit comments