Skip to content

Commit dbe4d04

Browse files
author
Peiming Liu
committed
[mlir][sparse] support sparsification to coiterate operations.
1 parent 876ee11 commit dbe4d04

File tree

8 files changed

+283
-62
lines changed

8 files changed

+283
-62
lines changed

mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1749,6 +1749,10 @@ def SparseTensor_CoIterateOp : SparseTensor_Op<"coiterate",
17491749
let results = (outs Variadic<AnyType>:$results);
17501750
let regions = (region VariadicRegion<SizedRegion<1>>:$caseRegions);
17511751

1752+
let builders = [
1753+
OpBuilder<(ins "ValueRange":$iterSpace, "ValueRange":$initArgs, "unsigned":$numCases)>,
1754+
];
1755+
17521756
let extraClassDeclaration = [{
17531757
unsigned getSpaceDim() {
17541758
return llvm::cast<::mlir::sparse_tensor::IterSpaceType>(

mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2594,6 +2594,22 @@ void IterateOp::getSuccessorRegions(RegionBranchPoint point,
25942594
regions.push_back(RegionSuccessor(getResults()));
25952595
}
25962596

2597+
void CoIterateOp::build(OpBuilder &builder, OperationState &odsState,
2598+
ValueRange iterSpaces, ValueRange initArgs,
2599+
unsigned numCases) {
2600+
unsigned rank =
2601+
cast<IterSpaceType>(iterSpaces.front().getType()).getSpaceDim();
2602+
// All ones.
2603+
I64BitSet set((1 << rank) - 1);
2604+
// Fake cases bits. We need to preallocate all the regions as Region can not
2605+
// be dynamically added later after the operation is created.
2606+
SmallVector<int64_t> caseBits(numCases, 0);
2607+
ArrayAttr cases = builder.getI64ArrayAttr(caseBits);
2608+
return CoIterateOp::build(builder, odsState, initArgs.getTypes(), iterSpaces,
2609+
initArgs, set, cases,
2610+
/*caseRegionsCount=*/numCases);
2611+
}
2612+
25972613
ParseResult CoIterateOp::parse(OpAsmParser &parser, OperationState &result) {
25982614

25992615
SmallVector<Value> spaces;

mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1395,7 +1395,7 @@ struct ForeachRewriter : public OpRewritePattern<ForeachOp> {
13951395
loopEmitter.enterNewLoopSeq(rewriter, loc, tidLvls);
13961396
// Note that reduc will be taken care of by loop emitter and get updated
13971397
// in place.
1398-
loopEmitter.enterCoIterationOverTensorsAtLvls(rewriter, loc, tidLvls,
1398+
loopEmitter.enterCoIterationOverTensorsAtLvls(rewriter, loc, tidLvls, 1,
13991399
reduc);
14001400
}
14011401

mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp

Lines changed: 91 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -842,11 +842,13 @@ static bool shouldTryParallize(CodegenEnv &env, LoopId curr,
842842
/// one sparse level in the list.
843843
static Operation *genCoIteration(CodegenEnv &env, OpBuilder &builder,
844844
ArrayRef<TensorLevel> tidLvls,
845-
bool tryParallel, bool needsUniv) {
845+
unsigned numCases, bool tryParallel,
846+
bool needsUniv) {
846847
Operation *loop = *env.genLoopBoundary([&](MutableArrayRef<Value> reduc) {
847848
// Construct while-loop with a parameter for each index.
848849
return env.emitter().enterCoIterationOverTensorsAtLvls(
849-
builder, env.op().getLoc(), tidLvls, reduc, tryParallel, needsUniv);
850+
builder, env.op().getLoc(), tidLvls, numCases, reduc, tryParallel,
851+
needsUniv);
850852
});
851853
assert(loop);
852854
return loop;
@@ -855,9 +857,11 @@ static Operation *genCoIteration(CodegenEnv &env, OpBuilder &builder,
855857
/// Generates a for-loop or a while-loop, depending on whether it implements
856858
/// singleton iteration or co-iteration over the given conjunction.
857859
static Operation *genLoop(CodegenEnv &env, OpBuilder &builder, LoopId curr,
858-
bool needsUniv, ArrayRef<TensorLevel> tidLvls) {
860+
unsigned numCases, bool needsUniv,
861+
ArrayRef<TensorLevel> tidLvls) {
859862
bool tryParallel = shouldTryParallize(env, curr, tidLvls);
860-
return genCoIteration(env, builder, tidLvls, tryParallel, needsUniv);
863+
return genCoIteration(env, builder, tidLvls, numCases, tryParallel,
864+
needsUniv);
861865
}
862866

863867
/// Generates the induction structure for a while-loop.
@@ -900,6 +904,26 @@ static void finalizeWhileOp(CodegenEnv &env, OpBuilder &builder,
900904
// basic block where scf::Yield should be inserted.
901905
}
902906

907+
/// Generate a case region in the coiterate operation.
908+
static void genCoIterationCase(CodegenEnv &env, OpBuilder &builder,
909+
unsigned caseIdx, LatPointId allCase,
910+
LatPointId curCase,
911+
MutableArrayRef<Value> reduc) {
912+
assert(allCase == curCase || env.merger().latGT(allCase, curCase));
913+
const BitVector &allCaseBits = env.merger().lat(allCase).simple;
914+
const BitVector &curCaseBits = env.merger().lat(curCase).simple;
915+
916+
/// Computes the subset of iterators that are valid in the current case being
917+
/// generated.
918+
I64BitSet caseBit(0);
919+
for (auto [idx, set] : llvm::enumerate(allCaseBits.set_bits()))
920+
if (curCaseBits.test(set))
921+
caseBit.set(idx);
922+
923+
env.emitter().enterCurCoIterationCase(builder, env.op().getLoc(), caseBit,
924+
caseIdx, reduc);
925+
}
926+
903927
/// Generates a single if-statement within a while-loop.
904928
static scf::IfOp genIf(CodegenEnv &env, OpBuilder &builder, LoopId curr,
905929
LatPointId p) {
@@ -1175,7 +1199,10 @@ static bool translateBitsToTidLvlPairs(
11751199
/// Starts a single loop in current sequence.
11761200
static std::pair<Operation *, bool> startLoop(CodegenEnv &env,
11771201
OpBuilder &builder, LoopId curr,
1178-
LatPointId li, bool needsUniv) {
1202+
LatPointId li, unsigned numCases,
1203+
bool needsUniv) {
1204+
// TODO: numCases only used when generating iterator-based loops. Cleanup
1205+
// after fully migration.
11791206
// The set of tensors + lvls to generate loops on
11801207
SmallVector<TensorLevel> tidLvls;
11811208

@@ -1186,7 +1213,7 @@ static std::pair<Operation *, bool> startLoop(CodegenEnv &env,
11861213
translateBitsToTidLvlPairs(env, li, curr, tidLvls, affineTidLvls);
11871214

11881215
// Emit the for/while-loop control.
1189-
Operation *loop = genLoop(env, builder, curr, needsUniv, tidLvls);
1216+
Operation *loop = genLoop(env, builder, curr, numCases, needsUniv, tidLvls);
11901217
Location loc = env.op().getLoc();
11911218
for (auto [tidLvl, exp] : affineTidLvls) {
11921219
env.emitter().locateLvlAtAffineAddress(builder, loc, tidLvl, exp);
@@ -1259,42 +1286,73 @@ static void genStmt(CodegenEnv &env, RewriterBase &rewriter, ExprId exp,
12591286
// Start a loop sequence.
12601287
bool needsUniv = startLoopSeq(env, rewriter, exp, curr, lts);
12611288

1262-
// Emit a loop for every lattice point L0 >= Li in this loop sequence.
1263-
// We cannot change this to `for (const LatPointId li : env.set(lts))`
1264-
// because the loop body causes data-movement which invalidates
1265-
// the iterator.
1289+
// When using sparse-iterator-based loops, we only need one loops, as
1290+
// opposed to a loop sequence, to cover all the iterator spaces.
12661291
const unsigned lsize = env.set(lts).size();
1267-
for (unsigned i = 0; i < lsize; i++) {
1268-
const LatPointId li = env.set(lts)[i];
1269-
// Start a loop.
1270-
auto [loop, isSingleCond] = startLoop(env, rewriter, curr, li, needsUniv);
1271-
1272-
// Visit all lattices points with Li >= Lj to generate the
1273-
// loop-body, possibly with if statements for coiteration.
1274-
Value redInput = env.getReduc();
1275-
Value cntInput = env.getExpandCount();
1276-
Value insInput = env.getInsertionChain();
1277-
Value validIns = env.getValidLexInsert();
1278-
// We cannot change this to `for (const LatPointId lj : env.set(lts))`
1279-
// because the loop body causes data-movement which invalidates the
1280-
// iterator.
1292+
if (env.generatingSparseIterator()) {
1293+
// Get the largest lattice point and start a loop.
1294+
const LatPointId li = env.set(lts)[0];
1295+
auto [loop, isSingleCond] =
1296+
startLoop(env, rewriter, curr, li, lsize, needsUniv);
1297+
assert(isSingleCond == llvm::isa<IterateOp>(loop));
1298+
// We cannot change this to `for (const LatPointId li : env.set(lts))`
1299+
// because the loop body causes data-movement which invalidates
1300+
// the iterator.
12811301
for (unsigned j = 0; j < lsize; j++) {
12821302
const LatPointId lj = env.set(lts)[j];
12831303
const ExprId ej = env.lat(lj).exp;
1284-
if (li == lj || env.merger().latGT(li, lj)) {
1285-
// Recurse into body of each branch.
1286-
if (!isSingleCond) {
1287-
scf::IfOp ifOp = genIf(env, rewriter, curr, lj);
1288-
genStmt(env, rewriter, ej, curr + 1);
1289-
endIf(env, rewriter, ifOp, redInput, cntInput, insInput, validIns);
1290-
} else {
1304+
// Recurse into body of each branch.
1305+
if (!isSingleCond) {
1306+
env.genLoopBoundary([&, curr, j, li, lj](MutableArrayRef<Value> reduc) {
1307+
genCoIterationCase(env, rewriter, /*caseIdx*/ j, li, lj, reduc);
12911308
genStmt(env, rewriter, ej, curr + 1);
1292-
}
1309+
// TODO: handle yield values.
1310+
assert(reduc.empty() && "Not Implemented");
1311+
rewriter.create<sparse_tensor::YieldOp>(env.op().getLoc());
1312+
return std::nullopt;
1313+
});
1314+
// endIf(env, rewriter, ifOp, redInput, cntInput, insInput, validIns);
1315+
} else {
1316+
genStmt(env, rewriter, ej, curr + 1);
12931317
}
12941318
}
1295-
12961319
// End a loop.
12971320
needsUniv = endLoop(env, rewriter, loop, curr, needsUniv, isSingleCond);
1321+
} else {
1322+
// Emit a loop for every lattice point L0 >= Li in this loop sequence.
1323+
for (unsigned i = 0; i < lsize; i++) {
1324+
const LatPointId li = env.set(lts)[i];
1325+
// Start a loop.
1326+
auto [loop, isSingleCond] =
1327+
startLoop(env, rewriter, curr, li, lsize, needsUniv);
1328+
1329+
// Visit all lattices points with Li >= Lj to generate the
1330+
// loop-body, possibly with if statements for coiteration.
1331+
Value redInput = env.getReduc();
1332+
Value cntInput = env.getExpandCount();
1333+
Value insInput = env.getInsertionChain();
1334+
Value validIns = env.getValidLexInsert();
1335+
// We cannot change this to `for (const LatPointId lj : env.set(lts))`
1336+
// because the loop body causes data-movement which invalidates the
1337+
// iterator.
1338+
for (unsigned j = 0; j < lsize; j++) {
1339+
const LatPointId lj = env.set(lts)[j];
1340+
const ExprId ej = env.lat(lj).exp;
1341+
if (li == lj || env.merger().latGT(li, lj)) {
1342+
// Recurse into body of each branch.
1343+
if (!isSingleCond) {
1344+
scf::IfOp ifOp = genIf(env, rewriter, curr, lj);
1345+
genStmt(env, rewriter, ej, curr + 1);
1346+
endIf(env, rewriter, ifOp, redInput, cntInput, insInput, validIns);
1347+
} else {
1348+
genStmt(env, rewriter, ej, curr + 1);
1349+
}
1350+
}
1351+
}
1352+
1353+
// End a loop.
1354+
needsUniv = endLoop(env, rewriter, loop, curr, needsUniv, isSingleCond);
1355+
}
12981356
}
12991357

13001358
// End a loop sequence.

mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenEnv.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,10 @@ class CodegenEnv {
4949

5050
linalg::GenericOp op() const { return linalgOp; }
5151
const SparsificationOptions &options() const { return sparseOptions; }
52+
bool generatingSparseIterator() const {
53+
return sparseOptions.sparseEmitStrategy ==
54+
SparseEmitStrategy::kSparseIterator;
55+
}
5256
Merger &merger() { return latticeMerger; }
5357
LoopEmitter &emitter() { return loopEmitter; }
5458

mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp

Lines changed: 105 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -615,33 +615,104 @@ bool LoopEmitter::shouldIteratedByForLoop(ArrayRef<SparseIterator *> spIters) {
615615
return true;
616616
}
617617

618+
Region *LoopEmitter::enterCurCoIterationCase(OpBuilder &builder, Location loc,
619+
I64BitSet caseBit,
620+
unsigned caseIdx,
621+
MutableArrayRef<Value> reduc) {
622+
auto coIterOp = cast<CoIterateOp>(loopStack.back().loop);
623+
SmallVector<Attribute> cases(coIterOp.getCases().getAsRange<Attribute>());
624+
cases[caseIdx] = builder.getI64IntegerAttr(caseBit);
625+
626+
coIterOp.setCasesAttr(builder.getArrayAttr(cases));
627+
Region &caseRegion = coIterOp.getRegion(caseIdx);
628+
assert(caseRegion.getBlocks().empty() &&
629+
"re-initialize the same coiteration case region.");
630+
631+
// Each block starts with a list of used coordinates of index type.
632+
SmallVector<Type> blockArgTps(coIterOp.getCrdUsedLvls().count(),
633+
builder.getIndexType());
634+
// Follows by a list of user-provided iteration arguments.
635+
TypeRange iterArgsTps = coIterOp.getInitArgs().getTypes();
636+
blockArgTps.append(iterArgsTps.begin(), iterArgsTps.end());
637+
// Ends with a set of iterators that defines the actually iteration space.
638+
for (auto i : caseBit.bits()) {
639+
blockArgTps.push_back(
640+
cast<IterSpaceType>(coIterOp.getIterSpaces()[i].getType())
641+
.getIteratorType());
642+
}
643+
SmallVector<Location> locs(blockArgTps.size(), loc);
644+
caseRegion.emplaceBlock().addArguments(blockArgTps, locs);
645+
646+
// Entering the new region scope, updating the SSA chain.
647+
builder.setInsertionPointToStart(&caseRegion.front());
648+
// Update the coordinates.
649+
loopStack.back().iv = coIterOp.getCrds(caseIdx).front();
650+
// Updates loop iteration arguments.
651+
ValueRange iterArgs = coIterOp.getRegionIterArgs(caseIdx);
652+
llvm::copy(iterArgs, reduc.begin());
653+
// Updates sparse iterator values.
654+
ValueRange iters = coIterOp.getRegionIterators(caseIdx);
655+
ArrayRef<TensorLevel> tidLvls = loopStack.back().tidLvls;
656+
for (auto [i, tl] : llvm::enumerate(unpackTensorLevelRange(tidLvls))) {
657+
if (caseBit[i]) {
658+
spIterVals[tl.first][tl.second] = iters.front();
659+
iters = iters.drop_front();
660+
} else {
661+
spIterVals[tl.first][tl.second] = nullptr;
662+
}
663+
}
664+
// Must have consumed all iterator SSA values.
665+
assert(iters.empty());
666+
return &caseRegion;
667+
}
668+
618669
Operation *LoopEmitter::enterCoIterationOverTensorsAtLvls(
619670
OpBuilder &builder, Location loc, ArrayRef<TensorLevel> tidLvls,
620-
MutableArrayRef<Value> reduc, bool tryParallel, bool needsUniv) {
621-
671+
unsigned numCases, MutableArrayRef<Value> reduc, bool tryParallel,
672+
bool needsUniv) {
673+
// TODO: Argument `numCases` only used when generating iterator-based sparse
674+
// loops. Simplify the code upon feature complete.
622675
// TODO: handle coiteration with sparse iterator.
623676
if (emitStrategy == SparseEmitStrategy::kSparseIterator) {
624-
assert(tidLvls.size() == 1);
625-
auto [tid, lvl] = unpackTensorLevel(tidLvls.front());
626-
Value t = tensors[tid];
627-
628-
// Extract and iterate over the iteration space.
629-
ExtractIterSpaceOp extractSpaceOp =
630-
lvl == 0 ? builder.create<ExtractIterSpaceOp>(loc, t)
631-
: builder.create<ExtractIterSpaceOp>(
632-
loc, t, spIterVals[tid][lvl - 1], lvl);
633-
634-
IterateOp iterOp = builder.create<IterateOp>(
635-
loc, extractSpaceOp.getExtractedSpace(), reduc);
636-
spIterVals[tid][lvl] = iterOp.getIterator();
677+
if (tidLvls.size() == 1) {
678+
auto [tid, lvl] = unpackTensorLevel(tidLvls.front());
679+
Value t = tensors[tid];
680+
681+
// Extract and iterate over the iteration space.
682+
ExtractIterSpaceOp extractSpaceOp =
683+
lvl == 0 ? builder.create<ExtractIterSpaceOp>(loc, t)
684+
: builder.create<ExtractIterSpaceOp>(
685+
loc, t, spIterVals[tid][lvl - 1], lvl);
686+
687+
IterateOp iterOp = builder.create<IterateOp>(
688+
loc, extractSpaceOp.getExtractedSpace(), reduc);
689+
spIterVals[tid][lvl] = iterOp.getIterator();
690+
691+
// Update the reduction varaibles.
692+
llvm::copy(iterOp.getRegionIterArgs(), reduc.begin());
693+
// Set the insertion point to loop body.
694+
builder.setInsertionPointToStart(iterOp.getBody());
695+
loopStack.emplace_back(tidLvls, iterOp, builder.getInsertionBlock(),
696+
iterOp.getCrds().front(), loopTag);
697+
return iterOp;
698+
}
637699

638-
// Update the reduction varaibles.
639-
llvm::copy(iterOp.getRegionIterArgs(), reduc.begin());
640-
// Set the insertion point to loop body.
641-
builder.setInsertionPointToStart(iterOp.getBody());
642-
loopStack.emplace_back(tidLvls, iterOp, builder.getInsertionBlock(),
643-
iterOp.getIterator(), loopTag);
644-
return iterOp;
700+
// CoIteration Loops.
701+
SmallVector<Value> spaces;
702+
for (auto [tid, lvl] : unpackTensorLevelRange(tidLvls)) {
703+
Value t = tensors[tid];
704+
ExtractIterSpaceOp extractSpaceOp =
705+
lvl == 0 ? builder.create<ExtractIterSpaceOp>(loc, t)
706+
: builder.create<ExtractIterSpaceOp>(
707+
loc, t, spIterVals[tid][lvl - 1], lvl);
708+
spaces.push_back(extractSpaceOp.getExtractedSpace());
709+
}
710+
auto coIterOp = builder.create<CoIterateOp>(loc, spaces, reduc, numCases);
711+
// The CoIterationOp does not have insertion block nor induction variable.
712+
// TODO: the `struct LoopInfo` should be simplied after full migration.
713+
loopStack.emplace_back(tidLvls, coIterOp, /*insertion block*/ nullptr,
714+
/*induction variable*/ nullptr, loopTag);
715+
return coIterOp;
645716
}
646717

647718
// TODO: support multiple return on parallel for?
@@ -866,6 +937,18 @@ void LoopEmitter::exitCurrentLoop(RewriterBase &rewriter, Location loc,
866937
// Clean up the values, it would help use to discover potential bug at a
867938
// earlier stage (instead of silently using a wrong value).
868939
const LoopInfo &loopInfo = loopStack.back();
940+
if (emitStrategy == SparseEmitStrategy::kSparseIterator) {
941+
Operation *p = loopInfo.loop;
942+
if (isa<IterateOp>(p))
943+
rewriter.create<sparse_tensor::YieldOp>(loc, reduc);
944+
945+
// Exit the loop.
946+
rewriter.setInsertionPointAfter(p);
947+
// In-place update reduction variables.
948+
llvm::copy(p->getResults(), reduc.begin());
949+
loopStack.pop_back();
950+
return;
951+
}
869952

870953
// Sets the insertion point to the right position.
871954
rewriter.setInsertionPointToEnd(loopInfo.userCodeBlock);

0 commit comments

Comments
 (0)