Skip to content

Commit c442025

Browse files
author
Peiming Liu
authored
[mlir][sparse] support sparsification to coiterate operations. (#102546)
1 parent 3031840 commit c442025

File tree

8 files changed

+302
-73
lines changed

8 files changed

+302
-73
lines changed

mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1749,6 +1749,10 @@ def SparseTensor_CoIterateOp : SparseTensor_Op<"coiterate",
17491749
let results = (outs Variadic<AnyType>:$results);
17501750
let regions = (region VariadicRegion<SizedRegion<1>>:$caseRegions);
17511751

1752+
let builders = [
1753+
OpBuilder<(ins "ValueRange":$iterSpace, "ValueRange":$initArgs, "unsigned":$numCases)>,
1754+
];
1755+
17521756
let extraClassDeclaration = [{
17531757
unsigned getSpaceDim() {
17541758
return llvm::cast<::mlir::sparse_tensor::IterSpaceType>(
@@ -1765,18 +1769,18 @@ def SparseTensor_CoIterateOp : SparseTensor_Op<"coiterate",
17651769
});
17661770
}
17671771

1768-
// The block arguments starts with referenced coordinates, follows by
1769-
// user-provided iteration arguments and ends with iterators.
1772+
// The block arguments starts with user-provided iteration arguments,
1773+
// follows by referenced coordinates and ends with iterators.
17701774
Block::BlockArgListType getCrds(unsigned regionIdx) {
17711775
return getRegion(regionIdx).getArguments()
1772-
.take_front(getCrdUsedLvls().count());
1776+
.slice(getNumRegionIterArgs(), getCrdUsedLvls().count());
17731777
}
1774-
unsigned getNumRegionIterArgs(unsigned regionIdx) {
1778+
unsigned getNumRegionIterArgs() {
17751779
return getInitArgs().size();
17761780
}
17771781
Block::BlockArgListType getRegionIterArgs(unsigned regionIdx) {
17781782
return getRegion(regionIdx).getArguments()
1779-
.slice(getCrdUsedLvls().count(), getNumRegionIterArgs(regionIdx));
1783+
.take_front(getNumRegionIterArgs());
17801784
}
17811785
Block::BlockArgListType getRegionIterators(unsigned regionIdx) {
17821786
return getRegion(regionIdx).getArguments()

mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2293,16 +2293,18 @@ parseSparseCoIterateLoop(OpAsmParser &parser, OperationState &state,
22932293
if (parser.parseOperandList(spaces, OpAsmParser::Delimiter::Paren))
22942294
return failure();
22952295

2296-
if (failed(parseUsedCoordList(parser, state, blockArgs)))
2296+
SmallVector<OpAsmParser::Argument> coords;
2297+
if (failed(parseUsedCoordList(parser, state, coords)))
22972298
return failure();
2298-
size_t numCrds = blockArgs.size();
2299+
size_t numCrds = coords.size();
22992300

23002301
// Parse "iter_args(%arg = %init, ...)"
23012302
SmallVector<OpAsmParser::UnresolvedOperand> initArgs;
23022303
bool hasIterArgs = succeeded(parser.parseOptionalKeyword("iter_args"));
23032304
if (hasIterArgs)
23042305
if (parser.parseAssignmentList(blockArgs, initArgs))
23052306
return failure();
2307+
blockArgs.append(coords);
23062308

23072309
SmallVector<Type> iterSpaceTps;
23082310
// parse ": (sparse_tensor.iter_space, ...) -> ret"
@@ -2326,8 +2328,8 @@ parseSparseCoIterateLoop(OpAsmParser &parser, OperationState &state,
23262328
state.operands.append(spacesVals);
23272329

23282330
if (hasIterArgs) {
2329-
// Strip off leading args that used for coordinates.
2330-
MutableArrayRef args = MutableArrayRef(blockArgs).drop_front(numCrds);
2331+
// Strip off trailing args that used for coordinates.
2332+
MutableArrayRef args = MutableArrayRef(blockArgs).drop_back(numCrds);
23312333
if (args.size() != initArgs.size() || args.size() != state.types.size()) {
23322334
return parser.emitError(
23332335
parser.getNameLoc(),
@@ -2602,6 +2604,24 @@ void IterateOp::getSuccessorRegions(RegionBranchPoint point,
26022604
regions.push_back(RegionSuccessor(getResults()));
26032605
}
26042606

2607+
void CoIterateOp::build(OpBuilder &builder, OperationState &odsState,
2608+
ValueRange iterSpaces, ValueRange initArgs,
2609+
unsigned numCases) {
2610+
unsigned rank =
2611+
cast<IterSpaceType>(iterSpaces.front().getType()).getSpaceDim();
2612+
// All ones.
2613+
I64BitSet set((1 << rank) - 1);
2614+
// Generates all-zero case bits (they only serve as placeholders), which are
2615+
// supposed to be overriden later. We need to preallocate all the regions as
2616+
// mlir::Region cannot be dynamically added later after the operation is
2617+
// created.
2618+
SmallVector<int64_t> caseBits(numCases, 0);
2619+
ArrayAttr cases = builder.getI64ArrayAttr(caseBits);
2620+
return CoIterateOp::build(builder, odsState, initArgs.getTypes(), iterSpaces,
2621+
initArgs, set, cases,
2622+
/*caseRegionsCount=*/numCases);
2623+
}
2624+
26052625
ParseResult CoIterateOp::parse(OpAsmParser &parser, OperationState &result) {
26062626

26072627
SmallVector<Value> spaces;
@@ -2685,7 +2705,7 @@ ValueRange CoIterateOp::getYieldedValues(unsigned regionIdx) {
26852705

26862706
LogicalResult CoIterateOp::verifyRegions() {
26872707
for (unsigned r = 0, e = getNumRegions(); r < e; r++) {
2688-
if (getNumRegionIterArgs(r) != getNumResults())
2708+
if (getNumRegionIterArgs() != getNumResults())
26892709
return emitOpError(
26902710
"mismatch in number of basic block args and defined values");
26912711

mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1395,7 +1395,7 @@ struct ForeachRewriter : public OpRewritePattern<ForeachOp> {
13951395
loopEmitter.enterNewLoopSeq(rewriter, loc, tidLvls);
13961396
// Note that reduc will be taken care of by loop emitter and get updated
13971397
// in place.
1398-
loopEmitter.enterCoIterationOverTensorsAtLvls(rewriter, loc, tidLvls,
1398+
loopEmitter.enterCoIterationOverTensorsAtLvls(rewriter, loc, tidLvls, 1,
13991399
reduc);
14001400
}
14011401

mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp

Lines changed: 91 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -842,11 +842,13 @@ static bool shouldTryParallize(CodegenEnv &env, LoopId curr,
842842
/// one sparse level in the list.
843843
static Operation *genCoIteration(CodegenEnv &env, OpBuilder &builder,
844844
ArrayRef<TensorLevel> tidLvls,
845-
bool tryParallel, bool needsUniv) {
845+
unsigned numCases, bool tryParallel,
846+
bool needsUniv) {
846847
Operation *loop = *env.genLoopBoundary([&](MutableArrayRef<Value> reduc) {
847848
// Construct while-loop with a parameter for each index.
848849
return env.emitter().enterCoIterationOverTensorsAtLvls(
849-
builder, env.op().getLoc(), tidLvls, reduc, tryParallel, needsUniv);
850+
builder, env.op().getLoc(), tidLvls, numCases, reduc, tryParallel,
851+
needsUniv);
850852
});
851853
assert(loop);
852854
return loop;
@@ -855,9 +857,11 @@ static Operation *genCoIteration(CodegenEnv &env, OpBuilder &builder,
855857
/// Generates a for-loop or a while-loop, depending on whether it implements
856858
/// singleton iteration or co-iteration over the given conjunction.
857859
static Operation *genLoop(CodegenEnv &env, OpBuilder &builder, LoopId curr,
858-
bool needsUniv, ArrayRef<TensorLevel> tidLvls) {
860+
unsigned numCases, bool needsUniv,
861+
ArrayRef<TensorLevel> tidLvls) {
859862
bool tryParallel = shouldTryParallize(env, curr, tidLvls);
860-
return genCoIteration(env, builder, tidLvls, tryParallel, needsUniv);
863+
return genCoIteration(env, builder, tidLvls, numCases, tryParallel,
864+
needsUniv);
861865
}
862866

863867
/// Generates the induction structure for a while-loop.
@@ -900,6 +904,26 @@ static void finalizeWhileOp(CodegenEnv &env, OpBuilder &builder,
900904
// basic block where scf::Yield should be inserted.
901905
}
902906

907+
/// Generates a case region in the coiterate operation.
908+
static void genCoIterationCase(CodegenEnv &env, OpBuilder &builder,
909+
unsigned caseIdx, LatPointId allCase,
910+
LatPointId curCase,
911+
MutableArrayRef<Value> reduc) {
912+
assert(allCase == curCase || env.merger().latGT(allCase, curCase));
913+
const BitVector &allCaseBits = env.merger().lat(allCase).simple;
914+
const BitVector &curCaseBits = env.merger().lat(curCase).simple;
915+
916+
/// Computes the subset of iterators that are valid in the current case being
917+
/// generated.
918+
I64BitSet caseBit(0);
919+
for (auto [idx, set] : llvm::enumerate(allCaseBits.set_bits()))
920+
if (curCaseBits.test(set))
921+
caseBit.set(idx);
922+
923+
env.emitter().enterCurrentCoIterationCase(builder, env.op().getLoc(), caseBit,
924+
caseIdx, reduc);
925+
}
926+
903927
/// Generates a single if-statement within a while-loop.
904928
static scf::IfOp genIf(CodegenEnv &env, OpBuilder &builder, LoopId curr,
905929
LatPointId p) {
@@ -1175,7 +1199,10 @@ static bool translateBitsToTidLvlPairs(
11751199
/// Starts a single loop in current sequence.
11761200
static std::pair<Operation *, bool> startLoop(CodegenEnv &env,
11771201
OpBuilder &builder, LoopId curr,
1178-
LatPointId li, bool needsUniv) {
1202+
LatPointId li, unsigned numCases,
1203+
bool needsUniv) {
1204+
// TODO: numCases only used when generating iterator-based loops. Cleanup
1205+
// after fully migration.
11791206
// The set of tensors + lvls to generate loops on
11801207
SmallVector<TensorLevel> tidLvls;
11811208

@@ -1186,7 +1213,7 @@ static std::pair<Operation *, bool> startLoop(CodegenEnv &env,
11861213
translateBitsToTidLvlPairs(env, li, curr, tidLvls, affineTidLvls);
11871214

11881215
// Emit the for/while-loop control.
1189-
Operation *loop = genLoop(env, builder, curr, needsUniv, tidLvls);
1216+
Operation *loop = genLoop(env, builder, curr, numCases, needsUniv, tidLvls);
11901217
Location loc = env.op().getLoc();
11911218
for (auto [tidLvl, exp] : affineTidLvls) {
11921219
env.emitter().locateLvlAtAffineAddress(builder, loc, tidLvl, exp);
@@ -1259,42 +1286,73 @@ static void genStmt(CodegenEnv &env, RewriterBase &rewriter, ExprId exp,
12591286
// Start a loop sequence.
12601287
bool needsUniv = startLoopSeq(env, rewriter, exp, curr, lts);
12611288

1262-
// Emit a loop for every lattice point L0 >= Li in this loop sequence.
1263-
// We cannot change this to `for (const LatPointId li : env.set(lts))`
1264-
// because the loop body causes data-movement which invalidates
1265-
// the iterator.
1289+
// When using sparse-iterator-based loops, we only need one loops, as
1290+
// opposed to a loop sequence, to cover all the iterator spaces.
12661291
const unsigned lsize = env.set(lts).size();
1267-
for (unsigned i = 0; i < lsize; i++) {
1268-
const LatPointId li = env.set(lts)[i];
1269-
// Start a loop.
1270-
auto [loop, isSingleCond] = startLoop(env, rewriter, curr, li, needsUniv);
1271-
1272-
// Visit all lattices points with Li >= Lj to generate the
1273-
// loop-body, possibly with if statements for coiteration.
1274-
Value redInput = env.getReduc();
1275-
Value cntInput = env.getExpandCount();
1276-
Value insInput = env.getInsertionChain();
1277-
Value validIns = env.getValidLexInsert();
1278-
// We cannot change this to `for (const LatPointId lj : env.set(lts))`
1279-
// because the loop body causes data-movement which invalidates the
1280-
// iterator.
1292+
if (env.generatingSparseIterator()) {
1293+
// Get the largest lattice point and start a loop.
1294+
const LatPointId li = env.set(lts)[0];
1295+
auto [loop, isSingleCond] =
1296+
startLoop(env, rewriter, curr, li, lsize, needsUniv);
1297+
assert(isSingleCond == llvm::isa<IterateOp>(loop));
1298+
// We cannot change this to `for (const LatPointId li : env.set(lts))`
1299+
// because the loop body causes data-movement which invalidates
1300+
// the iterator.
12811301
for (unsigned j = 0; j < lsize; j++) {
12821302
const LatPointId lj = env.set(lts)[j];
12831303
const ExprId ej = env.lat(lj).exp;
1284-
if (li == lj || env.merger().latGT(li, lj)) {
1285-
// Recurse into body of each branch.
1286-
if (!isSingleCond) {
1287-
scf::IfOp ifOp = genIf(env, rewriter, curr, lj);
1288-
genStmt(env, rewriter, ej, curr + 1);
1289-
endIf(env, rewriter, ifOp, redInput, cntInput, insInput, validIns);
1290-
} else {
1304+
// Recurse into body of each branch.
1305+
if (!isSingleCond) {
1306+
env.genLoopBoundary([&, curr, j, li, lj](MutableArrayRef<Value> reduc) {
1307+
genCoIterationCase(env, rewriter, /*caseIdx*/ j, li, lj, reduc);
12911308
genStmt(env, rewriter, ej, curr + 1);
1292-
}
1309+
// TODO: handle yield values.
1310+
assert(reduc.empty() && "Not Implemented");
1311+
rewriter.create<sparse_tensor::YieldOp>(env.op().getLoc());
1312+
return std::nullopt;
1313+
});
1314+
// endIf(env, rewriter, ifOp, redInput, cntInput, insInput, validIns);
1315+
} else {
1316+
genStmt(env, rewriter, ej, curr + 1);
12931317
}
12941318
}
1295-
12961319
// End a loop.
12971320
needsUniv = endLoop(env, rewriter, loop, curr, needsUniv, isSingleCond);
1321+
} else {
1322+
// Emit a loop for every lattice point L0 >= Li in this loop sequence.
1323+
for (unsigned i = 0; i < lsize; i++) {
1324+
const LatPointId li = env.set(lts)[i];
1325+
// Start a loop.
1326+
auto [loop, isSingleCond] =
1327+
startLoop(env, rewriter, curr, li, lsize, needsUniv);
1328+
1329+
// Visit all lattices points with Li >= Lj to generate the
1330+
// loop-body, possibly with if statements for coiteration.
1331+
Value redInput = env.getReduc();
1332+
Value cntInput = env.getExpandCount();
1333+
Value insInput = env.getInsertionChain();
1334+
Value validIns = env.getValidLexInsert();
1335+
// We cannot change this to `for (const LatPointId lj : env.set(lts))`
1336+
// because the loop body causes data-movement which invalidates the
1337+
// iterator.
1338+
for (unsigned j = 0; j < lsize; j++) {
1339+
const LatPointId lj = env.set(lts)[j];
1340+
const ExprId ej = env.lat(lj).exp;
1341+
if (li == lj || env.merger().latGT(li, lj)) {
1342+
// Recurse into body of each branch.
1343+
if (!isSingleCond) {
1344+
scf::IfOp ifOp = genIf(env, rewriter, curr, lj);
1345+
genStmt(env, rewriter, ej, curr + 1);
1346+
endIf(env, rewriter, ifOp, redInput, cntInput, insInput, validIns);
1347+
} else {
1348+
genStmt(env, rewriter, ej, curr + 1);
1349+
}
1350+
}
1351+
}
1352+
1353+
// End a loop.
1354+
needsUniv = endLoop(env, rewriter, loop, curr, needsUniv, isSingleCond);
1355+
}
12981356
}
12991357

13001358
// End a loop sequence.

mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenEnv.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,10 @@ class CodegenEnv {
4949

5050
linalg::GenericOp op() const { return linalgOp; }
5151
const SparsificationOptions &options() const { return sparseOptions; }
52+
bool generatingSparseIterator() const {
53+
return sparseOptions.sparseEmitStrategy ==
54+
SparseEmitStrategy::kSparseIterator;
55+
}
5256
Merger &merger() { return latticeMerger; }
5357
LoopEmitter &emitter() { return loopEmitter; }
5458

0 commit comments

Comments
 (0)