Skip to content

Commit 2b5b3cf

Browse files
[mlir][sparse_tensor] Migrate SparseIterationToScf.cpp to dialect conversion (#121054)
Use the regular dialect conversion driver instead of the 1:N dialect conversion driver. The 1:N dialect conversion driver will be removed soon.
1 parent bca055f commit 2b5b3cf

File tree

2 files changed

+81
-53
lines changed

2 files changed

+81
-53
lines changed

mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp

Lines changed: 73 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,17 @@
77
#include "mlir/Dialect/SCF/IR/SCF.h"
88
#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
99
#include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
10-
#include "mlir/Transforms/OneToNTypeConversion.h"
10+
#include "mlir/Transforms/DialectConversion.h"
1111

1212
using namespace mlir;
1313
using namespace mlir::sparse_tensor;
1414

15+
/// Assert that the given value range contains a single value and return it.
16+
static Value getSingleValue(ValueRange values) {
17+
assert(values.size() == 1 && "expected single value");
18+
return values.front();
19+
}
20+
1521
static void convertLevelType(SparseTensorEncodingAttr enc, Level lvl,
1622
SmallVectorImpl<Type> &fields) {
1723
// Position and coordinate buffer in the sparse structure.
@@ -54,14 +60,17 @@ static ValueRange
5460
genCoIterateBranchNest(PatternRewriter &rewriter, Location loc, CoIterateOp op,
5561
Value loopCrd,
5662
ArrayRef<std::unique_ptr<SparseIterator>> iters,
57-
ArrayRef<Region *> subCases, ArrayRef<Value> userReduc) {
58-
if (subCases.empty())
63+
ArrayRef<Block *> newBlocks, ArrayRef<Block *> oldBlocks,
64+
ArrayRef<Value> userReduc) {
65+
if (newBlocks.empty())
5966
return userReduc;
6067

6168
// The current branch that we are handling.
62-
Region *b = subCases.front();
69+
Block *newBlock = newBlocks.front();
70+
Block *oldBlock = oldBlocks.front();
6371
Value casePred = constantI1(rewriter, loc, true);
64-
I64BitSet caseBits = op.getRegionDefinedSpace(b->getRegionNumber());
72+
I64BitSet caseBits =
73+
op.getRegionDefinedSpace(newBlock->getParent()->getRegionNumber());
6574
for (unsigned i : caseBits.bits()) {
6675
SparseIterator *it = iters[i].get();
6776
Value pred = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
@@ -80,16 +89,20 @@ genCoIterateBranchNest(PatternRewriter &rewriter, Location loc, CoIterateOp op,
8089
for (unsigned idx : caseBits.bits())
8190
llvm::append_range(blockArgs, iters[idx]->getCursor());
8291

92+
// Map the old block arguments, because the dialect conversion driver does
93+
// not immediately perform SSA value replacements. This function is still
94+
// seeing the old uses.
8395
IRMapping mapping;
84-
for (auto [from, to] :
85-
llvm::zip_equal(b->front().getArguments(), blockArgs)) {
96+
for (auto [from, to] : llvm::zip_equal(oldBlock->getArguments(), blockArgs)) {
8697
mapping.map(from, to);
8798
}
8899

89100
// Clone the region, we can not erase the region now because the same region
90101
// might be a subcase for multiple lattice point.
91-
rewriter.cloneRegionBefore(*b, ifOp.getThenRegion(),
102+
rewriter.cloneRegionBefore(*newBlock->getParent(), ifOp.getThenRegion(),
92103
ifOp.getThenRegion().begin(), mapping);
104+
// Remove the block arguments, they were already replaced via `mapping`.
105+
ifOp.getThenRegion().front().eraseArguments(0, blockArgs.size());
93106

94107
// replace sparse_tensor::YieldOp -> scf::YieldOp
95108
auto spY = cast<sparse_tensor::YieldOp>(&ifOp.getThenRegion().front().back());
@@ -101,7 +114,8 @@ genCoIterateBranchNest(PatternRewriter &rewriter, Location loc, CoIterateOp op,
101114
// Generates remaining case recursively.
102115
rewriter.setInsertionPointToStart(&ifOp.getElseRegion().front());
103116
ValueRange res = genCoIterateBranchNest(rewriter, loc, op, loopCrd, iters,
104-
subCases.drop_front(), userReduc);
117+
newBlocks.drop_front(),
118+
oldBlocks.drop_front(), userReduc);
105119
if (!res.empty())
106120
rewriter.create<scf::YieldOp>(loc, res);
107121

@@ -119,15 +133,13 @@ static ValueRange genLoopWithIterator(
119133
if (it->iteratableByFor()) {
120134
auto [lo, hi] = it->genForCond(rewriter, loc);
121135
Value step = constantIndex(rewriter, loc, 1);
122-
scf::ForOp forOp = rewriter.create<scf::ForOp>(loc, lo, hi, step, reduc);
136+
scf::ForOp forOp = rewriter.create<scf::ForOp>(
137+
loc, lo, hi, step, reduc,
138+
[&](OpBuilder &b, Location loc, Value iv, ValueRange iterArgs) {
139+
// Empty builder function to ensure that no terminator is created.
140+
});
123141
{
124142
OpBuilder::InsertionGuard guard(rewriter);
125-
// Erase the implicit yield operation created by ForOp when there is no
126-
// yielding values.
127-
if (!forOp.getBody()->empty())
128-
rewriter.eraseOp(&forOp.getBody()->front());
129-
assert(forOp.getBody()->empty());
130-
131143
it->linkNewScope(forOp.getInductionVar());
132144
rewriter.setInsertionPointToStart(forOp.getBody());
133145
SmallVector<Value> ret = bodyBuilder(rewriter, loc, forOp.getBodyRegion(),
@@ -178,46 +190,47 @@ namespace {
178190

179191
/// Sparse codegen rule for number of entries operator.
180192
class ExtractIterSpaceConverter
181-
: public OneToNOpConversionPattern<ExtractIterSpaceOp> {
193+
: public OpConversionPattern<ExtractIterSpaceOp> {
182194
public:
183-
using OneToNOpConversionPattern::OneToNOpConversionPattern;
195+
using OpConversionPattern::OpConversionPattern;
184196
LogicalResult
185-
matchAndRewrite(ExtractIterSpaceOp op, OpAdaptor adaptor,
186-
OneToNPatternRewriter &rewriter) const override {
197+
matchAndRewrite(ExtractIterSpaceOp op, OneToNOpAdaptor adaptor,
198+
ConversionPatternRewriter &rewriter) const override {
187199
Location loc = op.getLoc();
188-
const OneToNTypeMapping &resultMapping = adaptor.getResultMapping();
189200

190201
// Construct the iteration space.
191-
SparseIterationSpace space(loc, rewriter, op.getTensor(), 0,
202+
SparseIterationSpace space(loc, rewriter,
203+
getSingleValue(adaptor.getTensor()), 0,
192204
op.getLvlRange(), adaptor.getParentIter());
193205

194206
SmallVector<Value> result = space.toValues();
195-
rewriter.replaceOp(op, result, resultMapping);
207+
rewriter.replaceOpWithMultiple(op, {result});
196208
return success();
197209
}
198210
};
199211

200212
/// Sparse codegen rule for number of entries operator.
201-
class ExtractValOpConverter : public OneToNOpConversionPattern<ExtractValOp> {
213+
class ExtractValOpConverter : public OpConversionPattern<ExtractValOp> {
202214
public:
203-
using OneToNOpConversionPattern::OneToNOpConversionPattern;
215+
using OpConversionPattern::OpConversionPattern;
204216
LogicalResult
205-
matchAndRewrite(ExtractValOp op, OpAdaptor adaptor,
206-
OneToNPatternRewriter &rewriter) const override {
217+
matchAndRewrite(ExtractValOp op, OneToNOpAdaptor adaptor,
218+
ConversionPatternRewriter &rewriter) const override {
207219
Location loc = op.getLoc();
208220
Value pos = adaptor.getIterator().back();
209-
Value valBuf = rewriter.create<ToValuesOp>(loc, op.getTensor());
221+
Value valBuf =
222+
rewriter.create<ToValuesOp>(loc, getSingleValue(adaptor.getTensor()));
210223
rewriter.replaceOpWithNewOp<memref::LoadOp>(op, valBuf, pos);
211224
return success();
212225
}
213226
};
214227

215-
class SparseIterateOpConverter : public OneToNOpConversionPattern<IterateOp> {
228+
class SparseIterateOpConverter : public OpConversionPattern<IterateOp> {
216229
public:
217-
using OneToNOpConversionPattern::OneToNOpConversionPattern;
230+
using OpConversionPattern::OpConversionPattern;
218231
LogicalResult
219-
matchAndRewrite(IterateOp op, OpAdaptor adaptor,
220-
OneToNPatternRewriter &rewriter) const override {
232+
matchAndRewrite(IterateOp op, OneToNOpAdaptor adaptor,
233+
ConversionPatternRewriter &rewriter) const override {
221234
if (!op.getCrdUsedLvls().empty())
222235
return rewriter.notifyMatchFailure(
223236
op, "non-empty coordinates list not implemented.");
@@ -235,14 +248,15 @@ class SparseIterateOpConverter : public OneToNOpConversionPattern<IterateOp> {
235248
llvm::append_range(ivs, inits);
236249

237250
// Type conversion on iterate op block.
238-
OneToNTypeMapping blockTypeMapping(op.getBody()->getArgumentTypes());
251+
unsigned numOrigArgs = op.getBody()->getArgumentTypes().size();
252+
TypeConverter::SignatureConversion signatureConversion(numOrigArgs);
239253
if (failed(typeConverter->convertSignatureArgs(
240-
op.getBody()->getArgumentTypes(), blockTypeMapping)))
254+
op.getBody()->getArgumentTypes(), signatureConversion)))
241255
return rewriter.notifyMatchFailure(
242256
op, "failed to convert iterate region argurment types");
243-
rewriter.applySignatureConversion(op.getBody(), blockTypeMapping);
244257

245-
Block *block = op.getBody();
258+
Block *block = rewriter.applySignatureConversion(
259+
op.getBody(), signatureConversion, getTypeConverter());
246260
ValueRange ret = genLoopWithIterator(
247261
rewriter, loc, it.get(), ivs,
248262
[block](PatternRewriter &rewriter, Location loc, Region &loopBody,
@@ -263,19 +277,17 @@ class SparseIterateOpConverter : public OneToNOpConversionPattern<IterateOp> {
263277
return result;
264278
});
265279

266-
const OneToNTypeMapping &resultMapping = adaptor.getResultMapping();
267-
rewriter.replaceOp(op, ret, resultMapping);
280+
rewriter.replaceOp(op, ret);
268281
return success();
269282
}
270283
};
271284

272-
class SparseCoIterateOpConverter
273-
: public OneToNOpConversionPattern<CoIterateOp> {
274-
using OneToNOpConversionPattern::OneToNOpConversionPattern;
285+
class SparseCoIterateOpConverter : public OpConversionPattern<CoIterateOp> {
286+
using OpConversionPattern::OpConversionPattern;
275287

276288
LogicalResult
277-
matchAndRewrite(CoIterateOp op, OpAdaptor adaptor,
278-
OneToNPatternRewriter &rewriter) const override {
289+
matchAndRewrite(CoIterateOp op, OneToNOpAdaptor adaptor,
290+
ConversionPatternRewriter &rewriter) const override {
279291
assert(op.getSpaceDim() == 1 && "Not implemented");
280292
Location loc = op.getLoc();
281293

@@ -299,18 +311,23 @@ class SparseCoIterateOpConverter
299311
assert(!needUniv && "Not implemented");
300312
(void)needUniv;
301313

314+
SmallVector<Block *> newBlocks;
315+
DenseMap<Block *, Block *> newToOldBlockMap;
302316
for (Region &region : op.getCaseRegions()) {
303317
// Do a one-shot type conversion on all region blocks, since the same
304318
// region might be used multiple time.
305319
Block *block = &region.getBlocks().front();
306-
OneToNTypeMapping blockTypeMapping(block->getArgumentTypes());
320+
TypeConverter::SignatureConversion blockTypeMapping(
321+
block->getArgumentTypes().size());
307322
if (failed(typeConverter->convertSignatureArgs(block->getArgumentTypes(),
308323
blockTypeMapping))) {
309324
return rewriter.notifyMatchFailure(
310325
op, "failed to convert coiterate region argurment types");
311326
}
312327

313-
rewriter.applySignatureConversion(block, blockTypeMapping);
328+
newBlocks.push_back(rewriter.applySignatureConversion(
329+
block, blockTypeMapping, getTypeConverter()));
330+
newToOldBlockMap[newBlocks.back()] = block;
314331
}
315332

316333
SmallVector<SparseIterationSpace> spaces;
@@ -343,7 +360,7 @@ class SparseCoIterateOpConverter
343360

344361
// Generates a loop sequence, one loop per case.
345362
for (auto [r, caseBits] :
346-
llvm::zip_equal(op.getCaseRegions(), op.getRegionDefinedSpaces())) {
363+
llvm::zip_equal(newBlocks, op.getRegionDefinedSpaces())) {
347364
assert(caseBits.count() > 0 && "Complement space not implemented");
348365

349366
// Retrives a vector of pointers to the iterators used in the case.
@@ -359,11 +376,17 @@ class SparseCoIterateOpConverter
359376
// The subcases are never empty, it must contains at least the current
360377
// region itself.
361378
// TODO: these cases should be sorted.
362-
SmallVector<Region *> subCases = op.getSubCasesOf(r.getRegionNumber());
379+
SmallVector<Region *> subCases =
380+
op.getSubCasesOf(r->getParent()->getRegionNumber());
381+
SmallVector<Block *> newBlocks, oldBlocks;
382+
for (Region *r : subCases) {
383+
newBlocks.push_back(&r->front());
384+
oldBlocks.push_back(newToOldBlockMap[newBlocks.back()]);
385+
}
363386
assert(!subCases.empty());
364387

365-
ValueRange res = genCoIterateBranchNest(rewriter, loc, op, loopCrd,
366-
iters, subCases, userReduc);
388+
ValueRange res = genCoIterateBranchNest(
389+
rewriter, loc, op, loopCrd, iters, newBlocks, oldBlocks, userReduc);
367390

368391
SmallVector<Value> nextIterYields(res);
369392
// 2nd. foward the loop.
@@ -388,7 +411,7 @@ class SparseCoIterateOpConverter
388411
// This is a simple iteration loop.
389412
assert(caseBits.count() == 1);
390413

391-
Block *block = &r.getBlocks().front();
414+
Block *block = r;
392415
ValueRange curResult = genLoopWithIterator(
393416
rewriter, loc, validIters.front(), userReduc,
394417
/*bodyBuilder=*/

mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -172,11 +172,16 @@ struct LowerSparseIterationToSCFPass
172172
ConversionTarget target(*ctx);
173173

174174
// The actual conversion.
175-
target.addIllegalOp<ExtractIterSpaceOp, IterateOp>();
175+
target.addLegalDialect<arith::ArithDialect, linalg::LinalgDialect,
176+
memref::MemRefDialect, scf::SCFDialect,
177+
sparse_tensor::SparseTensorDialect>();
178+
target.addIllegalOp<CoIterateOp, ExtractIterSpaceOp, ExtractValOp,
179+
IterateOp>();
180+
target.addLegalOp<UnrealizedConversionCastOp>();
176181
populateLowerSparseIterationToSCFPatterns(converter, patterns);
177182

178-
if (failed(applyPartialOneToNConversion(getOperation(), converter,
179-
std::move(patterns))))
183+
if (failed(applyPartialConversion(getOperation(), target,
184+
std::move(patterns))))
180185
signalPassFailure();
181186
}
182187
};

0 commit comments

Comments
 (0)