1
1
2
2
#include " Utils/CodegenUtils.h"
3
+ #include " Utils/LoopEmitter.h"
3
4
#include " Utils/SparseTensorIterator.h"
4
5
5
6
#include " mlir/Dialect/MemRef/IR/MemRef.h"
@@ -49,6 +50,144 @@ convertIteratorType(IteratorType itTp, SmallVectorImpl<Type> &fields) {
49
50
return success ();
50
51
}
51
52
53
+ static ValueRange
54
+ genCoIterateBranchNest (PatternRewriter &rewriter, Location loc, CoIterateOp op,
55
+ Value loopCrd,
56
+ ArrayRef<std::unique_ptr<SparseIterator>> iters,
57
+ ArrayRef<Region *> subCases, ArrayRef<Value> userReduc) {
58
+ if (subCases.empty ())
59
+ return userReduc;
60
+
61
+ // The current branch that we are handling.
62
+ Region *b = subCases.front ();
63
+ Value casePred = constantI1 (rewriter, loc, true );
64
+ I64BitSet caseBits = op.getRegionDefinedSpace (b->getRegionNumber ());
65
+ for (unsigned i : caseBits.bits ()) {
66
+ SparseIterator *it = iters[i].get ();
67
+ Value pred = rewriter.create <arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
68
+ it->getCrd (), loopCrd);
69
+ casePred = rewriter.create <arith::AndIOp>(loc, casePred, pred);
70
+ }
71
+ scf::IfOp ifOp = rewriter.create <scf::IfOp>(
72
+ loc, ValueRange (userReduc).getTypes (), casePred, /* else=*/ true );
73
+ rewriter.setInsertionPointToStart (&ifOp.getThenRegion ().front ());
74
+
75
+ // Erase the empty block.
76
+ rewriter.eraseBlock (&ifOp.getThenRegion ().front ());
77
+ // Set up block arguments: user-provided values -> loop coord -> iterators.
78
+ SmallVector<Value> blockArgs (userReduc);
79
+ blockArgs.push_back (loopCrd);
80
+ for (unsigned idx : caseBits.bits ())
81
+ llvm::append_range (blockArgs, iters[idx]->getCursor ());
82
+
83
+ IRMapping mapping;
84
+ for (auto [from, to] :
85
+ llvm::zip_equal (b->front ().getArguments (), blockArgs)) {
86
+ mapping.map (from, to);
87
+ }
88
+
89
+ // Clone the region, we can not erase the region now because the same region
90
+ // might be a subcase for multiple lattice point.
91
+ rewriter.cloneRegionBefore (*b, ifOp.getThenRegion (),
92
+ ifOp.getThenRegion ().begin (), mapping);
93
+
94
+ // replace sparse_tensor::YieldOp -> scf::YieldOp
95
+ auto spY = cast<sparse_tensor::YieldOp>(&ifOp.getThenRegion ().front ().back ());
96
+ ValueRange yields = spY.getResults ();
97
+ rewriter.eraseOp (spY);
98
+ rewriter.setInsertionPointToEnd (&ifOp.getThenRegion ().front ());
99
+ rewriter.create <scf::YieldOp>(loc, yields);
100
+
101
+ // Generates remaining case recursively.
102
+ rewriter.setInsertionPointToStart (&ifOp.getElseRegion ().front ());
103
+ ValueRange res = genCoIterateBranchNest (rewriter, loc, op, loopCrd, iters,
104
+ subCases.drop_front (), userReduc);
105
+ if (!res.empty ())
106
+ rewriter.create <scf::YieldOp>(loc, res);
107
+
108
+ rewriter.setInsertionPointAfter (ifOp);
109
+ return ifOp.getResults ();
110
+ }
111
+
112
+ static ValueRange genLoopWithIterator (
113
+ PatternRewriter &rewriter, Location loc, SparseIterator *it,
114
+ ValueRange reduc, bool iterFirst,
115
+ function_ref<SmallVector<Value>(PatternRewriter &rewriter, Location loc,
116
+ Region &loopBody, SparseIterator *it,
117
+ ValueRange reduc)>
118
+ bodyBuilder) {
119
+ if (it->iteratableByFor ()) {
120
+ auto [lo, hi] = it->genForCond (rewriter, loc);
121
+ Value step = constantIndex (rewriter, loc, 1 );
122
+ scf::ForOp forOp = rewriter.create <scf::ForOp>(loc, lo, hi, step, reduc);
123
+ {
124
+ OpBuilder::InsertionGuard guard (rewriter);
125
+ // Erase the implicit yield operation created by ForOp when there is no
126
+ // yielding values.
127
+ if (!forOp.getBody ()->empty ())
128
+ rewriter.eraseOp (&forOp.getBody ()->front ());
129
+ assert (forOp.getBody ()->empty ());
130
+
131
+ it->linkNewScope (forOp.getInductionVar ());
132
+ rewriter.setInsertionPointToStart (forOp.getBody ());
133
+ SmallVector<Value> ret = bodyBuilder (rewriter, loc, forOp.getBodyRegion (),
134
+ it, forOp.getRegionIterArgs ());
135
+
136
+ rewriter.setInsertionPointToEnd (forOp.getBody ());
137
+ rewriter.create <scf::YieldOp>(loc, ret);
138
+ }
139
+ return forOp.getResults ();
140
+ }
141
+ SmallVector<Value> ivs;
142
+ // TODO: always put iterator SSA values at the end of argument list to be
143
+ // consistent with coiterate operation.
144
+ if (!iterFirst)
145
+ llvm::append_range (ivs, it->getCursor ());
146
+ // Appends the user-provided values.
147
+ llvm::append_range (ivs, reduc);
148
+ if (iterFirst)
149
+ llvm::append_range (ivs, it->getCursor ());
150
+
151
+ TypeRange types = ValueRange (ivs).getTypes ();
152
+ auto whileOp = rewriter.create <scf::WhileOp>(loc, types, ivs);
153
+ {
154
+ OpBuilder::InsertionGuard guard (rewriter);
155
+ // Generates loop conditions.
156
+ SmallVector<Location> l (types.size (), loc);
157
+ Block *before = rewriter.createBlock (&whileOp.getBefore (), {}, types, l);
158
+ rewriter.setInsertionPointToStart (before);
159
+ ValueRange bArgs = before->getArguments ();
160
+ auto [whileCond, remArgs] = it->genWhileCond (rewriter, loc, bArgs);
161
+ rewriter.create <scf::ConditionOp>(loc, whileCond, before->getArguments ());
162
+
163
+ // Delegates loop body generation.
164
+ Region &dstRegion = whileOp.getAfter ();
165
+ Block *after = rewriter.createBlock (&dstRegion, {}, types, l);
166
+ ValueRange aArgs = whileOp.getAfterArguments ();
167
+ if (iterFirst) {
168
+ aArgs = it->linkNewScope (aArgs);
169
+ } else {
170
+ aArgs = aArgs.take_front (reduc.size ());
171
+ it->linkNewScope (aArgs.drop_front (reduc.size ()));
172
+ }
173
+
174
+ rewriter.setInsertionPointToStart (after);
175
+ SmallVector<Value> ret = bodyBuilder (rewriter, loc, dstRegion, it, aArgs);
176
+ rewriter.setInsertionPointToEnd (after);
177
+
178
+ // Forward loops
179
+ SmallVector<Value> yields;
180
+ ValueRange nx = it->forward (rewriter, loc);
181
+ if (iterFirst)
182
+ llvm::append_range (yields, nx);
183
+ llvm::append_range (yields, ret);
184
+ if (!iterFirst)
185
+ llvm::append_range (yields, nx);
186
+ rewriter.create <scf::YieldOp>(loc, yields);
187
+ }
188
+ return whileOp.getResults ().drop_front (it->getCursor ().size ());
189
+ }
190
+
52
191
namespace {
53
192
54
193
// / Sparse codegen rule for number of entries operator.
@@ -136,6 +275,8 @@ class SparseIterateOpConverter : public OneToNOpConversionPattern<IterateOp> {
136
275
rewriter.replaceOp (op, forOp.getResults (), resultMapping);
137
276
} else {
138
277
SmallVector<Value> ivs;
278
+ // TODO: put iterator at the end of argument list to be consistent with
279
+ // coiterate operation.
139
280
llvm::append_range (ivs, it->getCursor ());
140
281
for (ValueRange inits : adaptor.getInitArgs ())
141
282
llvm::append_range (ivs, inits);
@@ -189,6 +330,153 @@ class SparseIterateOpConverter : public OneToNOpConversionPattern<IterateOp> {
189
330
}
190
331
};
191
332
333
+ class SparseCoIterateOpConverter
334
+ : public OneToNOpConversionPattern<CoIterateOp> {
335
+ using OneToNOpConversionPattern::OneToNOpConversionPattern;
336
+
337
+ LogicalResult
338
+ matchAndRewrite (CoIterateOp op, OpAdaptor adaptor,
339
+ OneToNPatternRewriter &rewriter) const override {
340
+ assert (op.getSpaceDim () == 1 && " Not implemented" );
341
+ Location loc = op.getLoc ();
342
+
343
+ I64BitSet denseBits (0 );
344
+ for (auto [idx, spaceTp] : llvm::enumerate (op.getIterSpaces ().getTypes ()))
345
+ if (all_of (cast<IterSpaceType>(spaceTp).getLvlTypes (), isDenseLT))
346
+ denseBits.set (idx);
347
+
348
+ // If there exists a case that only contains dense spaces. I.e., case
349
+ // bits is a subset of dense bits, or when there is a full empty case (due
350
+ // to complements), we need a universal pointer to forward the coiteration
351
+ // loop.
352
+ bool needUniv =
353
+ any_of (op.getRegionDefinedSpaces (), [denseBits](I64BitSet caseBits) {
354
+ // A case for complement.
355
+ if (caseBits.count () == 0 )
356
+ return true ;
357
+ // An all-dense case.
358
+ return caseBits.isSubSetOf (denseBits);
359
+ });
360
+ assert (!needUniv && " Not implemented" );
361
+ (void )needUniv;
362
+
363
+ for (Region ®ion : op.getCaseRegions ()) {
364
+ // Do a one-shot type conversion on all region blocks, since the same
365
+ // region might be used multiple time.
366
+ Block *block = ®ion.getBlocks ().front ();
367
+ OneToNTypeMapping blockTypeMapping (block->getArgumentTypes ());
368
+ if (failed (typeConverter->convertSignatureArgs (block->getArgumentTypes (),
369
+ blockTypeMapping)))
370
+ return rewriter.notifyMatchFailure (
371
+ op, " failed to convert coiterate region argurment types" );
372
+
373
+ rewriter.applySignatureConversion (block, blockTypeMapping);
374
+ }
375
+
376
+ SmallVector<SparseIterationSpace> spaces;
377
+ SmallVector<std::unique_ptr<SparseIterator>> iters;
378
+ for (auto [spaceTp, spaceVals] : llvm::zip_equal (
379
+ op.getIterSpaces ().getTypes (), adaptor.getIterSpaces ())) {
380
+ // TODO: do we really need tid?
381
+ spaces.push_back (SparseIterationSpace::fromValues (
382
+ cast<IterSpaceType>(spaceTp), spaceVals, /* tid=*/ 0 ));
383
+ // Extract the iterator.
384
+ iters.push_back (spaces.back ().extractIterator (rewriter, loc));
385
+ }
386
+
387
+ auto getFilteredIters = [&iters](I64BitSet caseBits) {
388
+ // Retrives a vector of pointers to the iterators used in the case.
389
+ SmallVector<SparseIterator *> validIters;
390
+ for (auto idx : caseBits.bits ())
391
+ validIters.push_back (iters[idx].get ());
392
+ return validIters;
393
+ };
394
+
395
+ // Get a flattened user-provided loop reduction values.
396
+ SmallVector<Value> userReduc;
397
+ for (ValueRange r : adaptor.getInitArgs ())
398
+ llvm::append_range (userReduc, r);
399
+
400
+ // TODO: we need to sort the cases such that they appears in lexical order.
401
+ // Although sparsification always generates cases in that order, it might
402
+ // not be the case for human-written code.
403
+
404
+ // Generates a loop sequence, one loop per case.
405
+ for (auto [r, caseBits] :
406
+ llvm::zip_equal (op.getCaseRegions (), op.getRegionDefinedSpaces ())) {
407
+ assert (caseBits.count () > 0 && " Complement space not implemented" );
408
+
409
+ // Retrives a vector of pointers to the iterators used in the case.
410
+ SmallVector<SparseIterator *> validIters = getFilteredIters (caseBits);
411
+
412
+ if (validIters.size () > 1 ) {
413
+ auto [loop, loopCrd] =
414
+ genCoIteration (rewriter, loc, validIters, userReduc,
415
+ /* uniIdx=*/ nullptr , /* userReducFirst=*/ true );
416
+
417
+ // 1st. find all the cases that is a strict subset of the current case
418
+ // condition, for which we generate one branch per case inside the loop.
419
+ // The subcases are never empty, it must contains at least the current
420
+ // region itself.
421
+ // TODO: these cases should be sorted.
422
+ SmallVector<Region *> subCases = op.getSubCasesOf (r.getRegionNumber ());
423
+ assert (!subCases.empty ());
424
+
425
+ ValueRange res = genCoIterateBranchNest (rewriter, loc, op, loopCrd,
426
+ iters, subCases, userReduc);
427
+
428
+ SmallVector<Value> nextIterYields (res);
429
+ // 2nd. foward the loop.
430
+ for (SparseIterator *it : validIters) {
431
+ Value cmp = rewriter.create <arith::CmpIOp>(
432
+ loc, arith::CmpIPredicate::eq, it->getCrd (), loopCrd);
433
+ it->forwardIf (rewriter, loc, cmp);
434
+ llvm::append_range (nextIterYields, it->getCursor ());
435
+ }
436
+ rewriter.create <scf::YieldOp>(loc, nextIterYields);
437
+
438
+ // Exit the loop, relink the iterator SSA value.
439
+ rewriter.setInsertionPointAfter (loop);
440
+ ValueRange iterVals = loop->getResults ().drop_front (userReduc.size ());
441
+ for (SparseIterator *it : validIters)
442
+ iterVals = it->linkNewScope (iterVals);
443
+ assert (iterVals.empty ());
444
+
445
+ ValueRange curResult = loop->getResults ().take_front (userReduc.size ());
446
+ userReduc.assign (curResult.begin (), curResult.end ());
447
+ } else {
448
+ // This is a simple iteration loop.
449
+ assert (caseBits.count () == 1 );
450
+
451
+ Block *block = &r.getBlocks ().front ();
452
+ ValueRange curResult = genLoopWithIterator (
453
+ rewriter, loc, validIters.front (), userReduc, /* iterFirst=*/ false ,
454
+ /* bodyBuilder=*/
455
+ [block](PatternRewriter &rewriter, Location loc, Region &dstRegion,
456
+ SparseIterator *it,
457
+ ValueRange reduc) -> SmallVector<Value> {
458
+ SmallVector<Value> blockArgs (reduc);
459
+ blockArgs.push_back (it->deref (rewriter, loc));
460
+ llvm::append_range (blockArgs, it->getCursor ());
461
+
462
+ Block *dstBlock = &dstRegion.getBlocks ().front ();
463
+ rewriter.inlineBlockBefore (
464
+ block, dstBlock, rewriter.getInsertionPoint (), blockArgs);
465
+ auto yield = llvm::cast<sparse_tensor::YieldOp>(dstBlock->back ());
466
+ SmallVector<Value> result (yield.getResults ());
467
+ rewriter.eraseOp (yield);
468
+ return result;
469
+ });
470
+
471
+ userReduc.assign (curResult.begin (), curResult.end ());
472
+ }
473
+ }
474
+
475
+ rewriter.replaceOp (op, userReduc);
476
+ return success ();
477
+ }
478
+ };
479
+
192
480
} // namespace
193
481
194
482
mlir::SparseIterationTypeConverter::SparseIterationTypeConverter () {
@@ -210,5 +498,6 @@ void mlir::populateLowerSparseIterationToSCFPatterns(
210
498
211
499
IterateOp::getCanonicalizationPatterns (patterns, patterns.getContext ());
212
500
patterns.add <ExtractIterSpaceConverter, ExtractValOpConverter,
213
- SparseIterateOpConverter>(converter, patterns.getContext ());
501
+ SparseIterateOpConverter, SparseCoIterateOpConverter>(
502
+ converter, patterns.getContext ());
214
503
}
0 commit comments