Skip to content

[mlir][sparse] refactoring sparse_tensor.iterate lowering pattern implementation. #105566

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Aug 23, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
118 changes: 36 additions & 82 deletions mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -244,88 +244,41 @@ class SparseIterateOpConverter : public OneToNOpConversionPattern<IterateOp> {
std::unique_ptr<SparseIterator> it =
iterSpace.extractIterator(rewriter, loc);

if (it->iteratableByFor()) {
auto [lo, hi] = it->genForCond(rewriter, loc);
Value step = constantIndex(rewriter, loc, 1);
SmallVector<Value> ivs;
for (ValueRange inits : adaptor.getInitArgs())
llvm::append_range(ivs, inits);
scf::ForOp forOp = rewriter.create<scf::ForOp>(loc, lo, hi, step, ivs);

Block *loopBody = op.getBody();
OneToNTypeMapping bodyTypeMapping(loopBody->getArgumentTypes());
if (failed(typeConverter->convertSignatureArgs(
loopBody->getArgumentTypes(), bodyTypeMapping)))
return failure();
rewriter.applySignatureConversion(loopBody, bodyTypeMapping);

rewriter.eraseBlock(forOp.getBody());
Region &dstRegion = forOp.getRegion();
rewriter.inlineRegionBefore(op.getRegion(), dstRegion, dstRegion.end());

auto yieldOp =
llvm::cast<sparse_tensor::YieldOp>(forOp.getBody()->getTerminator());

rewriter.setInsertionPointToEnd(forOp.getBody());
// replace sparse_tensor.yield with scf.yield.
rewriter.create<scf::YieldOp>(loc, yieldOp.getResults());
rewriter.eraseOp(yieldOp);

const OneToNTypeMapping &resultMapping = adaptor.getResultMapping();
rewriter.replaceOp(op, forOp.getResults(), resultMapping);
} else {
SmallVector<Value> ivs;
// TODO: put iterator at the end of argument list to be consistent with
// coiterate operation.
llvm::append_range(ivs, it->getCursor());
for (ValueRange inits : adaptor.getInitArgs())
llvm::append_range(ivs, inits);

assert(llvm::all_of(ivs, [](Value v) { return v != nullptr; }));

TypeRange types = ValueRange(ivs).getTypes();
auto whileOp = rewriter.create<scf::WhileOp>(loc, types, ivs);
SmallVector<Location> l(types.size(), op.getIterator().getLoc());

// Generates loop conditions.
Block *before = rewriter.createBlock(&whileOp.getBefore(), {}, types, l);
rewriter.setInsertionPointToStart(before);
ValueRange bArgs = before->getArguments();
auto [whileCond, remArgs] = it->genWhileCond(rewriter, loc, bArgs);
assert(remArgs.size() == adaptor.getInitArgs().size());
rewriter.create<scf::ConditionOp>(loc, whileCond, before->getArguments());

// Generates loop body.
Block *loopBody = op.getBody();
OneToNTypeMapping bodyTypeMapping(loopBody->getArgumentTypes());
if (failed(typeConverter->convertSignatureArgs(
loopBody->getArgumentTypes(), bodyTypeMapping)))
return failure();
rewriter.applySignatureConversion(loopBody, bodyTypeMapping);

Region &dstRegion = whileOp.getAfter();
// TODO: handle uses of coordinate!
rewriter.inlineRegionBefore(op.getRegion(), dstRegion, dstRegion.end());
ValueRange aArgs = whileOp.getAfterArguments();
auto yieldOp = llvm::cast<sparse_tensor::YieldOp>(
whileOp.getAfterBody()->getTerminator());

rewriter.setInsertionPointToEnd(whileOp.getAfterBody());
SmallVector<Value> ivs;
for (ValueRange inits : adaptor.getInitArgs())
llvm::append_range(ivs, inits);

// Type conversion on iterate op block.
OneToNTypeMapping blockTypeMapping(op.getBody()->getArgumentTypes());
if (failed(typeConverter->convertSignatureArgs(
op.getBody()->getArgumentTypes(), blockTypeMapping)))
return rewriter.notifyMatchFailure(
op, "failed to convert iterate region argurment types");
rewriter.applySignatureConversion(op.getBody(), blockTypeMapping);

Block *block = op.getBody();
ValueRange ret = genLoopWithIterator(
rewriter, loc, it.get(), ivs, /*iterFirst=*/true,
[block](PatternRewriter &rewriter, Location loc, Region &loopBody,
SparseIterator *it, ValueRange reduc) -> SmallVector<Value> {
SmallVector<Value> blockArgs(it->getCursor());
// TODO: Also appends coordinates if used.
// blockArgs.push_back(it->deref(rewriter, loc));
llvm::append_range(blockArgs, reduc);

Block *dstBlock = &loopBody.getBlocks().front();
rewriter.inlineBlockBefore(block, dstBlock, dstBlock->end(),
blockArgs);
auto yield = llvm::cast<sparse_tensor::YieldOp>(dstBlock->back());
// We can not use ValueRange as the operation holding the values will
// be destoryed.
SmallVector<Value> result(yield.getResults());
rewriter.eraseOp(yield);
return result;
});

aArgs = it->linkNewScope(aArgs);
ValueRange nx = it->forward(rewriter, loc);
SmallVector<Value> yields;
llvm::append_range(yields, nx);
llvm::append_range(yields, yieldOp.getResults());

// replace sparse_tensor.yield with scf.yield.
rewriter.eraseOp(yieldOp);
rewriter.create<scf::YieldOp>(loc, yields);
const OneToNTypeMapping &resultMapping = adaptor.getResultMapping();
rewriter.replaceOp(
op, whileOp.getResults().drop_front(it->getCursor().size()),
resultMapping);
}
const OneToNTypeMapping &resultMapping = adaptor.getResultMapping();
rewriter.replaceOp(op, ret, resultMapping);
return success();
}
};
Expand Down Expand Up @@ -366,9 +319,10 @@ class SparseCoIterateOpConverter
Block *block = &region.getBlocks().front();
OneToNTypeMapping blockTypeMapping(block->getArgumentTypes());
if (failed(typeConverter->convertSignatureArgs(block->getArgumentTypes(),
blockTypeMapping)))
blockTypeMapping))) {
return rewriter.notifyMatchFailure(
op, "failed to convert coiterate region argurment types");
}

rewriter.applySignatureConversion(block, blockTypeMapping);
}
Expand Down
Loading