7
7
#include " mlir/Dialect/SCF/IR/SCF.h"
8
8
#include " mlir/Dialect/SparseTensor/IR/SparseTensor.h"
9
9
#include " mlir/Dialect/SparseTensor/Transforms/Passes.h"
10
- #include " mlir/Transforms/OneToNTypeConversion .h"
10
+ #include " mlir/Transforms/DialectConversion .h"
11
11
12
12
using namespace mlir ;
13
13
using namespace mlir ::sparse_tensor;
14
14
15
+ // / Assert that the given value range contains a single value and return it.
16
+ static Value getSingleValue (ValueRange values) {
17
+ assert (values.size () == 1 && " expected single value" );
18
+ return values.front ();
19
+ }
20
+
15
21
static void convertLevelType (SparseTensorEncodingAttr enc, Level lvl,
16
22
SmallVectorImpl<Type> &fields) {
17
23
// Position and coordinate buffer in the sparse structure.
@@ -54,14 +60,17 @@ static ValueRange
54
60
genCoIterateBranchNest (PatternRewriter &rewriter, Location loc, CoIterateOp op,
55
61
Value loopCrd,
56
62
ArrayRef<std::unique_ptr<SparseIterator>> iters,
57
- ArrayRef<Region *> subCases, ArrayRef<Value> userReduc) {
58
- if (subCases.empty ())
63
+ ArrayRef<Block *> newBlocks, ArrayRef<Block *> oldBlocks,
64
+ ArrayRef<Value> userReduc) {
65
+ if (newBlocks.empty ())
59
66
return userReduc;
60
67
61
68
// The current branch that we are handling.
62
- Region *b = subCases.front ();
69
+ Block *newBlock = newBlocks.front ();
70
+ Block *oldBlock = oldBlocks.front ();
63
71
Value casePred = constantI1 (rewriter, loc, true );
64
- I64BitSet caseBits = op.getRegionDefinedSpace (b->getRegionNumber ());
72
+ I64BitSet caseBits =
73
+ op.getRegionDefinedSpace (newBlock->getParent ()->getRegionNumber ());
65
74
for (unsigned i : caseBits.bits ()) {
66
75
SparseIterator *it = iters[i].get ();
67
76
Value pred = rewriter.create <arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
@@ -80,16 +89,20 @@ genCoIterateBranchNest(PatternRewriter &rewriter, Location loc, CoIterateOp op,
80
89
for (unsigned idx : caseBits.bits ())
81
90
llvm::append_range (blockArgs, iters[idx]->getCursor ());
82
91
92
+ // Map the old block arguments, because the dialect conversion driver does
93
+ // not immediately perform SSA value replacements. This function is still
94
+ // seeing the old uses.
83
95
IRMapping mapping;
84
- for (auto [from, to] :
85
- llvm::zip_equal (b->front ().getArguments (), blockArgs)) {
96
+ for (auto [from, to] : llvm::zip_equal (oldBlock->getArguments (), blockArgs)) {
86
97
mapping.map (from, to);
87
98
}
88
99
89
100
// Clone the region, we can not erase the region now because the same region
90
101
// might be a subcase for multiple lattice point.
91
- rewriter.cloneRegionBefore (*b , ifOp.getThenRegion (),
102
+ rewriter.cloneRegionBefore (*newBlock-> getParent () , ifOp.getThenRegion (),
92
103
ifOp.getThenRegion ().begin (), mapping);
104
+ // Remove the block arguments, they were already replaced via `mapping`.
105
+ ifOp.getThenRegion ().front ().eraseArguments (0 , blockArgs.size ());
93
106
94
107
// replace sparse_tensor::YieldOp -> scf::YieldOp
95
108
auto spY = cast<sparse_tensor::YieldOp>(&ifOp.getThenRegion ().front ().back ());
@@ -101,7 +114,8 @@ genCoIterateBranchNest(PatternRewriter &rewriter, Location loc, CoIterateOp op,
101
114
// Generates remaining case recursively.
102
115
rewriter.setInsertionPointToStart (&ifOp.getElseRegion ().front ());
103
116
ValueRange res = genCoIterateBranchNest (rewriter, loc, op, loopCrd, iters,
104
- subCases.drop_front (), userReduc);
117
+ newBlocks.drop_front (),
118
+ oldBlocks.drop_front (), userReduc);
105
119
if (!res.empty ())
106
120
rewriter.create <scf::YieldOp>(loc, res);
107
121
@@ -119,15 +133,13 @@ static ValueRange genLoopWithIterator(
119
133
if (it->iteratableByFor ()) {
120
134
auto [lo, hi] = it->genForCond (rewriter, loc);
121
135
Value step = constantIndex (rewriter, loc, 1 );
122
- scf::ForOp forOp = rewriter.create <scf::ForOp>(loc, lo, hi, step, reduc);
136
+ scf::ForOp forOp = rewriter.create <scf::ForOp>(
137
+ loc, lo, hi, step, reduc,
138
+ [&](OpBuilder &b, Location loc, Value iv, ValueRange iterArgs) {
139
+ // Empty builder function to ensure that no terminator is created.
140
+ });
123
141
{
124
142
OpBuilder::InsertionGuard guard (rewriter);
125
- // Erase the implicit yield operation created by ForOp when there is no
126
- // yielding values.
127
- if (!forOp.getBody ()->empty ())
128
- rewriter.eraseOp (&forOp.getBody ()->front ());
129
- assert (forOp.getBody ()->empty ());
130
-
131
143
it->linkNewScope (forOp.getInductionVar ());
132
144
rewriter.setInsertionPointToStart (forOp.getBody ());
133
145
SmallVector<Value> ret = bodyBuilder (rewriter, loc, forOp.getBodyRegion (),
@@ -178,46 +190,47 @@ namespace {
178
190
179
191
// / Sparse codegen rule for number of entries operator.
180
192
class ExtractIterSpaceConverter
181
- : public OneToNOpConversionPattern <ExtractIterSpaceOp> {
193
+ : public OpConversionPattern <ExtractIterSpaceOp> {
182
194
public:
183
- using OneToNOpConversionPattern::OneToNOpConversionPattern ;
195
+ using OpConversionPattern::OpConversionPattern ;
184
196
LogicalResult
185
- matchAndRewrite (ExtractIterSpaceOp op, OpAdaptor adaptor,
186
- OneToNPatternRewriter &rewriter) const override {
197
+ matchAndRewrite (ExtractIterSpaceOp op, OneToNOpAdaptor adaptor,
198
+ ConversionPatternRewriter &rewriter) const override {
187
199
Location loc = op.getLoc ();
188
- const OneToNTypeMapping &resultMapping = adaptor.getResultMapping ();
189
200
190
201
// Construct the iteration space.
191
- SparseIterationSpace space (loc, rewriter, op.getTensor (), 0 ,
202
+ SparseIterationSpace space (loc, rewriter,
203
+ getSingleValue (adaptor.getTensor ()), 0 ,
192
204
op.getLvlRange (), adaptor.getParentIter ());
193
205
194
206
SmallVector<Value> result = space.toValues ();
195
- rewriter.replaceOp (op, result, resultMapping );
207
+ rewriter.replaceOpWithMultiple (op, { result} );
196
208
return success ();
197
209
}
198
210
};
199
211
200
212
// / Sparse codegen rule for number of entries operator.
201
- class ExtractValOpConverter : public OneToNOpConversionPattern <ExtractValOp> {
213
+ class ExtractValOpConverter : public OpConversionPattern <ExtractValOp> {
202
214
public:
203
- using OneToNOpConversionPattern::OneToNOpConversionPattern ;
215
+ using OpConversionPattern::OpConversionPattern ;
204
216
LogicalResult
205
- matchAndRewrite (ExtractValOp op, OpAdaptor adaptor,
206
- OneToNPatternRewriter &rewriter) const override {
217
+ matchAndRewrite (ExtractValOp op, OneToNOpAdaptor adaptor,
218
+ ConversionPatternRewriter &rewriter) const override {
207
219
Location loc = op.getLoc ();
208
220
Value pos = adaptor.getIterator ().back ();
209
- Value valBuf = rewriter.create <ToValuesOp>(loc, op.getTensor ());
221
+ Value valBuf =
222
+ rewriter.create <ToValuesOp>(loc, getSingleValue (adaptor.getTensor ()));
210
223
rewriter.replaceOpWithNewOp <memref::LoadOp>(op, valBuf, pos);
211
224
return success ();
212
225
}
213
226
};
214
227
215
- class SparseIterateOpConverter : public OneToNOpConversionPattern <IterateOp> {
228
+ class SparseIterateOpConverter : public OpConversionPattern <IterateOp> {
216
229
public:
217
- using OneToNOpConversionPattern::OneToNOpConversionPattern ;
230
+ using OpConversionPattern::OpConversionPattern ;
218
231
LogicalResult
219
- matchAndRewrite (IterateOp op, OpAdaptor adaptor,
220
- OneToNPatternRewriter &rewriter) const override {
232
+ matchAndRewrite (IterateOp op, OneToNOpAdaptor adaptor,
233
+ ConversionPatternRewriter &rewriter) const override {
221
234
if (!op.getCrdUsedLvls ().empty ())
222
235
return rewriter.notifyMatchFailure (
223
236
op, " non-empty coordinates list not implemented." );
@@ -235,14 +248,15 @@ class SparseIterateOpConverter : public OneToNOpConversionPattern<IterateOp> {
235
248
llvm::append_range (ivs, inits);
236
249
237
250
// Type conversion on iterate op block.
238
- OneToNTypeMapping blockTypeMapping (op.getBody ()->getArgumentTypes ());
251
+ unsigned numOrigArgs = op.getBody ()->getArgumentTypes ().size ();
252
+ TypeConverter::SignatureConversion signatureConversion (numOrigArgs);
239
253
if (failed (typeConverter->convertSignatureArgs (
240
- op.getBody ()->getArgumentTypes (), blockTypeMapping )))
254
+ op.getBody ()->getArgumentTypes (), signatureConversion )))
241
255
return rewriter.notifyMatchFailure (
242
256
op, " failed to convert iterate region argurment types" );
243
- rewriter.applySignatureConversion (op.getBody (), blockTypeMapping);
244
257
245
- Block *block = op.getBody ();
258
+ Block *block = rewriter.applySignatureConversion (
259
+ op.getBody (), signatureConversion, getTypeConverter ());
246
260
ValueRange ret = genLoopWithIterator (
247
261
rewriter, loc, it.get (), ivs,
248
262
[block](PatternRewriter &rewriter, Location loc, Region &loopBody,
@@ -263,19 +277,17 @@ class SparseIterateOpConverter : public OneToNOpConversionPattern<IterateOp> {
263
277
return result;
264
278
});
265
279
266
- const OneToNTypeMapping &resultMapping = adaptor.getResultMapping ();
267
- rewriter.replaceOp (op, ret, resultMapping);
280
+ rewriter.replaceOp (op, ret);
268
281
return success ();
269
282
}
270
283
};
271
284
272
- class SparseCoIterateOpConverter
273
- : public OneToNOpConversionPattern<CoIterateOp> {
274
- using OneToNOpConversionPattern::OneToNOpConversionPattern;
285
+ class SparseCoIterateOpConverter : public OpConversionPattern <CoIterateOp> {
286
+ using OpConversionPattern::OpConversionPattern;
275
287
276
288
LogicalResult
277
- matchAndRewrite (CoIterateOp op, OpAdaptor adaptor,
278
- OneToNPatternRewriter &rewriter) const override {
289
+ matchAndRewrite (CoIterateOp op, OneToNOpAdaptor adaptor,
290
+ ConversionPatternRewriter &rewriter) const override {
279
291
assert (op.getSpaceDim () == 1 && " Not implemented" );
280
292
Location loc = op.getLoc ();
281
293
@@ -299,18 +311,23 @@ class SparseCoIterateOpConverter
299
311
assert (!needUniv && " Not implemented" );
300
312
(void )needUniv;
301
313
314
+ SmallVector<Block *> newBlocks;
315
+ DenseMap<Block *, Block *> newToOldBlockMap;
302
316
for (Region ®ion : op.getCaseRegions ()) {
303
317
// Do a one-shot type conversion on all region blocks, since the same
304
318
// region might be used multiple time.
305
319
Block *block = ®ion.getBlocks ().front ();
306
- OneToNTypeMapping blockTypeMapping (block->getArgumentTypes ());
320
+ TypeConverter::SignatureConversion blockTypeMapping (
321
+ block->getArgumentTypes ().size ());
307
322
if (failed (typeConverter->convertSignatureArgs (block->getArgumentTypes (),
308
323
blockTypeMapping))) {
309
324
return rewriter.notifyMatchFailure (
310
325
op, " failed to convert coiterate region argurment types" );
311
326
}
312
327
313
- rewriter.applySignatureConversion (block, blockTypeMapping);
328
+ newBlocks.push_back (rewriter.applySignatureConversion (
329
+ block, blockTypeMapping, getTypeConverter ()));
330
+ newToOldBlockMap[newBlocks.back ()] = block;
314
331
}
315
332
316
333
SmallVector<SparseIterationSpace> spaces;
@@ -343,7 +360,7 @@ class SparseCoIterateOpConverter
343
360
344
361
// Generates a loop sequence, one loop per case.
345
362
for (auto [r, caseBits] :
346
- llvm::zip_equal (op. getCaseRegions () , op.getRegionDefinedSpaces ())) {
363
+ llvm::zip_equal (newBlocks , op.getRegionDefinedSpaces ())) {
347
364
assert (caseBits.count () > 0 && " Complement space not implemented" );
348
365
349
366
// Retrives a vector of pointers to the iterators used in the case.
@@ -359,11 +376,17 @@ class SparseCoIterateOpConverter
359
376
// The subcases are never empty, it must contains at least the current
360
377
// region itself.
361
378
// TODO: these cases should be sorted.
362
- SmallVector<Region *> subCases = op.getSubCasesOf (r.getRegionNumber ());
379
+ SmallVector<Region *> subCases =
380
+ op.getSubCasesOf (r->getParent ()->getRegionNumber ());
381
+ SmallVector<Block *> newBlocks, oldBlocks;
382
+ for (Region *r : subCases) {
383
+ newBlocks.push_back (&r->front ());
384
+ oldBlocks.push_back (newToOldBlockMap[newBlocks.back ()]);
385
+ }
363
386
assert (!subCases.empty ());
364
387
365
- ValueRange res = genCoIterateBranchNest (rewriter, loc, op, loopCrd,
366
- iters, subCases , userReduc);
388
+ ValueRange res = genCoIterateBranchNest (
389
+ rewriter, loc, op, loopCrd, iters, newBlocks, oldBlocks , userReduc);
367
390
368
391
SmallVector<Value> nextIterYields (res);
369
392
// 2nd. foward the loop.
@@ -388,7 +411,7 @@ class SparseCoIterateOpConverter
388
411
// This is a simple iteration loop.
389
412
assert (caseBits.count () == 1 );
390
413
391
- Block *block = &r. getBlocks (). front () ;
414
+ Block *block = r ;
392
415
ValueRange curResult = genLoopWithIterator (
393
416
rewriter, loc, validIters.front (), userReduc,
394
417
/* bodyBuilder=*/
0 commit comments