-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][sparse] partially support lowering sparse coiteration loops to scf.while/for. #105565
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
42102ed
to
2f8d0f8
Compare
@llvm/pr-subscribers-mlir Author: Peiming Liu (PeimingLiu) ChangesStacked PRs:
[mlir][sparse] partially support lowering sparse coiteration loops to scf.while/for.Patch is 35.78 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/105565.diff 9 Files Affected:
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
index 388efd1c454b1e..915a0cd8d92973 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
@@ -96,24 +96,32 @@ class I64BitSet {
return *this;
}
+ bool isSubSetOf(const I64BitSet p) const {
+ I64BitSet tmp = *this;
+ tmp |= p;
+ return tmp == p;
+ }
+
// Needed by `llvm::const_set_bits_iterator_impl`.
int find_first() const { return min(); }
int find_next(unsigned prev) const {
- if (prev >= max())
+ if (prev >= max() - 1)
return -1;
- uint64_t b = storage >> (prev + 1);
- if (b == 0)
- return -1;
+ uint64_t b = storage >> (prev + 1ULL);
+ assert(b != 0);
- return llvm::countr_zero(b) + prev + 1;
+ return llvm::countr_zero(b) + prev + 1ULL;
}
bool operator[](unsigned i) const {
assert(i < 64);
- return (storage & (1 << i)) != 0;
+ return (storage & (static_cast<int64_t>(1) << i)) != 0;
+ }
+ unsigned min() const {
+ unsigned m = llvm::countr_zero(storage);
+ return m == 64 ? -1 : m;
}
- unsigned min() const { return llvm::countr_zero(storage); }
unsigned max() const { return 64 - llvm::countl_zero(storage); }
unsigned count() const { return llvm::popcount(storage); }
bool empty() const { return storage == 0; }
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
index 2803223354d5ee..20512f972e67cd 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
@@ -1787,6 +1787,10 @@ def SparseTensor_CoIterateOp : SparseTensor_Op<"coiterate",
.take_back(getRegionDefinedSpace(regionIdx).count());
}
ValueRange getYieldedValues(unsigned regionIdx);
+
+ // Returns a vector of regions that are the `sub-cases` of the given case region.
+ // E.g., `case %it1, _, %it3` is a subcase of `case %it1, %it2, %it3`.
+ SmallVector<Region *> getSubCasesOf(unsigned regionIdx);
}];
let hasVerifier = 1;
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index a143189c301a43..16856b958d4f13 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -2745,6 +2745,16 @@ LogicalResult CoIterateOp::verifyRegions() {
return success();
}
+SmallVector<Region *> CoIterateOp::getSubCasesOf(unsigned regionIdx) {
+ SmallVector<Region *> ret;
+ I64BitSet caseBit = getRegionDefinedSpace(regionIdx);
+ for (Region &r : getCaseRegions())
+ if (getRegionDefinedSpace(r.getRegionNumber()).isSubSetOf(caseBit))
+ ret.push_back(&r);
+
+ return ret;
+}
+
//===----------------------------------------------------------------------===//
// Sparse Tensor Dialect Setups.
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp
index b1451dee738ac3..d6c0da4a9e4573 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp
@@ -1,5 +1,6 @@
#include "Utils/CodegenUtils.h"
+#include "Utils/LoopEmitter.h"
#include "Utils/SparseTensorIterator.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
@@ -49,6 +50,144 @@ convertIteratorType(IteratorType itTp, SmallVectorImpl<Type> &fields) {
return success();
}
+static ValueRange
+genCoIterateBranchNest(PatternRewriter &rewriter, Location loc, CoIterateOp op,
+ Value loopCrd,
+ ArrayRef<std::unique_ptr<SparseIterator>> iters,
+ ArrayRef<Region *> subCases, ArrayRef<Value> userReduc) {
+ if (subCases.empty())
+ return userReduc;
+
+ // The current branch that we are handling.
+ Region *b = subCases.front();
+ Value casePred = constantI1(rewriter, loc, true);
+ I64BitSet caseBits = op.getRegionDefinedSpace(b->getRegionNumber());
+ for (unsigned i : caseBits.bits()) {
+ SparseIterator *it = iters[i].get();
+ Value pred = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
+ it->getCrd(), loopCrd);
+ casePred = rewriter.create<arith::AndIOp>(loc, casePred, pred);
+ }
+ scf::IfOp ifOp = rewriter.create<scf::IfOp>(
+ loc, ValueRange(userReduc).getTypes(), casePred, /*else=*/true);
+ rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front());
+
+ // Erase the empty block.
+ rewriter.eraseBlock(&ifOp.getThenRegion().front());
+ // Set up block arguments: user-provided values -> loop coord -> iterators.
+ SmallVector<Value> blockArgs(userReduc);
+ blockArgs.push_back(loopCrd);
+ for (unsigned idx : caseBits.bits())
+ llvm::append_range(blockArgs, iters[idx]->getCursor());
+
+ IRMapping mapping;
+ for (auto [from, to] :
+ llvm::zip_equal(b->front().getArguments(), blockArgs)) {
+ mapping.map(from, to);
+ }
+
+ // Clone the region, we can not erase the region now because the same region
+ // might be a subcase for multiple lattice point.
+ rewriter.cloneRegionBefore(*b, ifOp.getThenRegion(),
+ ifOp.getThenRegion().begin(), mapping);
+
+ // replace sparse_tensor::YieldOp -> scf::YieldOp
+ auto spY = cast<sparse_tensor::YieldOp>(&ifOp.getThenRegion().front().back());
+ ValueRange yields = spY.getResults();
+ rewriter.eraseOp(spY);
+ rewriter.setInsertionPointToEnd(&ifOp.getThenRegion().front());
+ rewriter.create<scf::YieldOp>(loc, yields);
+
+ // Generates remaining case recursively.
+ rewriter.setInsertionPointToStart(&ifOp.getElseRegion().front());
+ ValueRange res = genCoIterateBranchNest(rewriter, loc, op, loopCrd, iters,
+ subCases.drop_front(), userReduc);
+ if (!res.empty())
+ rewriter.create<scf::YieldOp>(loc, res);
+
+ rewriter.setInsertionPointAfter(ifOp);
+ return ifOp.getResults();
+}
+
+static ValueRange genLoopWithIterator(
+ PatternRewriter &rewriter, Location loc, SparseIterator *it,
+ ValueRange reduc, bool iterFirst,
+ function_ref<SmallVector<Value>(PatternRewriter &rewriter, Location loc,
+ Region &loopBody, SparseIterator *it,
+ ValueRange reduc)>
+ bodyBuilder) {
+ if (it->iteratableByFor()) {
+ auto [lo, hi] = it->genForCond(rewriter, loc);
+ Value step = constantIndex(rewriter, loc, 1);
+ scf::ForOp forOp = rewriter.create<scf::ForOp>(loc, lo, hi, step, reduc);
+ {
+ OpBuilder::InsertionGuard guard(rewriter);
+ // Erase the implicit yield operation created by ForOp when there is no
+ // yielding values.
+ if (!forOp.getBody()->empty())
+ rewriter.eraseOp(&forOp.getBody()->front());
+ assert(forOp.getBody()->empty());
+
+ it->linkNewScope(forOp.getInductionVar());
+ rewriter.setInsertionPointToStart(forOp.getBody());
+ SmallVector<Value> ret = bodyBuilder(rewriter, loc, forOp.getBodyRegion(),
+ it, forOp.getRegionIterArgs());
+
+ rewriter.setInsertionPointToEnd(forOp.getBody());
+ rewriter.create<scf::YieldOp>(loc, ret);
+ }
+ return forOp.getResults();
+ }
+ SmallVector<Value> ivs;
+ // TODO: always put iterator SSA values at the end of argument list to be
+ // consistent with coiterate operation.
+ if (!iterFirst)
+ llvm::append_range(ivs, it->getCursor());
+ // Appends the user-provided values.
+ llvm::append_range(ivs, reduc);
+ if (iterFirst)
+ llvm::append_range(ivs, it->getCursor());
+
+ TypeRange types = ValueRange(ivs).getTypes();
+ auto whileOp = rewriter.create<scf::WhileOp>(loc, types, ivs);
+ {
+ OpBuilder::InsertionGuard guard(rewriter);
+ // Generates loop conditions.
+ SmallVector<Location> l(types.size(), loc);
+ Block *before = rewriter.createBlock(&whileOp.getBefore(), {}, types, l);
+ rewriter.setInsertionPointToStart(before);
+ ValueRange bArgs = before->getArguments();
+ auto [whileCond, remArgs] = it->genWhileCond(rewriter, loc, bArgs);
+ rewriter.create<scf::ConditionOp>(loc, whileCond, before->getArguments());
+
+ // Delegates loop body generation.
+ Region &dstRegion = whileOp.getAfter();
+ Block *after = rewriter.createBlock(&dstRegion, {}, types, l);
+ ValueRange aArgs = whileOp.getAfterArguments();
+ if (iterFirst) {
+ aArgs = it->linkNewScope(aArgs);
+ } else {
+ aArgs = aArgs.take_front(reduc.size());
+ it->linkNewScope(aArgs.drop_front(reduc.size()));
+ }
+
+ rewriter.setInsertionPointToStart(after);
+ SmallVector<Value> ret = bodyBuilder(rewriter, loc, dstRegion, it, aArgs);
+ rewriter.setInsertionPointToEnd(after);
+
+ // Forward loops
+ SmallVector<Value> yields;
+ ValueRange nx = it->forward(rewriter, loc);
+ if (iterFirst)
+ llvm::append_range(yields, nx);
+ llvm::append_range(yields, ret);
+ if (!iterFirst)
+ llvm::append_range(yields, nx);
+ rewriter.create<scf::YieldOp>(loc, yields);
+ }
+ return whileOp.getResults().drop_front(it->getCursor().size());
+}
+
namespace {
/// Sparse codegen rule for number of entries operator.
@@ -136,6 +275,8 @@ class SparseIterateOpConverter : public OneToNOpConversionPattern<IterateOp> {
rewriter.replaceOp(op, forOp.getResults(), resultMapping);
} else {
SmallVector<Value> ivs;
+ // TODO: put iterator at the end of argument list to be consistent with
+ // coiterate operation.
llvm::append_range(ivs, it->getCursor());
for (ValueRange inits : adaptor.getInitArgs())
llvm::append_range(ivs, inits);
@@ -189,6 +330,153 @@ class SparseIterateOpConverter : public OneToNOpConversionPattern<IterateOp> {
}
};
+class SparseCoIterateOpConverter
+ : public OneToNOpConversionPattern<CoIterateOp> {
+ using OneToNOpConversionPattern::OneToNOpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(CoIterateOp op, OpAdaptor adaptor,
+ OneToNPatternRewriter &rewriter) const override {
+ assert(op.getSpaceDim() == 1 && "Not implemented");
+ Location loc = op.getLoc();
+
+ I64BitSet denseBits(0);
+ for (auto [idx, spaceTp] : llvm::enumerate(op.getIterSpaces().getTypes()))
+ if (all_of(cast<IterSpaceType>(spaceTp).getLvlTypes(), isDenseLT))
+ denseBits.set(idx);
+
+ // If there exists a case that only contains dense spaces. I.e., case
+ // bits is a subset of dense bits, or when there is a full empty case (due
+ // to complements), we need a universal pointer to forward the coiteration
+ // loop.
+ bool needUniv =
+ any_of(op.getRegionDefinedSpaces(), [denseBits](I64BitSet caseBits) {
+ // A case for complement.
+ if (caseBits.count() == 0)
+ return true;
+ // An all-dense case.
+ return caseBits.isSubSetOf(denseBits);
+ });
+ assert(!needUniv && "Not implemented");
+ (void)needUniv;
+
+ for (Region ®ion : op.getCaseRegions()) {
+ // Do a one-shot type conversion on all region blocks, since the same
+ // region might be used multiple time.
+ Block *block = ®ion.getBlocks().front();
+ OneToNTypeMapping blockTypeMapping(block->getArgumentTypes());
+ if (failed(typeConverter->convertSignatureArgs(block->getArgumentTypes(),
+ blockTypeMapping)))
+ return rewriter.notifyMatchFailure(
+ op, "failed to convert coiterate region argurment types");
+
+ rewriter.applySignatureConversion(block, blockTypeMapping);
+ }
+
+ SmallVector<SparseIterationSpace> spaces;
+ SmallVector<std::unique_ptr<SparseIterator>> iters;
+ for (auto [spaceTp, spaceVals] : llvm::zip_equal(
+ op.getIterSpaces().getTypes(), adaptor.getIterSpaces())) {
+ // TODO: do we really need tid?
+ spaces.push_back(SparseIterationSpace::fromValues(
+ cast<IterSpaceType>(spaceTp), spaceVals, /*tid=*/0));
+ // Extract the iterator.
+ iters.push_back(spaces.back().extractIterator(rewriter, loc));
+ }
+
+ auto getFilteredIters = [&iters](I64BitSet caseBits) {
+ // Retrives a vector of pointers to the iterators used in the case.
+ SmallVector<SparseIterator *> validIters;
+ for (auto idx : caseBits.bits())
+ validIters.push_back(iters[idx].get());
+ return validIters;
+ };
+
+ // Get a flattened user-provided loop reduction values.
+ SmallVector<Value> userReduc;
+ for (ValueRange r : adaptor.getInitArgs())
+ llvm::append_range(userReduc, r);
+
+ // TODO: we need to sort the cases such that they appears in lexical order.
+ // Although sparsification always generates cases in that order, it might
+ // not be the case for human-written code.
+
+ // Generates a loop sequence, one loop per case.
+ for (auto [r, caseBits] :
+ llvm::zip_equal(op.getCaseRegions(), op.getRegionDefinedSpaces())) {
+ assert(caseBits.count() > 0 && "Complement space not implemented");
+
+ // Retrives a vector of pointers to the iterators used in the case.
+ SmallVector<SparseIterator *> validIters = getFilteredIters(caseBits);
+
+ if (validIters.size() > 1) {
+ auto [loop, loopCrd] =
+ genCoIteration(rewriter, loc, validIters, userReduc,
+ /*uniIdx=*/nullptr, /*userReducFirst=*/true);
+
+ // 1st. find all the cases that is a strict subset of the current case
+ // condition, for which we generate one branch per case inside the loop.
+ // The subcases are never empty, it must contains at least the current
+ // region itself.
+ // TODO: these cases should be sorted.
+ SmallVector<Region *> subCases = op.getSubCasesOf(r.getRegionNumber());
+ assert(!subCases.empty());
+
+ ValueRange res = genCoIterateBranchNest(rewriter, loc, op, loopCrd,
+ iters, subCases, userReduc);
+
+ SmallVector<Value> nextIterYields(res);
+ // 2nd. foward the loop.
+ for (SparseIterator *it : validIters) {
+ Value cmp = rewriter.create<arith::CmpIOp>(
+ loc, arith::CmpIPredicate::eq, it->getCrd(), loopCrd);
+ it->forwardIf(rewriter, loc, cmp);
+ llvm::append_range(nextIterYields, it->getCursor());
+ }
+ rewriter.create<scf::YieldOp>(loc, nextIterYields);
+
+ // Exit the loop, relink the iterator SSA value.
+ rewriter.setInsertionPointAfter(loop);
+ ValueRange iterVals = loop->getResults().drop_front(userReduc.size());
+ for (SparseIterator *it : validIters)
+ iterVals = it->linkNewScope(iterVals);
+ assert(iterVals.empty());
+
+ ValueRange curResult = loop->getResults().take_front(userReduc.size());
+ userReduc.assign(curResult.begin(), curResult.end());
+ } else {
+ // This is a simple iteration loop.
+ assert(caseBits.count() == 1);
+
+ Block *block = &r.getBlocks().front();
+ ValueRange curResult = genLoopWithIterator(
+ rewriter, loc, validIters.front(), userReduc, /*iterFirst=*/false,
+ /*bodyBuilder=*/
+ [block](PatternRewriter &rewriter, Location loc, Region &dstRegion,
+ SparseIterator *it,
+ ValueRange reduc) -> SmallVector<Value> {
+ SmallVector<Value> blockArgs(reduc);
+ blockArgs.push_back(it->deref(rewriter, loc));
+ llvm::append_range(blockArgs, it->getCursor());
+
+ Block *dstBlock = &dstRegion.getBlocks().front();
+ rewriter.inlineBlockBefore(
+ block, dstBlock, rewriter.getInsertionPoint(), blockArgs);
+ auto yield = llvm::cast<sparse_tensor::YieldOp>(dstBlock->back());
+ SmallVector<Value> result(yield.getResults());
+ rewriter.eraseOp(yield);
+ return result;
+ });
+
+ userReduc.assign(curResult.begin(), curResult.end());
+ }
+ }
+
+ rewriter.replaceOp(op, userReduc);
+ return success();
+ }
+};
+
} // namespace
mlir::SparseIterationTypeConverter::SparseIterationTypeConverter() {
@@ -210,5 +498,6 @@ void mlir::populateLowerSparseIterationToSCFPatterns(
IterateOp::getCanonicalizationPatterns(patterns, patterns.getContext());
patterns.add<ExtractIterSpaceConverter, ExtractValOpConverter,
- SparseIterateOpConverter>(converter, patterns.getContext());
+ SparseIterateOpConverter, SparseCoIterateOpConverter>(
+ converter, patterns.getContext());
}
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp
index efb3295fb2a4bf..cb5874ff45068e 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp
@@ -524,84 +524,8 @@ std::pair<Operation *, Value> LoopEmitter::emitForLoopOverTensorAtLvl(
std::pair<Operation *, Value> LoopEmitter::emitWhileLoopOverTensorsAtLvls(
OpBuilder &builder, Location loc, ArrayRef<SparseIterator *> spIters,
MutableArrayRef<Value> reduc, bool needsUniv) {
- // NOTE: the slice driven tensor-related reduction variable must
- // appear before normal tensors.
-
- // The set of induction variables for the while loop.
- SmallVector<Value> ivs;
-
- // Construct the while-loop with a parameter for each coordinate.
- for (SparseIterator *it : spIters) {
- ValueRange itVals = it->getCursor();
- ivs.append(itVals.begin(), itVals.end());
- }
-
- // The position where user-supplied reduction variable starts.
- ivs.append(reduc.begin(), reduc.end());
- // Update universal index.
- if (needsUniv)
- ivs.push_back(loopSeqStack.back().first);
-
- // Ensures all operands are valid.
- assert(llvm::all_of(ivs, [](Value v) { return v != nullptr; }));
- TypeRange types = ValueRange(ivs).getTypes();
- auto whileOp = builder.create<scf::WhileOp>(loc, types, ivs);
-
- SmallVector<Location> locs(types.size(), loc);
- Block *before = builder.createBlock(&whileOp.getBefore(), {}, types, locs);
- Block *after = builder.createBlock(&whileOp.getAfter(), {}, types, locs);
-
- // Generates loop conditions.
- builder.setInsertionPointToStart(before);
- ValueRange bArgs = before->getArguments();
- Value whileCond = nullptr; // bool values for loop condition.
-
- for (SparseIterator *it : spIters) {
- auto [cond, remArgs] = it->genWhileCond(builder, loc, bArgs);
- whileCond = !whileCond ? cond : ANDI(whileCond, cond);
- bArgs = remArgs;
- }
- // The remaining block arguments are user-provided reduction values and an
- // optional universal index. Make sure their sizes match.
- assert(bArgs.size() == reduc.size() + needsUniv);
- builder.create<scf::ConditionOp>(loc, whileCond, before->getArguments());
-
- // Generates loop body.
- builder.setInsertionPointToStart(after);
- ValueRange aArgs = after->getArguments();
- // Since some LoopCondKind might need extra checks to filter out invalid
- // iterations, we maintains another array to hold the iteration arguments to
- // yield if the checks fails.
- SmallVector<Value> nextArgs(aArgs.begin(), aArgs.end());
-
- for (SparseIterator *it : spIters) {
- aArgs = it->linkNewScope(aArgs);
- // Dereference the iterator to cache the coordinate.
- it->deref(builder, loc);
- }
-
- // In-place update on reduction variable.
- assert(aArgs.size() == reduc.size() + needsUniv);
- for (unsigned i = 0, e = reduc.size(); i < e; i++)
- reduc[i] = aArgs[i];
-
- Value min;
- // Finds the minimum coordinate
- if (!needsUniv) {
- for (SparseIterator *it : spIters) {
- if (min) {
- Value cmp = CMPI(ult, it->getCrd(), min);
- min = SELECT(cmp, it->getCrd(), min);
- } else {
- min = it->getCrd();
- }
- }
- } else {
- // Otherwise, universal index is the minimal pos.
- min = whileOp.getAfterAr...
[truncated]
|
@llvm/pr-subscribers-mlir-sparse Author: Peiming Liu (PeimingLiu) ChangesStacked PRs:
[mlir][sparse] partially support lowering sparse coiteration loops to scf.while/for.Patch is 35.78 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/105565.diff 9 Files Affected:
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
index 388efd1c454b1e..915a0cd8d92973 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
@@ -96,24 +96,32 @@ class I64BitSet {
return *this;
}
+ bool isSubSetOf(const I64BitSet p) const {
+ I64BitSet tmp = *this;
+ tmp |= p;
+ return tmp == p;
+ }
+
// Needed by `llvm::const_set_bits_iterator_impl`.
int find_first() const { return min(); }
int find_next(unsigned prev) const {
- if (prev >= max())
+ if (prev >= max() - 1)
return -1;
- uint64_t b = storage >> (prev + 1);
- if (b == 0)
- return -1;
+ uint64_t b = storage >> (prev + 1ULL);
+ assert(b != 0);
- return llvm::countr_zero(b) + prev + 1;
+ return llvm::countr_zero(b) + prev + 1ULL;
}
bool operator[](unsigned i) const {
assert(i < 64);
- return (storage & (1 << i)) != 0;
+ return (storage & (static_cast<int64_t>(1) << i)) != 0;
+ }
+ unsigned min() const {
+ unsigned m = llvm::countr_zero(storage);
+ return m == 64 ? -1 : m;
}
- unsigned min() const { return llvm::countr_zero(storage); }
unsigned max() const { return 64 - llvm::countl_zero(storage); }
unsigned count() const { return llvm::popcount(storage); }
bool empty() const { return storage == 0; }
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
index 2803223354d5ee..20512f972e67cd 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
@@ -1787,6 +1787,10 @@ def SparseTensor_CoIterateOp : SparseTensor_Op<"coiterate",
.take_back(getRegionDefinedSpace(regionIdx).count());
}
ValueRange getYieldedValues(unsigned regionIdx);
+
+ // Returns a vector of regions that are the `sub-cases` of the given case region.
+ // E.g., `case %it1, _, %it3` is a subcase of `case %it1, %it2, %it3`.
+ SmallVector<Region *> getSubCasesOf(unsigned regionIdx);
}];
let hasVerifier = 1;
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index a143189c301a43..16856b958d4f13 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -2745,6 +2745,16 @@ LogicalResult CoIterateOp::verifyRegions() {
return success();
}
+SmallVector<Region *> CoIterateOp::getSubCasesOf(unsigned regionIdx) {
+ SmallVector<Region *> ret;
+ I64BitSet caseBit = getRegionDefinedSpace(regionIdx);
+ for (Region &r : getCaseRegions())
+ if (getRegionDefinedSpace(r.getRegionNumber()).isSubSetOf(caseBit))
+ ret.push_back(&r);
+
+ return ret;
+}
+
//===----------------------------------------------------------------------===//
// Sparse Tensor Dialect Setups.
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp
index b1451dee738ac3..d6c0da4a9e4573 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp
@@ -1,5 +1,6 @@
#include "Utils/CodegenUtils.h"
+#include "Utils/LoopEmitter.h"
#include "Utils/SparseTensorIterator.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
@@ -49,6 +50,144 @@ convertIteratorType(IteratorType itTp, SmallVectorImpl<Type> &fields) {
return success();
}
+static ValueRange
+genCoIterateBranchNest(PatternRewriter &rewriter, Location loc, CoIterateOp op,
+ Value loopCrd,
+ ArrayRef<std::unique_ptr<SparseIterator>> iters,
+ ArrayRef<Region *> subCases, ArrayRef<Value> userReduc) {
+ if (subCases.empty())
+ return userReduc;
+
+ // The current branch that we are handling.
+ Region *b = subCases.front();
+ Value casePred = constantI1(rewriter, loc, true);
+ I64BitSet caseBits = op.getRegionDefinedSpace(b->getRegionNumber());
+ for (unsigned i : caseBits.bits()) {
+ SparseIterator *it = iters[i].get();
+ Value pred = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
+ it->getCrd(), loopCrd);
+ casePred = rewriter.create<arith::AndIOp>(loc, casePred, pred);
+ }
+ scf::IfOp ifOp = rewriter.create<scf::IfOp>(
+ loc, ValueRange(userReduc).getTypes(), casePred, /*else=*/true);
+ rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front());
+
+ // Erase the empty block.
+ rewriter.eraseBlock(&ifOp.getThenRegion().front());
+ // Set up block arguments: user-provided values -> loop coord -> iterators.
+ SmallVector<Value> blockArgs(userReduc);
+ blockArgs.push_back(loopCrd);
+ for (unsigned idx : caseBits.bits())
+ llvm::append_range(blockArgs, iters[idx]->getCursor());
+
+ IRMapping mapping;
+ for (auto [from, to] :
+ llvm::zip_equal(b->front().getArguments(), blockArgs)) {
+ mapping.map(from, to);
+ }
+
+ // Clone the region, we can not erase the region now because the same region
+ // might be a subcase for multiple lattice point.
+ rewriter.cloneRegionBefore(*b, ifOp.getThenRegion(),
+ ifOp.getThenRegion().begin(), mapping);
+
+ // replace sparse_tensor::YieldOp -> scf::YieldOp
+ auto spY = cast<sparse_tensor::YieldOp>(&ifOp.getThenRegion().front().back());
+ ValueRange yields = spY.getResults();
+ rewriter.eraseOp(spY);
+ rewriter.setInsertionPointToEnd(&ifOp.getThenRegion().front());
+ rewriter.create<scf::YieldOp>(loc, yields);
+
+ // Generates remaining case recursively.
+ rewriter.setInsertionPointToStart(&ifOp.getElseRegion().front());
+ ValueRange res = genCoIterateBranchNest(rewriter, loc, op, loopCrd, iters,
+ subCases.drop_front(), userReduc);
+ if (!res.empty())
+ rewriter.create<scf::YieldOp>(loc, res);
+
+ rewriter.setInsertionPointAfter(ifOp);
+ return ifOp.getResults();
+}
+
+static ValueRange genLoopWithIterator(
+ PatternRewriter &rewriter, Location loc, SparseIterator *it,
+ ValueRange reduc, bool iterFirst,
+ function_ref<SmallVector<Value>(PatternRewriter &rewriter, Location loc,
+ Region &loopBody, SparseIterator *it,
+ ValueRange reduc)>
+ bodyBuilder) {
+ if (it->iteratableByFor()) {
+ auto [lo, hi] = it->genForCond(rewriter, loc);
+ Value step = constantIndex(rewriter, loc, 1);
+ scf::ForOp forOp = rewriter.create<scf::ForOp>(loc, lo, hi, step, reduc);
+ {
+ OpBuilder::InsertionGuard guard(rewriter);
+ // Erase the implicit yield operation created by ForOp when there is no
+ // yielding values.
+ if (!forOp.getBody()->empty())
+ rewriter.eraseOp(&forOp.getBody()->front());
+ assert(forOp.getBody()->empty());
+
+ it->linkNewScope(forOp.getInductionVar());
+ rewriter.setInsertionPointToStart(forOp.getBody());
+ SmallVector<Value> ret = bodyBuilder(rewriter, loc, forOp.getBodyRegion(),
+ it, forOp.getRegionIterArgs());
+
+ rewriter.setInsertionPointToEnd(forOp.getBody());
+ rewriter.create<scf::YieldOp>(loc, ret);
+ }
+ return forOp.getResults();
+ }
+ SmallVector<Value> ivs;
+ // TODO: always put iterator SSA values at the end of argument list to be
+ // consistent with coiterate operation.
+ if (!iterFirst)
+ llvm::append_range(ivs, it->getCursor());
+ // Appends the user-provided values.
+ llvm::append_range(ivs, reduc);
+ if (iterFirst)
+ llvm::append_range(ivs, it->getCursor());
+
+ TypeRange types = ValueRange(ivs).getTypes();
+ auto whileOp = rewriter.create<scf::WhileOp>(loc, types, ivs);
+ {
+ OpBuilder::InsertionGuard guard(rewriter);
+ // Generates loop conditions.
+ SmallVector<Location> l(types.size(), loc);
+ Block *before = rewriter.createBlock(&whileOp.getBefore(), {}, types, l);
+ rewriter.setInsertionPointToStart(before);
+ ValueRange bArgs = before->getArguments();
+ auto [whileCond, remArgs] = it->genWhileCond(rewriter, loc, bArgs);
+ rewriter.create<scf::ConditionOp>(loc, whileCond, before->getArguments());
+
+ // Delegates loop body generation.
+ Region &dstRegion = whileOp.getAfter();
+ Block *after = rewriter.createBlock(&dstRegion, {}, types, l);
+ ValueRange aArgs = whileOp.getAfterArguments();
+ if (iterFirst) {
+ aArgs = it->linkNewScope(aArgs);
+ } else {
+ aArgs = aArgs.take_front(reduc.size());
+ it->linkNewScope(aArgs.drop_front(reduc.size()));
+ }
+
+ rewriter.setInsertionPointToStart(after);
+ SmallVector<Value> ret = bodyBuilder(rewriter, loc, dstRegion, it, aArgs);
+ rewriter.setInsertionPointToEnd(after);
+
+ // Forward loops
+ SmallVector<Value> yields;
+ ValueRange nx = it->forward(rewriter, loc);
+ if (iterFirst)
+ llvm::append_range(yields, nx);
+ llvm::append_range(yields, ret);
+ if (!iterFirst)
+ llvm::append_range(yields, nx);
+ rewriter.create<scf::YieldOp>(loc, yields);
+ }
+ return whileOp.getResults().drop_front(it->getCursor().size());
+}
+
namespace {
/// Sparse codegen rule for number of entries operator.
@@ -136,6 +275,8 @@ class SparseIterateOpConverter : public OneToNOpConversionPattern<IterateOp> {
rewriter.replaceOp(op, forOp.getResults(), resultMapping);
} else {
SmallVector<Value> ivs;
+ // TODO: put iterator at the end of argument list to be consistent with
+ // coiterate operation.
llvm::append_range(ivs, it->getCursor());
for (ValueRange inits : adaptor.getInitArgs())
llvm::append_range(ivs, inits);
@@ -189,6 +330,153 @@ class SparseIterateOpConverter : public OneToNOpConversionPattern<IterateOp> {
}
};
+class SparseCoIterateOpConverter
+ : public OneToNOpConversionPattern<CoIterateOp> {
+ using OneToNOpConversionPattern::OneToNOpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(CoIterateOp op, OpAdaptor adaptor,
+ OneToNPatternRewriter &rewriter) const override {
+ assert(op.getSpaceDim() == 1 && "Not implemented");
+ Location loc = op.getLoc();
+
+ I64BitSet denseBits(0);
+ for (auto [idx, spaceTp] : llvm::enumerate(op.getIterSpaces().getTypes()))
+ if (all_of(cast<IterSpaceType>(spaceTp).getLvlTypes(), isDenseLT))
+ denseBits.set(idx);
+
+ // If there exists a case that only contains dense spaces. I.e., case
+ // bits is a subset of dense bits, or when there is a full empty case (due
+ // to complements), we need a universal pointer to forward the coiteration
+ // loop.
+ bool needUniv =
+ any_of(op.getRegionDefinedSpaces(), [denseBits](I64BitSet caseBits) {
+ // A case for complement.
+ if (caseBits.count() == 0)
+ return true;
+ // An all-dense case.
+ return caseBits.isSubSetOf(denseBits);
+ });
+ assert(!needUniv && "Not implemented");
+ (void)needUniv;
+
+ for (Region ®ion : op.getCaseRegions()) {
+ // Do a one-shot type conversion on all region blocks, since the same
+ // region might be used multiple time.
+ Block *block = ®ion.getBlocks().front();
+ OneToNTypeMapping blockTypeMapping(block->getArgumentTypes());
+ if (failed(typeConverter->convertSignatureArgs(block->getArgumentTypes(),
+ blockTypeMapping)))
+ return rewriter.notifyMatchFailure(
+ op, "failed to convert coiterate region argurment types");
+
+ rewriter.applySignatureConversion(block, blockTypeMapping);
+ }
+
+ SmallVector<SparseIterationSpace> spaces;
+ SmallVector<std::unique_ptr<SparseIterator>> iters;
+ for (auto [spaceTp, spaceVals] : llvm::zip_equal(
+ op.getIterSpaces().getTypes(), adaptor.getIterSpaces())) {
+ // TODO: do we really need tid?
+ spaces.push_back(SparseIterationSpace::fromValues(
+ cast<IterSpaceType>(spaceTp), spaceVals, /*tid=*/0));
+ // Extract the iterator.
+ iters.push_back(spaces.back().extractIterator(rewriter, loc));
+ }
+
+ auto getFilteredIters = [&iters](I64BitSet caseBits) {
+ // Retrives a vector of pointers to the iterators used in the case.
+ SmallVector<SparseIterator *> validIters;
+ for (auto idx : caseBits.bits())
+ validIters.push_back(iters[idx].get());
+ return validIters;
+ };
+
+ // Get a flattened user-provided loop reduction values.
+ SmallVector<Value> userReduc;
+ for (ValueRange r : adaptor.getInitArgs())
+ llvm::append_range(userReduc, r);
+
+ // TODO: we need to sort the cases such that they appears in lexical order.
+ // Although sparsification always generates cases in that order, it might
+ // not be the case for human-written code.
+
+ // Generates a loop sequence, one loop per case.
+ for (auto [r, caseBits] :
+ llvm::zip_equal(op.getCaseRegions(), op.getRegionDefinedSpaces())) {
+ assert(caseBits.count() > 0 && "Complement space not implemented");
+
+ // Retrives a vector of pointers to the iterators used in the case.
+ SmallVector<SparseIterator *> validIters = getFilteredIters(caseBits);
+
+ if (validIters.size() > 1) {
+ auto [loop, loopCrd] =
+ genCoIteration(rewriter, loc, validIters, userReduc,
+ /*uniIdx=*/nullptr, /*userReducFirst=*/true);
+
+ // 1st. find all the cases that is a strict subset of the current case
+ // condition, for which we generate one branch per case inside the loop.
+ // The subcases are never empty, it must contains at least the current
+ // region itself.
+ // TODO: these cases should be sorted.
+ SmallVector<Region *> subCases = op.getSubCasesOf(r.getRegionNumber());
+ assert(!subCases.empty());
+
+ ValueRange res = genCoIterateBranchNest(rewriter, loc, op, loopCrd,
+ iters, subCases, userReduc);
+
+ SmallVector<Value> nextIterYields(res);
+ // 2nd. foward the loop.
+ for (SparseIterator *it : validIters) {
+ Value cmp = rewriter.create<arith::CmpIOp>(
+ loc, arith::CmpIPredicate::eq, it->getCrd(), loopCrd);
+ it->forwardIf(rewriter, loc, cmp);
+ llvm::append_range(nextIterYields, it->getCursor());
+ }
+ rewriter.create<scf::YieldOp>(loc, nextIterYields);
+
+ // Exit the loop, relink the iterator SSA value.
+ rewriter.setInsertionPointAfter(loop);
+ ValueRange iterVals = loop->getResults().drop_front(userReduc.size());
+ for (SparseIterator *it : validIters)
+ iterVals = it->linkNewScope(iterVals);
+ assert(iterVals.empty());
+
+ ValueRange curResult = loop->getResults().take_front(userReduc.size());
+ userReduc.assign(curResult.begin(), curResult.end());
+ } else {
+ // This is a simple iteration loop.
+ assert(caseBits.count() == 1);
+
+ Block *block = &r.getBlocks().front();
+ ValueRange curResult = genLoopWithIterator(
+ rewriter, loc, validIters.front(), userReduc, /*iterFirst=*/false,
+ /*bodyBuilder=*/
+ [block](PatternRewriter &rewriter, Location loc, Region &dstRegion,
+ SparseIterator *it,
+ ValueRange reduc) -> SmallVector<Value> {
+ SmallVector<Value> blockArgs(reduc);
+ blockArgs.push_back(it->deref(rewriter, loc));
+ llvm::append_range(blockArgs, it->getCursor());
+
+ Block *dstBlock = &dstRegion.getBlocks().front();
+ rewriter.inlineBlockBefore(
+ block, dstBlock, rewriter.getInsertionPoint(), blockArgs);
+ auto yield = llvm::cast<sparse_tensor::YieldOp>(dstBlock->back());
+ SmallVector<Value> result(yield.getResults());
+ rewriter.eraseOp(yield);
+ return result;
+ });
+
+ userReduc.assign(curResult.begin(), curResult.end());
+ }
+ }
+
+ rewriter.replaceOp(op, userReduc);
+ return success();
+ }
+};
+
} // namespace
mlir::SparseIterationTypeConverter::SparseIterationTypeConverter() {
@@ -210,5 +498,6 @@ void mlir::populateLowerSparseIterationToSCFPatterns(
IterateOp::getCanonicalizationPatterns(patterns, patterns.getContext());
patterns.add<ExtractIterSpaceConverter, ExtractValOpConverter,
- SparseIterateOpConverter>(converter, patterns.getContext());
+ SparseIterateOpConverter, SparseCoIterateOpConverter>(
+ converter, patterns.getContext());
}
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp
index efb3295fb2a4bf..cb5874ff45068e 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/LoopEmitter.cpp
@@ -524,84 +524,8 @@ std::pair<Operation *, Value> LoopEmitter::emitForLoopOverTensorAtLvl(
std::pair<Operation *, Value> LoopEmitter::emitWhileLoopOverTensorsAtLvls(
OpBuilder &builder, Location loc, ArrayRef<SparseIterator *> spIters,
MutableArrayRef<Value> reduc, bool needsUniv) {
- // NOTE: the slice driven tensor-related reduction variable must
- // appear before normal tensors.
-
- // The set of induction variables for the while loop.
- SmallVector<Value> ivs;
-
- // Construct the while-loop with a parameter for each coordinate.
- for (SparseIterator *it : spIters) {
- ValueRange itVals = it->getCursor();
- ivs.append(itVals.begin(), itVals.end());
- }
-
- // The position where user-supplied reduction variable starts.
- ivs.append(reduc.begin(), reduc.end());
- // Update universal index.
- if (needsUniv)
- ivs.push_back(loopSeqStack.back().first);
-
- // Ensures all operands are valid.
- assert(llvm::all_of(ivs, [](Value v) { return v != nullptr; }));
- TypeRange types = ValueRange(ivs).getTypes();
- auto whileOp = builder.create<scf::WhileOp>(loc, types, ivs);
-
- SmallVector<Location> locs(types.size(), loc);
- Block *before = builder.createBlock(&whileOp.getBefore(), {}, types, locs);
- Block *after = builder.createBlock(&whileOp.getAfter(), {}, types, locs);
-
- // Generates loop conditions.
- builder.setInsertionPointToStart(before);
- ValueRange bArgs = before->getArguments();
- Value whileCond = nullptr; // bool values for loop condition.
-
- for (SparseIterator *it : spIters) {
- auto [cond, remArgs] = it->genWhileCond(builder, loc, bArgs);
- whileCond = !whileCond ? cond : ANDI(whileCond, cond);
- bArgs = remArgs;
- }
- // The remaining block arguments are user-provided reduction values and an
- // optional universal index. Make sure their sizes match.
- assert(bArgs.size() == reduc.size() + needsUniv);
- builder.create<scf::ConditionOp>(loc, whileCond, before->getArguments());
-
- // Generates loop body.
- builder.setInsertionPointToStart(after);
- ValueRange aArgs = after->getArguments();
- // Since some LoopCondKind might need extra checks to filter out invalid
- // iterations, we maintains another array to hold the iteration arguments to
- // yield if the checks fails.
- SmallVector<Value> nextArgs(aArgs.begin(), aArgs.end());
-
- for (SparseIterator *it : spIters) {
- aArgs = it->linkNewScope(aArgs);
- // Dereference the iterator to cache the coordinate.
- it->deref(builder, loc);
- }
-
- // In-place update on reduction variable.
- assert(aArgs.size() == reduc.size() + needsUniv);
- for (unsigned i = 0, e = reduc.size(); i < e; i++)
- reduc[i] = aArgs[i];
-
- Value min;
- // Finds the minimum coordinate
- if (!needsUniv) {
- for (SparseIterator *it : spIters) {
- if (min) {
- Value cmp = CMPI(ult, it->getCrd(), min);
- min = SELECT(cmp, it->getCrd(), min);
- } else {
- min = it->getCrd();
- }
- }
- } else {
- // Otherwise, universal index is the minimal pos.
- min = whileOp.getAfterAr...
[truncated]
|
… scf.while/for. stack-info: PR: llvm#105565, branch: users/PeimingLiu/stack/1
… scf.while/for. stack-info: PR: #105565, branch: users/PeimingLiu/stack/1
2f8d0f8
to
494be13
Compare
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/116/builds/2611 Here is the relevant piece of the build log for the reference:
|
Stacked PRs:
[mlir][sparse] partially support lowering sparse coiteration loops to scf.while/for.