Skip to content

Commit 42102ed

Browse files
author
Peiming Liu
committed
[mlir][sparse] partially support lowering sparse coiteration loops to scf.while/for.
1 parent dd90c72 commit 42102ed

File tree

9 files changed

+549
-91
lines changed

9 files changed

+549
-91
lines changed

mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -96,24 +96,32 @@ class I64BitSet {
9696
return *this;
9797
}
9898

99+
bool isSubSetOf(const I64BitSet p) const {
100+
I64BitSet tmp = *this;
101+
tmp |= p;
102+
return tmp == p;
103+
}
104+
99105
// Needed by `llvm::const_set_bits_iterator_impl`.
100106
int find_first() const { return min(); }
101107
int find_next(unsigned prev) const {
102-
if (prev >= max())
108+
if (prev >= max() - 1)
103109
return -1;
104110

105-
uint64_t b = storage >> (prev + 1);
106-
if (b == 0)
107-
return -1;
111+
uint64_t b = storage >> (prev + 1ULL);
112+
assert(b != 0);
108113

109-
return llvm::countr_zero(b) + prev + 1;
114+
return llvm::countr_zero(b) + prev + 1ULL;
110115
}
111116

112117
bool operator[](unsigned i) const {
113118
assert(i < 64);
114-
return (storage & (1 << i)) != 0;
119+
return (storage & (static_cast<int64_t>(1) << i)) != 0;
120+
}
121+
unsigned min() const {
122+
unsigned m = llvm::countr_zero(storage);
123+
return m == 64 ? -1 : m;
115124
}
116-
unsigned min() const { return llvm::countr_zero(storage); }
117125
unsigned max() const { return 64 - llvm::countl_zero(storage); }
118126
unsigned count() const { return llvm::popcount(storage); }
119127
bool empty() const { return storage == 0; }

mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1787,6 +1787,10 @@ def SparseTensor_CoIterateOp : SparseTensor_Op<"coiterate",
17871787
.take_back(getRegionDefinedSpace(regionIdx).count());
17881788
}
17891789
ValueRange getYieldedValues(unsigned regionIdx);
1790+
1791+
// Returns a vector of regions that are the `sub-cases` of the given case region.
1792+
// E.g., `case %it1, _, %it3` is a subcase of `case %it1, %it2, %it3`.
1793+
SmallVector<Region *> getSubCasesOf(unsigned regionIdx);
17901794
}];
17911795

17921796
let hasVerifier = 1;

mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2745,6 +2745,16 @@ LogicalResult CoIterateOp::verifyRegions() {
27452745
return success();
27462746
}
27472747

2748+
SmallVector<Region *> CoIterateOp::getSubCasesOf(unsigned regionIdx) {
2749+
SmallVector<Region *> ret;
2750+
I64BitSet caseBit = getRegionDefinedSpace(regionIdx);
2751+
for (Region &r : getCaseRegions())
2752+
if (getRegionDefinedSpace(r.getRegionNumber()).isSubSetOf(caseBit))
2753+
ret.push_back(&r);
2754+
2755+
return ret;
2756+
}
2757+
27482758
//===----------------------------------------------------------------------===//
27492759
// Sparse Tensor Dialect Setups.
27502760
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp

Lines changed: 290 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11

22
#include "Utils/CodegenUtils.h"
3+
#include "Utils/LoopEmitter.h"
34
#include "Utils/SparseTensorIterator.h"
45

56
#include "mlir/Dialect/MemRef/IR/MemRef.h"
@@ -49,6 +50,144 @@ convertIteratorType(IteratorType itTp, SmallVectorImpl<Type> &fields) {
4950
return success();
5051
}
5152

