Skip to content

[mlir][sparse_tensor] Migrate SparseIterationToScf.cpp to dialect conversion #121054

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Dec 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 73 additions & 50 deletions mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,17 @@
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
#include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
#include "mlir/Transforms/OneToNTypeConversion.h"
#include "mlir/Transforms/DialectConversion.h"

using namespace mlir;
using namespace mlir::sparse_tensor;

/// Assert that the given value range contains a single value and return it.
static Value getSingleValue(ValueRange values) {
assert(values.size() == 1 && "expected single value");
return values.front();
}

static void convertLevelType(SparseTensorEncodingAttr enc, Level lvl,
SmallVectorImpl<Type> &fields) {
// Position and coordinate buffer in the sparse structure.
Expand Down Expand Up @@ -54,14 +60,17 @@ static ValueRange
genCoIterateBranchNest(PatternRewriter &rewriter, Location loc, CoIterateOp op,
Value loopCrd,
ArrayRef<std::unique_ptr<SparseIterator>> iters,
ArrayRef<Region *> subCases, ArrayRef<Value> userReduc) {
if (subCases.empty())
ArrayRef<Block *> newBlocks, ArrayRef<Block *> oldBlocks,
ArrayRef<Value> userReduc) {
if (newBlocks.empty())
return userReduc;

// The current branch that we are handling.
Region *b = subCases.front();
Block *newBlock = newBlocks.front();
Block *oldBlock = oldBlocks.front();
Value casePred = constantI1(rewriter, loc, true);
I64BitSet caseBits = op.getRegionDefinedSpace(b->getRegionNumber());
I64BitSet caseBits =
op.getRegionDefinedSpace(newBlock->getParent()->getRegionNumber());
for (unsigned i : caseBits.bits()) {
SparseIterator *it = iters[i].get();
Value pred = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
Expand All @@ -80,16 +89,20 @@ genCoIterateBranchNest(PatternRewriter &rewriter, Location loc, CoIterateOp op,
for (unsigned idx : caseBits.bits())
llvm::append_range(blockArgs, iters[idx]->getCursor());

// Map the old block arguments, because the dialect conversion driver does
// not immediately perform SSA value replacements. This function is still
// seeing the old uses.
IRMapping mapping;
for (auto [from, to] :
llvm::zip_equal(b->front().getArguments(), blockArgs)) {
for (auto [from, to] : llvm::zip_equal(oldBlock->getArguments(), blockArgs)) {
mapping.map(from, to);
}

// Clone the region, we can not erase the region now because the same region
// might be a subcase for multiple lattice point.
rewriter.cloneRegionBefore(*b, ifOp.getThenRegion(),
rewriter.cloneRegionBefore(*newBlock->getParent(), ifOp.getThenRegion(),
ifOp.getThenRegion().begin(), mapping);
// Remove the block arguments, they were already replaced via `mapping`.
ifOp.getThenRegion().front().eraseArguments(0, blockArgs.size());

// replace sparse_tensor::YieldOp -> scf::YieldOp
auto spY = cast<sparse_tensor::YieldOp>(&ifOp.getThenRegion().front().back());
Expand All @@ -101,7 +114,8 @@ genCoIterateBranchNest(PatternRewriter &rewriter, Location loc, CoIterateOp op,
// Generates remaining case recursively.
rewriter.setInsertionPointToStart(&ifOp.getElseRegion().front());
ValueRange res = genCoIterateBranchNest(rewriter, loc, op, loopCrd, iters,
subCases.drop_front(), userReduc);
newBlocks.drop_front(),
oldBlocks.drop_front(), userReduc);
if (!res.empty())
rewriter.create<scf::YieldOp>(loc, res);

Expand All @@ -119,15 +133,13 @@ static ValueRange genLoopWithIterator(
if (it->iteratableByFor()) {
auto [lo, hi] = it->genForCond(rewriter, loc);
Value step = constantIndex(rewriter, loc, 1);
scf::ForOp forOp = rewriter.create<scf::ForOp>(loc, lo, hi, step, reduc);
scf::ForOp forOp = rewriter.create<scf::ForOp>(
loc, lo, hi, step, reduc,
[&](OpBuilder &b, Location loc, Value iv, ValueRange iterArgs) {
// Empty builder function to ensure that no terminator is created.
});
{
OpBuilder::InsertionGuard guard(rewriter);
// Erase the implicit yield operation created by ForOp when there is no
// yielding values.
if (!forOp.getBody()->empty())
rewriter.eraseOp(&forOp.getBody()->front());
assert(forOp.getBody()->empty());

it->linkNewScope(forOp.getInductionVar());
rewriter.setInsertionPointToStart(forOp.getBody());
SmallVector<Value> ret = bodyBuilder(rewriter, loc, forOp.getBodyRegion(),
Expand Down Expand Up @@ -178,46 +190,47 @@ namespace {

/// Sparse codegen rule for number of entries operator.
class ExtractIterSpaceConverter
: public OneToNOpConversionPattern<ExtractIterSpaceOp> {
: public OpConversionPattern<ExtractIterSpaceOp> {
public:
using OneToNOpConversionPattern::OneToNOpConversionPattern;
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(ExtractIterSpaceOp op, OpAdaptor adaptor,
OneToNPatternRewriter &rewriter) const override {
matchAndRewrite(ExtractIterSpaceOp op, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op.getLoc();
const OneToNTypeMapping &resultMapping = adaptor.getResultMapping();

// Construct the iteration space.
SparseIterationSpace space(loc, rewriter, op.getTensor(), 0,
SparseIterationSpace space(loc, rewriter,
getSingleValue(adaptor.getTensor()), 0,
op.getLvlRange(), adaptor.getParentIter());

SmallVector<Value> result = space.toValues();
rewriter.replaceOp(op, result, resultMapping);
rewriter.replaceOpWithMultiple(op, {result});
return success();
}
};

/// Sparse codegen rule for number of entries operator.
class ExtractValOpConverter : public OneToNOpConversionPattern<ExtractValOp> {
class ExtractValOpConverter : public OpConversionPattern<ExtractValOp> {
public:
using OneToNOpConversionPattern::OneToNOpConversionPattern;
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(ExtractValOp op, OpAdaptor adaptor,
OneToNPatternRewriter &rewriter) const override {
matchAndRewrite(ExtractValOp op, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op.getLoc();
Value pos = adaptor.getIterator().back();
Value valBuf = rewriter.create<ToValuesOp>(loc, op.getTensor());
Value valBuf =
rewriter.create<ToValuesOp>(loc, getSingleValue(adaptor.getTensor()));
rewriter.replaceOpWithNewOp<memref::LoadOp>(op, valBuf, pos);
return success();
}
};

class SparseIterateOpConverter : public OneToNOpConversionPattern<IterateOp> {
class SparseIterateOpConverter : public OpConversionPattern<IterateOp> {
public:
using OneToNOpConversionPattern::OneToNOpConversionPattern;
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(IterateOp op, OpAdaptor adaptor,
OneToNPatternRewriter &rewriter) const override {
matchAndRewrite(IterateOp op, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (!op.getCrdUsedLvls().empty())
return rewriter.notifyMatchFailure(
op, "non-empty coordinates list not implemented.");
Expand All @@ -235,14 +248,15 @@ class SparseIterateOpConverter : public OneToNOpConversionPattern<IterateOp> {
llvm::append_range(ivs, inits);

// Type conversion on iterate op block.
OneToNTypeMapping blockTypeMapping(op.getBody()->getArgumentTypes());
unsigned numOrigArgs = op.getBody()->getArgumentTypes().size();
TypeConverter::SignatureConversion signatureConversion(numOrigArgs);
if (failed(typeConverter->convertSignatureArgs(
op.getBody()->getArgumentTypes(), blockTypeMapping)))
op.getBody()->getArgumentTypes(), signatureConversion)))
return rewriter.notifyMatchFailure(
op, "failed to convert iterate region argurment types");
rewriter.applySignatureConversion(op.getBody(), blockTypeMapping);

Block *block = op.getBody();
Block *block = rewriter.applySignatureConversion(
op.getBody(), signatureConversion, getTypeConverter());
ValueRange ret = genLoopWithIterator(
rewriter, loc, it.get(), ivs,
[block](PatternRewriter &rewriter, Location loc, Region &loopBody,
Expand All @@ -263,19 +277,17 @@ class SparseIterateOpConverter : public OneToNOpConversionPattern<IterateOp> {
return result;
});

const OneToNTypeMapping &resultMapping = adaptor.getResultMapping();
rewriter.replaceOp(op, ret, resultMapping);
rewriter.replaceOp(op, ret);
return success();
}
};

class SparseCoIterateOpConverter
: public OneToNOpConversionPattern<CoIterateOp> {
using OneToNOpConversionPattern::OneToNOpConversionPattern;
class SparseCoIterateOpConverter : public OpConversionPattern<CoIterateOp> {
using OpConversionPattern::OpConversionPattern;

LogicalResult
matchAndRewrite(CoIterateOp op, OpAdaptor adaptor,
OneToNPatternRewriter &rewriter) const override {
matchAndRewrite(CoIterateOp op, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
assert(op.getSpaceDim() == 1 && "Not implemented");
Location loc = op.getLoc();

Expand All @@ -299,18 +311,23 @@ class SparseCoIterateOpConverter
assert(!needUniv && "Not implemented");
(void)needUniv;

SmallVector<Block *> newBlocks;
DenseMap<Block *, Block *> newToOldBlockMap;
for (Region &region : op.getCaseRegions()) {
// Do a one-shot type conversion on all region blocks, since the same
// region might be used multiple time.
Block *block = &region.getBlocks().front();
OneToNTypeMapping blockTypeMapping(block->getArgumentTypes());
TypeConverter::SignatureConversion blockTypeMapping(
block->getArgumentTypes().size());
if (failed(typeConverter->convertSignatureArgs(block->getArgumentTypes(),
blockTypeMapping))) {
return rewriter.notifyMatchFailure(
op, "failed to convert coiterate region argurment types");
}

rewriter.applySignatureConversion(block, blockTypeMapping);
newBlocks.push_back(rewriter.applySignatureConversion(
block, blockTypeMapping, getTypeConverter()));
newToOldBlockMap[newBlocks.back()] = block;
}

SmallVector<SparseIterationSpace> spaces;
Expand Down Expand Up @@ -343,7 +360,7 @@ class SparseCoIterateOpConverter

// Generates a loop sequence, one loop per case.
for (auto [r, caseBits] :
llvm::zip_equal(op.getCaseRegions(), op.getRegionDefinedSpaces())) {
llvm::zip_equal(newBlocks, op.getRegionDefinedSpaces())) {
assert(caseBits.count() > 0 && "Complement space not implemented");

// Retrives a vector of pointers to the iterators used in the case.
Expand All @@ -359,11 +376,17 @@ class SparseCoIterateOpConverter
// The subcases are never empty, it must contains at least the current
// region itself.
// TODO: these cases should be sorted.
SmallVector<Region *> subCases = op.getSubCasesOf(r.getRegionNumber());
SmallVector<Region *> subCases =
op.getSubCasesOf(r->getParent()->getRegionNumber());
SmallVector<Block *> newBlocks, oldBlocks;
for (Region *r : subCases) {
newBlocks.push_back(&r->front());
oldBlocks.push_back(newToOldBlockMap[newBlocks.back()]);
}
assert(!subCases.empty());

ValueRange res = genCoIterateBranchNest(rewriter, loc, op, loopCrd,
iters, subCases, userReduc);
ValueRange res = genCoIterateBranchNest(
rewriter, loc, op, loopCrd, iters, newBlocks, oldBlocks, userReduc);

SmallVector<Value> nextIterYields(res);
// 2nd. foward the loop.
Expand All @@ -388,7 +411,7 @@ class SparseCoIterateOpConverter
// This is a simple iteration loop.
assert(caseBits.count() == 1);

Block *block = &r.getBlocks().front();
Block *block = r;
ValueRange curResult = genLoopWithIterator(
rewriter, loc, validIters.front(), userReduc,
/*bodyBuilder=*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -172,11 +172,16 @@ struct LowerSparseIterationToSCFPass
ConversionTarget target(*ctx);

// The actual conversion.
target.addIllegalOp<ExtractIterSpaceOp, IterateOp>();
target.addLegalDialect<arith::ArithDialect, linalg::LinalgDialect,
memref::MemRefDialect, scf::SCFDialect,
sparse_tensor::SparseTensorDialect>();
target.addIllegalOp<CoIterateOp, ExtractIterSpaceOp, ExtractValOp,
IterateOp>();
target.addLegalOp<UnrealizedConversionCastOp>();
populateLowerSparseIterationToSCFPatterns(converter, patterns);

if (failed(applyPartialOneToNConversion(getOperation(), converter,
std::move(patterns))))
if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns))))
signalPassFailure();
}
};
Expand Down
Loading