Skip to content

Commit 7186704

Browse files
author
Peiming Liu
authored
[mlir][sparse] refactoring sparse_tensor.iterate lowering pattern implementation. (#105566)
1 parent a968ae6 commit 7186704

File tree

1 file changed

+36
-82
lines changed

1 file changed

+36
-82
lines changed

mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp

Lines changed: 36 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -244,88 +244,41 @@ class SparseIterateOpConverter : public OneToNOpConversionPattern<IterateOp> {
244244
std::unique_ptr<SparseIterator> it =
245245
iterSpace.extractIterator(rewriter, loc);
246246

247-
if (it->iteratableByFor()) {
248-
auto [lo, hi] = it->genForCond(rewriter, loc);
249-
Value step = constantIndex(rewriter, loc, 1);
250-
SmallVector<Value> ivs;
251-
for (ValueRange inits : adaptor.getInitArgs())
252-
llvm::append_range(ivs, inits);
253-
scf::ForOp forOp = rewriter.create<scf::ForOp>(loc, lo, hi, step, ivs);
254-
255-
Block *loopBody = op.getBody();
256-
OneToNTypeMapping bodyTypeMapping(loopBody->getArgumentTypes());
257-
if (failed(typeConverter->convertSignatureArgs(
258-
loopBody->getArgumentTypes(), bodyTypeMapping)))
259-
return failure();
260-
rewriter.applySignatureConversion(loopBody, bodyTypeMapping);
261-
262-
rewriter.eraseBlock(forOp.getBody());
263-
Region &dstRegion = forOp.getRegion();
264-
rewriter.inlineRegionBefore(op.getRegion(), dstRegion, dstRegion.end());
265-
266-
auto yieldOp =
267-
llvm::cast<sparse_tensor::YieldOp>(forOp.getBody()->getTerminator());
268-
269-
rewriter.setInsertionPointToEnd(forOp.getBody());
270-
// replace sparse_tensor.yield with scf.yield.
271-
rewriter.create<scf::YieldOp>(loc, yieldOp.getResults());
272-
rewriter.eraseOp(yieldOp);
273-
274-
const OneToNTypeMapping &resultMapping = adaptor.getResultMapping();
275-
rewriter.replaceOp(op, forOp.getResults(), resultMapping);
276-
} else {
277-
SmallVector<Value> ivs;
278-
// TODO: put iterator at the end of argument list to be consistent with
279-
// coiterate operation.
280-
llvm::append_range(ivs, it->getCursor());
281-
for (ValueRange inits : adaptor.getInitArgs())
282-
llvm::append_range(ivs, inits);
283-
284-
assert(llvm::all_of(ivs, [](Value v) { return v != nullptr; }));
285-
286-
TypeRange types = ValueRange(ivs).getTypes();
287-
auto whileOp = rewriter.create<scf::WhileOp>(loc, types, ivs);
288-
SmallVector<Location> l(types.size(), op.getIterator().getLoc());
289-
290-
// Generates loop conditions.
291-
Block *before = rewriter.createBlock(&whileOp.getBefore(), {}, types, l);
292-
rewriter.setInsertionPointToStart(before);
293-
ValueRange bArgs = before->getArguments();
294-
auto [whileCond, remArgs] = it->genWhileCond(rewriter, loc, bArgs);
295-
assert(remArgs.size() == adaptor.getInitArgs().size());
296-
rewriter.create<scf::ConditionOp>(loc, whileCond, before->getArguments());
297-
298-
// Generates loop body.
299-
Block *loopBody = op.getBody();
300-
OneToNTypeMapping bodyTypeMapping(loopBody->getArgumentTypes());
301-
if (failed(typeConverter->convertSignatureArgs(
302-
loopBody->getArgumentTypes(), bodyTypeMapping)))
303-
return failure();
304-
rewriter.applySignatureConversion(loopBody, bodyTypeMapping);
305-
306-
Region &dstRegion = whileOp.getAfter();
307-
// TODO: handle uses of coordinate!
308-
rewriter.inlineRegionBefore(op.getRegion(), dstRegion, dstRegion.end());
309-
ValueRange aArgs = whileOp.getAfterArguments();
310-
auto yieldOp = llvm::cast<sparse_tensor::YieldOp>(
311-
whileOp.getAfterBody()->getTerminator());
312-
313-
rewriter.setInsertionPointToEnd(whileOp.getAfterBody());
247+
SmallVector<Value> ivs;
248+
for (ValueRange inits : adaptor.getInitArgs())
249+
llvm::append_range(ivs, inits);
250+
251+
// Type conversion on iterate op block.
252+
OneToNTypeMapping blockTypeMapping(op.getBody()->getArgumentTypes());
253+
if (failed(typeConverter->convertSignatureArgs(
254+
op.getBody()->getArgumentTypes(), blockTypeMapping)))
255+
return rewriter.notifyMatchFailure(
256+
op, "failed to convert iterate region argurment types");
257+
rewriter.applySignatureConversion(op.getBody(), blockTypeMapping);
258+
259+
Block *block = op.getBody();
260+
ValueRange ret = genLoopWithIterator(
261+
rewriter, loc, it.get(), ivs, /*iterFirst=*/true,
262+
[block](PatternRewriter &rewriter, Location loc, Region &loopBody,
263+
SparseIterator *it, ValueRange reduc) -> SmallVector<Value> {
264+
SmallVector<Value> blockArgs(it->getCursor());
265+
// TODO: Also appends coordinates if used.
266+
// blockArgs.push_back(it->deref(rewriter, loc));
267+
llvm::append_range(blockArgs, reduc);
268+
269+
Block *dstBlock = &loopBody.getBlocks().front();
270+
rewriter.inlineBlockBefore(block, dstBlock, dstBlock->end(),
271+
blockArgs);
272+
auto yield = llvm::cast<sparse_tensor::YieldOp>(dstBlock->back());
273+
// We can not use ValueRange as the operation holding the values will
274+
// be destoryed.
275+
SmallVector<Value> result(yield.getResults());
276+
rewriter.eraseOp(yield);
277+
return result;
278+
});
314279

315-
aArgs = it->linkNewScope(aArgs);
316-
ValueRange nx = it->forward(rewriter, loc);
317-
SmallVector<Value> yields;
318-
llvm::append_range(yields, nx);
319-
llvm::append_range(yields, yieldOp.getResults());
320-
321-
// replace sparse_tensor.yield with scf.yield.
322-
rewriter.eraseOp(yieldOp);
323-
rewriter.create<scf::YieldOp>(loc, yields);
324-
const OneToNTypeMapping &resultMapping = adaptor.getResultMapping();
325-
rewriter.replaceOp(
326-
op, whileOp.getResults().drop_front(it->getCursor().size()),
327-
resultMapping);
328-
}
280+
const OneToNTypeMapping &resultMapping = adaptor.getResultMapping();
281+
rewriter.replaceOp(op, ret, resultMapping);
329282
return success();
330283
}
331284
};
@@ -366,9 +319,10 @@ class SparseCoIterateOpConverter
366319
Block *block = &region.getBlocks().front();
367320
OneToNTypeMapping blockTypeMapping(block->getArgumentTypes());
368321
if (failed(typeConverter->convertSignatureArgs(block->getArgumentTypes(),
369-
blockTypeMapping)))
322+
blockTypeMapping))) {
370323
return rewriter.notifyMatchFailure(
371324
op, "failed to convert coiterate region argurment types");
325+
}
372326

373327
rewriter.applySignatureConversion(block, blockTypeMapping);
374328
}

0 commit comments

Comments
 (0)