53+
static ValueRange
54+
genCoIterateBranchNest(PatternRewriter &rewriter, Location loc, CoIterateOp op,
55+
Value loopCrd,
56+
ArrayRef<std::unique_ptr<SparseIterator>> iters,
57+
ArrayRef<Region *> subCases, ArrayRef<Value> userReduc) {
58+
if (subCases.empty())
59+
return userReduc;
60+
61+
// The current branch that we are handling.
62+
Region *b = subCases.front();
63+
Value casePred = constantI1(rewriter, loc, true);
64+
I64BitSet caseBits = op.getRegionDefinedSpace(b->getRegionNumber());
65+
for (unsigned i : caseBits.bits()) {
66+
SparseIterator *it = iters[i].get();
67+
Value pred = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
68+
it->getCrd(), loopCrd);
69+
casePred = rewriter.create<arith::AndIOp>(loc, casePred, pred);
70+
}
71+
scf::IfOp ifOp = rewriter.create<scf::IfOp>(
72+
loc, ValueRange(userReduc).getTypes(), casePred, /*else=*/true);
73+
rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front());
74+
75+
// Erase the empty block.
76+
rewriter.eraseBlock(&ifOp.getThenRegion().front());
77+
// Set up block arguments: user-provided values -> loop coord -> iterators.
78+
SmallVector<Value> blockArgs(userReduc);
79+
blockArgs.push_back(loopCrd);
80+
for (unsigned idx : caseBits.bits())
81+
llvm::append_range(blockArgs, iters[idx]->getCursor());
82+
83+
IRMapping mapping;
84+
for (auto [from, to] :
85+
llvm::zip_equal(b->front().getArguments(), blockArgs)) {
86+
mapping.map(from, to);
87+
}
88+
89+
// Clone the region, we can not erase the region now because the same region
90+
// might be a subcase for multiple lattice point.
91+
rewriter.cloneRegionBefore(*b, ifOp.getThenRegion(),
92+
ifOp.getThenRegion().begin(), mapping);
93+
94+
// replace sparse_tensor::YieldOp -> scf::YieldOp
95+
auto spY = cast<sparse_tensor::YieldOp>(&ifOp.getThenRegion().front().back());
96+
ValueRange yields = spY.getResults();
97+
rewriter.eraseOp(spY);
98+
rewriter.setInsertionPointToEnd(&ifOp.getThenRegion().front());
99+
rewriter.create<scf::YieldOp>(loc, yields);
100+
101+
// Generates remaining case recursively.
102+
rewriter.setInsertionPointToStart(&ifOp.getElseRegion().front());
103+
ValueRange res = genCoIterateBranchNest(rewriter, loc, op, loopCrd, iters,
104+
subCases.drop_front(), userReduc);
105+
if (!res.empty())
106+
rewriter.create<scf::YieldOp>(loc, res);
107+
108+
rewriter.setInsertionPointAfter(ifOp);
109+
return ifOp.getResults();
110+
}
111+
112+
static ValueRange genLoopWithIterator(
113+
PatternRewriter &rewriter, Location loc, SparseIterator *it,
114+
ValueRange reduc, bool iterFirst,
115+
function_ref<SmallVector<Value>(PatternRewriter &rewriter, Location loc,
116+
Region &loopBody, SparseIterator *it,
117+
ValueRange reduc)>
118+
bodyBuilder) {
119+
if (it->iteratableByFor()) {
120+
auto [lo, hi] = it->genForCond(rewriter, loc);
121+
Value step = constantIndex(rewriter, loc, 1);
122+
scf::ForOp forOp = rewriter.create<scf::ForOp>(loc, lo, hi, step, reduc);
123+
{
124+
OpBuilder::InsertionGuard guard(rewriter);
125+
// Erase the implicit yield operation created by ForOp when there is no
126+
// yielding values.
127+
if (!forOp.getBody()->empty())
128+
rewriter.eraseOp(&forOp.getBody()->front());
129+
assert(forOp.getBody()->empty());
130+
131+
it->linkNewScope(forOp.getInductionVar());
132+
rewriter.setInsertionPointToStart(forOp.getBody());
133+
SmallVector<Value> ret = bodyBuilder(rewriter, loc, forOp.getBodyRegion(),
134+
it, forOp.getRegionIterArgs());
135+
136+
rewriter.setInsertionPointToEnd(forOp.getBody());
137+
rewriter.create<scf::YieldOp>(loc, ret);
138+
}
139+
return forOp.getResults();
140+
}
141+
SmallVector<Value> ivs;
142+
// TODO: always put iterator SSA values at the end of argument list to be
143+
// consistent with coiterate operation.
144+
if (!iterFirst)
145+
llvm::append_range(ivs, it->getCursor());
146+
// Appends the user-provided values.
147+
llvm::append_range(ivs, reduc);
148+
if (iterFirst)
149+
llvm::append_range(ivs, it->getCursor());
150+
151+
TypeRange types = ValueRange(ivs).getTypes();
152+
auto whileOp = rewriter.create<scf::WhileOp>(loc, types, ivs);
153+
{
154+
OpBuilder::InsertionGuard guard(rewriter);
155+
// Generates loop conditions.
156+
SmallVector<Location> l(types.size(), loc);
157+
Block *before = rewriter.createBlock(&whileOp.getBefore(), {}, types, l);
158+
rewriter.setInsertionPointToStart(before);
159+
ValueRange bArgs = before->getArguments();
160+
auto [whileCond, remArgs] = it->genWhileCond(rewriter, loc, bArgs);
161+
rewriter.create<scf::ConditionOp>(loc, whileCond, before->getArguments());
162+
163+
// Delegates loop body generation.
164+
Region &dstRegion = whileOp.getAfter();
165+
Block *after = rewriter.createBlock(&dstRegion, {}, types, l);
166+
ValueRange aArgs = whileOp.getAfterArguments();
167+
if (iterFirst) {
168+
aArgs = it->linkNewScope(aArgs);
169+
} else {
170+
aArgs = aArgs.take_front(reduc.size());
171+
it->linkNewScope(aArgs.drop_front(reduc.size()));
172+
}
173+
174+
rewriter.setInsertionPointToStart(after);
175+
SmallVector<Value> ret = bodyBuilder(rewriter, loc, dstRegion, it, aArgs);
176+
rewriter.setInsertionPointToEnd(after);
177+
178+
// Forward loops
179+
SmallVector<Value> yields;
180+
ValueRange nx = it->forward(rewriter, loc);
181+
if (iterFirst)
182+
llvm::append_range(yields, nx);
183+
llvm::append_range(yields, ret);
184+
if (!iterFirst)
185+
llvm::append_range(yields, nx);
186+
rewriter.create<scf::YieldOp>(loc, yields);
187+
}
188+
return whileOp.getResults().drop_front(it->getCursor().size());
189+
}
190+
52191
namespace {
53192

54193
/// Sparse codegen rule for number of entries operator.
@@ -136,6 +275,8 @@ class SparseIterateOpConverter : public OneToNOpConversionPattern<IterateOp> {
136275
rewriter.replaceOp(op, forOp.getResults(), resultMapping);
137276
} else {
138277
SmallVector<Value> ivs;
278+
// TODO: put iterator at the end of argument list to be consistent with
279+
// coiterate operation.
139280
llvm::append_range(ivs, it->getCursor());
140281
for (ValueRange inits : adaptor.getInitArgs())
141282
llvm::append_range(ivs, inits);
@@ -189,6 +330,153 @@ class SparseIterateOpConverter : public OneToNOpConversionPattern<IterateOp> {
189330
}
190331
};
191332

333+
class SparseCoIterateOpConverter
334+
: public OneToNOpConversionPattern<CoIterateOp> {
335+
using OneToNOpConversionPattern::OneToNOpConversionPattern;
336+
337+
LogicalResult
338+
matchAndRewrite(CoIterateOp op, OpAdaptor adaptor,
339+
OneToNPatternRewriter &rewriter) const override {
340+
assert(op.getSpaceDim() == 1 && "Not implemented");
341+
Location loc = op.getLoc();
342+
343+
I64BitSet denseBits(0);
344+
for (auto [idx, spaceTp] : llvm::enumerate(op.getIterSpaces().getTypes()))
345+
if (all_of(cast<IterSpaceType>(spaceTp).getLvlTypes(), isDenseLT))
346+
denseBits.set(idx);
347+
348+
// If there exists a case that only contains dense spaces. I.e., case
349+
// bits is a subset of dense bits, or when there is a full empty case (due
350+
// to complements), we need a universal pointer to forward the coiteration
351+
// loop.
352+
bool needUniv =
353+
any_of(op.getRegionDefinedSpaces(), [denseBits](I64BitSet caseBits) {
354+
// A case for complement.
355+
if (caseBits.count() == 0)
356+
return true;
357+
// An all-dense case.
358+
return caseBits.isSubSetOf(denseBits);
359+
});
360+
assert(!needUniv && "Not implemented");
361+
(void)needUniv;
362+
363+
for (Region &region : op.getCaseRegions()) {
364+
// Do a one-shot type conversion on all region blocks, since the same
365+
// region might be used multiple time.
366+
Block *block = &region.getBlocks().front();
367+
OneToNTypeMapping blockTypeMapping(block->getArgumentTypes());
368+
if (failed(typeConverter->convertSignatureArgs(block->getArgumentTypes(),
369+
blockTypeMapping)))
370+
return rewriter.notifyMatchFailure(
371+
op, "failed to convert coiterate region argurment types");
372+
373+
rewriter.applySignatureConversion(block, blockTypeMapping);
374+
}
375+
376+
SmallVector<SparseIterationSpace> spaces;
377+
SmallVector<std::unique_ptr<SparseIterator>> iters;
378+
for (auto [spaceTp, spaceVals] : llvm::zip_equal(
379+
op.getIterSpaces().getTypes(), adaptor.getIterSpaces())) {
380+
// TODO: do we really need tid?
381+
spaces.push_back(SparseIterationSpace::fromValues(
382+
cast<IterSpaceType>(spaceTp), spaceVals, /*tid=*/0));
383+
// Extract the iterator.
384+
iters.push_back(spaces.back().extractIterator(rewriter, loc));
385+
}
386+
387+
auto getFilteredIters = [&iters](I64BitSet caseBits) {
388+
// Retrives a vector of pointers to the iterators used in the case.
389+
SmallVector<SparseIterator *> validIters;
390+
for (auto idx : caseBits.bits())
391+
validIters.push_back(iters[idx].get());
392+
return validIters;
393+
};
394+
395+
// Get a flattened user-provided loop reduction values.
396+
SmallVector<Value> userReduc;
397+
for (ValueRange r : adaptor.getInitArgs())
398+
llvm::append_range(userReduc, r);
399+
400+
// TODO: we need to sort the cases such that they appears in lexical order.
401+
// Although sparsification always generates cases in that order, it might
402+
// not be the case for human-written code.
403+
404+
// Generates a loop sequence, one loop per case.
405+
for (auto [r, caseBits] :
406+
llvm::zip_equal(op.getCaseRegions(), op.getRegionDefinedSpaces())) {
407+
assert(caseBits.count() > 0 && "Complement space not implemented");
408+
409+
// Retrives a vector of pointers to the iterators used in the case.
410+
SmallVector<SparseIterator *> validIters = getFilteredIters(caseBits);
411+
412+
if (validIters.size() > 1) {
413+
auto [loop, loopCrd] =
414+
genCoIteration(rewriter, loc, validIters, userReduc,
415+
/*uniIdx=*/nullptr, /*userReducFirst=*/true);
416+
417+
// 1st. find all the cases that is a strict subset of the current case
418+
// condition, for which we generate one branch per case inside the loop.
419+
// The subcases are never empty, it must contains at least the current
420+
// region itself.
421+
// TODO: these cases should be sorted.
422+
SmallVector<Region *> subCases = op.getSubCasesOf(r.getRegionNumber());
423+
assert(!subCases.empty());
424+
425+
ValueRange res = genCoIterateBranchNest(rewriter, loc, op, loopCrd,
426+
iters, subCases, userReduc);
427+
428+
SmallVector<Value> nextIterYields(res);
429+
// 2nd. foward the loop.
430+
for (SparseIterator *it : validIters) {
431+
Value cmp = rewriter.create<arith::CmpIOp>(
432+
loc, arith::CmpIPredicate::eq, it->getCrd(), loopCrd);
433+
it->forwardIf(rewriter, loc, cmp);
434+
llvm::append_range(nextIterYields, it->getCursor());
435+
}
436+
rewriter.create<scf::YieldOp>(loc, nextIterYields);
437+
438+
// Exit the loop, relink the iterator SSA value.
439+
rewriter.setInsertionPointAfter(loop);
440+
ValueRange iterVals = loop->getResults().drop_front(userReduc.size());
441+
for (SparseIterator *it : validIters)
442+
iterVals = it->linkNewScope(iterVals);
443+
assert(iterVals.empty());
444+
445+
ValueRange curResult = loop->getResults().take_front(userReduc.size());
446+
userReduc.assign(curResult.begin(), curResult.end());
447+
} else {
448+
// This is a simple iteration loop.
449+
assert(caseBits.count() == 1);
450+
451+
Block *block = &r.getBlocks().front();
452+
ValueRange curResult = genLoopWithIterator(
453+
rewriter, loc, validIters.front(), userReduc, /*iterFirst=*/false,
454+
/*bodyBuilder=*/
455+
[block](PatternRewriter &rewriter, Location loc, Region &dstRegion,
456+
SparseIterator *it,
457+
ValueRange reduc) -> SmallVector<Value> {
458+
SmallVector<Value> blockArgs(reduc);
459+
blockArgs.push_back(it->deref(rewriter, loc));
460+
llvm::append_range(blockArgs, it->getCursor());
461+
462+
Block *dstBlock = &dstRegion.getBlocks().front();
463+
rewriter.inlineBlockBefore(
464+
block, dstBlock, rewriter.getInsertionPoint(), blockArgs);
465+
auto yield = llvm::cast<sparse_tensor::YieldOp>(dstBlock->back());
466+
SmallVector<Value> result(yield.getResults());
467+
rewriter.eraseOp(yield);
468+
return result;
469+
});
470+
471+
userReduc.assign(curResult.begin(), curResult.end());
472+
}
473+
}
474+
475+
rewriter.replaceOp(op, userReduc);
476+
return success();
477+
}
478+
};
479+
192480
} // namespace
193481

194482
mlir::SparseIterationTypeConverter::SparseIterationTypeConverter() {
@@ -210,5 +498,6 @@ void mlir::populateLowerSparseIterationToSCFPatterns(
210498

211499
IterateOp::getCanonicalizationPatterns(patterns, patterns.getContext());
212500
patterns.add<ExtractIterSpaceConverter, ExtractValOpConverter,
213-
SparseIterateOpConverter>(converter, patterns.getContext());
501+
SparseIterateOpConverter, SparseCoIterateOpConverter>(
502+
converter, patterns.getContext());
214503
}

0 commit comments

Comments
 (0)