Skip to content

[mlir][sparse] support sparsification to coiterate operations. #102546

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Aug 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 9 additions & 5 deletions mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1749,6 +1749,10 @@ def SparseTensor_CoIterateOp : SparseTensor_Op<"coiterate",
let results = (outs Variadic<AnyType>:$results);
let regions = (region VariadicRegion<SizedRegion<1>>:$caseRegions);

let builders = [
OpBuilder<(ins "ValueRange":$iterSpace, "ValueRange":$initArgs, "unsigned":$numCases)>,
];

let extraClassDeclaration = [{
unsigned getSpaceDim() {
return llvm::cast<::mlir::sparse_tensor::IterSpaceType>(
Expand All @@ -1765,18 +1769,18 @@ def SparseTensor_CoIterateOp : SparseTensor_Op<"coiterate",
});
}

// The block arguments starts with referenced coordinates, follows by
// user-provided iteration arguments and ends with iterators.
// The block arguments starts with user-provided iteration arguments,
// follows by referenced coordinates and ends with iterators.
Block::BlockArgListType getCrds(unsigned regionIdx) {
return getRegion(regionIdx).getArguments()
.take_front(getCrdUsedLvls().count());
.slice(getNumRegionIterArgs(), getCrdUsedLvls().count());
}
unsigned getNumRegionIterArgs(unsigned regionIdx) {
unsigned getNumRegionIterArgs() {
return getInitArgs().size();
}
Block::BlockArgListType getRegionIterArgs(unsigned regionIdx) {
return getRegion(regionIdx).getArguments()
.slice(getCrdUsedLvls().count(), getNumRegionIterArgs(regionIdx));
.take_front(getNumRegionIterArgs());
}
Block::BlockArgListType getRegionIterators(unsigned regionIdx) {
return getRegion(regionIdx).getArguments()
Expand Down
30 changes: 25 additions & 5 deletions mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2293,16 +2293,18 @@ parseSparseCoIterateLoop(OpAsmParser &parser, OperationState &state,
if (parser.parseOperandList(spaces, OpAsmParser::Delimiter::Paren))
return failure();

if (failed(parseUsedCoordList(parser, state, blockArgs)))
SmallVector<OpAsmParser::Argument> coords;
if (failed(parseUsedCoordList(parser, state, coords)))
return failure();
size_t numCrds = blockArgs.size();
size_t numCrds = coords.size();

// Parse "iter_args(%arg = %init, ...)"
SmallVector<OpAsmParser::UnresolvedOperand> initArgs;
bool hasIterArgs = succeeded(parser.parseOptionalKeyword("iter_args"));
if (hasIterArgs)
if (parser.parseAssignmentList(blockArgs, initArgs))
return failure();
blockArgs.append(coords);

SmallVector<Type> iterSpaceTps;
// parse ": (sparse_tensor.iter_space, ...) -> ret"
Expand All @@ -2326,8 +2328,8 @@ parseSparseCoIterateLoop(OpAsmParser &parser, OperationState &state,
state.operands.append(spacesVals);

if (hasIterArgs) {
// Strip off leading args that used for coordinates.
MutableArrayRef args = MutableArrayRef(blockArgs).drop_front(numCrds);
// Strip off trailing args that used for coordinates.
MutableArrayRef args = MutableArrayRef(blockArgs).drop_back(numCrds);
if (args.size() != initArgs.size() || args.size() != state.types.size()) {
return parser.emitError(
parser.getNameLoc(),
Expand Down Expand Up @@ -2602,6 +2604,24 @@ void IterateOp::getSuccessorRegions(RegionBranchPoint point,
regions.push_back(RegionSuccessor(getResults()));
}

void CoIterateOp::build(OpBuilder &builder, OperationState &odsState,
ValueRange iterSpaces, ValueRange initArgs,
unsigned numCases) {
unsigned rank =
cast<IterSpaceType>(iterSpaces.front().getType()).getSpaceDim();
// All ones.
I64BitSet set((1 << rank) - 1);
// Generates all-zero case bits (they only serve as placeholders), which are
// supposed to be overriden later. We need to preallocate all the regions as
// mlir::Region cannot be dynamically added later after the operation is
// created.
SmallVector<int64_t> caseBits(numCases, 0);
ArrayAttr cases = builder.getI64ArrayAttr(caseBits);
return CoIterateOp::build(builder, odsState, initArgs.getTypes(), iterSpaces,
initArgs, set, cases,
/*caseRegionsCount=*/numCases);
}

ParseResult CoIterateOp::parse(OpAsmParser &parser, OperationState &result) {

SmallVector<Value> spaces;
Expand Down Expand Up @@ -2685,7 +2705,7 @@ ValueRange CoIterateOp::getYieldedValues(unsigned regionIdx) {

LogicalResult CoIterateOp::verifyRegions() {
for (unsigned r = 0, e = getNumRegions(); r < e; r++) {
if (getNumRegionIterArgs(r) != getNumResults())
if (getNumRegionIterArgs() != getNumResults())
return emitOpError(
"mismatch in number of basic block args and defined values");

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1395,7 +1395,7 @@ struct ForeachRewriter : public OpRewritePattern<ForeachOp> {
loopEmitter.enterNewLoopSeq(rewriter, loc, tidLvls);
// Note that reduc will be taken care of by loop emitter and get updated
// in place.
loopEmitter.enterCoIterationOverTensorsAtLvls(rewriter, loc, tidLvls,
loopEmitter.enterCoIterationOverTensorsAtLvls(rewriter, loc, tidLvls, 1,
reduc);
}

Expand Down
124 changes: 91 additions & 33 deletions mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -842,11 +842,13 @@ static bool shouldTryParallize(CodegenEnv &env, LoopId curr,
/// one sparse level in the list.
static Operation *genCoIteration(CodegenEnv &env, OpBuilder &builder,
ArrayRef<TensorLevel> tidLvls,
bool tryParallel, bool needsUniv) {
unsigned numCases, bool tryParallel,
bool needsUniv) {
Operation *loop = *env.genLoopBoundary([&](MutableArrayRef<Value> reduc) {
// Construct while-loop with a parameter for each index.
return env.emitter().enterCoIterationOverTensorsAtLvls(
builder, env.op().getLoc(), tidLvls, reduc, tryParallel, needsUniv);
builder, env.op().getLoc(), tidLvls, numCases, reduc, tryParallel,
needsUniv);
});
assert(loop);
return loop;
Expand All @@ -855,9 +857,11 @@ static Operation *genCoIteration(CodegenEnv &env, OpBuilder &builder,
/// Generates a for-loop or a while-loop, depending on whether it implements
/// singleton iteration or co-iteration over the given conjunction.
static Operation *genLoop(CodegenEnv &env, OpBuilder &builder, LoopId curr,
bool needsUniv, ArrayRef<TensorLevel> tidLvls) {
unsigned numCases, bool needsUniv,
ArrayRef<TensorLevel> tidLvls) {
bool tryParallel = shouldTryParallize(env, curr, tidLvls);
return genCoIteration(env, builder, tidLvls, tryParallel, needsUniv);
return genCoIteration(env, builder, tidLvls, numCases, tryParallel,
needsUniv);
}

/// Generates the induction structure for a while-loop.
Expand Down Expand Up @@ -900,6 +904,26 @@ static void finalizeWhileOp(CodegenEnv &env, OpBuilder &builder,
// basic block where scf::Yield should be inserted.
}

/// Generates a case region in the coiterate operation.
static void genCoIterationCase(CodegenEnv &env, OpBuilder &builder,
unsigned caseIdx, LatPointId allCase,
LatPointId curCase,
MutableArrayRef<Value> reduc) {
assert(allCase == curCase || env.merger().latGT(allCase, curCase));
const BitVector &allCaseBits = env.merger().lat(allCase).simple;
const BitVector &curCaseBits = env.merger().lat(curCase).simple;

/// Computes the subset of iterators that are valid in the current case being
/// generated.
I64BitSet caseBit(0);
for (auto [idx, set] : llvm::enumerate(allCaseBits.set_bits()))
if (curCaseBits.test(set))
caseBit.set(idx);

env.emitter().enterCurrentCoIterationCase(builder, env.op().getLoc(), caseBit,
caseIdx, reduc);
}

/// Generates a single if-statement within a while-loop.
static scf::IfOp genIf(CodegenEnv &env, OpBuilder &builder, LoopId curr,
LatPointId p) {
Expand Down Expand Up @@ -1175,7 +1199,10 @@ static bool translateBitsToTidLvlPairs(
/// Starts a single loop in current sequence.
static std::pair<Operation *, bool> startLoop(CodegenEnv &env,
OpBuilder &builder, LoopId curr,
LatPointId li, bool needsUniv) {
LatPointId li, unsigned numCases,
bool needsUniv) {
// TODO: numCases only used when generating iterator-based loops. Cleanup
// after fully migration.
// The set of tensors + lvls to generate loops on
SmallVector<TensorLevel> tidLvls;

Expand All @@ -1186,7 +1213,7 @@ static std::pair<Operation *, bool> startLoop(CodegenEnv &env,
translateBitsToTidLvlPairs(env, li, curr, tidLvls, affineTidLvls);

// Emit the for/while-loop control.
Operation *loop = genLoop(env, builder, curr, needsUniv, tidLvls);
Operation *loop = genLoop(env, builder, curr, numCases, needsUniv, tidLvls);
Location loc = env.op().getLoc();
for (auto [tidLvl, exp] : affineTidLvls) {
env.emitter().locateLvlAtAffineAddress(builder, loc, tidLvl, exp);
Expand Down Expand Up @@ -1259,42 +1286,73 @@ static void genStmt(CodegenEnv &env, RewriterBase &rewriter, ExprId exp,
// Start a loop sequence.
bool needsUniv = startLoopSeq(env, rewriter, exp, curr, lts);

// Emit a loop for every lattice point L0 >= Li in this loop sequence.
// We cannot change this to `for (const LatPointId li : env.set(lts))`
// because the loop body causes data-movement which invalidates
// the iterator.
// When using sparse-iterator-based loops, we only need one loops, as
// opposed to a loop sequence, to cover all the iterator spaces.
const unsigned lsize = env.set(lts).size();
for (unsigned i = 0; i < lsize; i++) {
const LatPointId li = env.set(lts)[i];
// Start a loop.
auto [loop, isSingleCond] = startLoop(env, rewriter, curr, li, needsUniv);

// Visit all lattices points with Li >= Lj to generate the
// loop-body, possibly with if statements for coiteration.
Value redInput = env.getReduc();
Value cntInput = env.getExpandCount();
Value insInput = env.getInsertionChain();
Value validIns = env.getValidLexInsert();
// We cannot change this to `for (const LatPointId lj : env.set(lts))`
// because the loop body causes data-movement which invalidates the
// iterator.
if (env.generatingSparseIterator()) {
// Get the largest lattice point and start a loop.
const LatPointId li = env.set(lts)[0];
auto [loop, isSingleCond] =
startLoop(env, rewriter, curr, li, lsize, needsUniv);
assert(isSingleCond == llvm::isa<IterateOp>(loop));
// We cannot change this to `for (const LatPointId li : env.set(lts))`
// because the loop body causes data-movement which invalidates
// the iterator.
for (unsigned j = 0; j < lsize; j++) {
const LatPointId lj = env.set(lts)[j];
const ExprId ej = env.lat(lj).exp;
if (li == lj || env.merger().latGT(li, lj)) {
// Recurse into body of each branch.
if (!isSingleCond) {
scf::IfOp ifOp = genIf(env, rewriter, curr, lj);
genStmt(env, rewriter, ej, curr + 1);
endIf(env, rewriter, ifOp, redInput, cntInput, insInput, validIns);
} else {
// Recurse into body of each branch.
if (!isSingleCond) {
env.genLoopBoundary([&, curr, j, li, lj](MutableArrayRef<Value> reduc) {
genCoIterationCase(env, rewriter, /*caseIdx*/ j, li, lj, reduc);
genStmt(env, rewriter, ej, curr + 1);
}
// TODO: handle yield values.
assert(reduc.empty() && "Not Implemented");
rewriter.create<sparse_tensor::YieldOp>(env.op().getLoc());
return std::nullopt;
});
// endIf(env, rewriter, ifOp, redInput, cntInput, insInput, validIns);
} else {
genStmt(env, rewriter, ej, curr + 1);
}
}

// End a loop.
needsUniv = endLoop(env, rewriter, loop, curr, needsUniv, isSingleCond);
} else {
// Emit a loop for every lattice point L0 >= Li in this loop sequence.
for (unsigned i = 0; i < lsize; i++) {
const LatPointId li = env.set(lts)[i];
// Start a loop.
auto [loop, isSingleCond] =
startLoop(env, rewriter, curr, li, lsize, needsUniv);

// Visit all lattices points with Li >= Lj to generate the
// loop-body, possibly with if statements for coiteration.
Value redInput = env.getReduc();
Value cntInput = env.getExpandCount();
Value insInput = env.getInsertionChain();
Value validIns = env.getValidLexInsert();
// We cannot change this to `for (const LatPointId lj : env.set(lts))`
// because the loop body causes data-movement which invalidates the
// iterator.
for (unsigned j = 0; j < lsize; j++) {
const LatPointId lj = env.set(lts)[j];
const ExprId ej = env.lat(lj).exp;
if (li == lj || env.merger().latGT(li, lj)) {
// Recurse into body of each branch.
if (!isSingleCond) {
scf::IfOp ifOp = genIf(env, rewriter, curr, lj);
genStmt(env, rewriter, ej, curr + 1);
endIf(env, rewriter, ifOp, redInput, cntInput, insInput, validIns);
} else {
genStmt(env, rewriter, ej, curr + 1);
}
}
}

// End a loop.
needsUniv = endLoop(env, rewriter, loop, curr, needsUniv, isSingleCond);
}
}

// End a loop sequence.
Expand Down
4 changes: 4 additions & 0 deletions mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenEnv.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,10 @@ class CodegenEnv {

linalg::GenericOp op() const { return linalgOp; }
const SparsificationOptions &options() const { return sparseOptions; }
bool generatingSparseIterator() const {
return sparseOptions.sparseEmitStrategy ==
SparseEmitStrategy::kSparseIterator;
}
Merger &merger() { return latticeMerger; }
LoopEmitter &emitter() { return loopEmitter; }

Expand Down
Loading
Loading