@@ -244,88 +244,41 @@ class SparseIterateOpConverter : public OneToNOpConversionPattern<IterateOp> {
244
244
std::unique_ptr<SparseIterator> it =
245
245
iterSpace.extractIterator (rewriter, loc);
246
246
247
- if (it->iteratableByFor ()) {
248
- auto [lo, hi] = it->genForCond (rewriter, loc);
249
- Value step = constantIndex (rewriter, loc, 1 );
250
- SmallVector<Value> ivs;
251
- for (ValueRange inits : adaptor.getInitArgs ())
252
- llvm::append_range (ivs, inits);
253
- scf::ForOp forOp = rewriter.create <scf::ForOp>(loc, lo, hi, step, ivs);
254
-
255
- Block *loopBody = op.getBody ();
256
- OneToNTypeMapping bodyTypeMapping (loopBody->getArgumentTypes ());
257
- if (failed (typeConverter->convertSignatureArgs (
258
- loopBody->getArgumentTypes (), bodyTypeMapping)))
259
- return failure ();
260
- rewriter.applySignatureConversion (loopBody, bodyTypeMapping);
261
-
262
- rewriter.eraseBlock (forOp.getBody ());
263
- Region &dstRegion = forOp.getRegion ();
264
- rewriter.inlineRegionBefore (op.getRegion (), dstRegion, dstRegion.end ());
265
-
266
- auto yieldOp =
267
- llvm::cast<sparse_tensor::YieldOp>(forOp.getBody ()->getTerminator ());
268
-
269
- rewriter.setInsertionPointToEnd (forOp.getBody ());
270
- // replace sparse_tensor.yield with scf.yield.
271
- rewriter.create <scf::YieldOp>(loc, yieldOp.getResults ());
272
- rewriter.eraseOp (yieldOp);
273
-
274
- const OneToNTypeMapping &resultMapping = adaptor.getResultMapping ();
275
- rewriter.replaceOp (op, forOp.getResults (), resultMapping);
276
- } else {
277
- SmallVector<Value> ivs;
278
- // TODO: put iterator at the end of argument list to be consistent with
279
- // coiterate operation.
280
- llvm::append_range (ivs, it->getCursor ());
281
- for (ValueRange inits : adaptor.getInitArgs ())
282
- llvm::append_range (ivs, inits);
283
-
284
- assert (llvm::all_of (ivs, [](Value v) { return v != nullptr ; }));
285
-
286
- TypeRange types = ValueRange (ivs).getTypes ();
287
- auto whileOp = rewriter.create <scf::WhileOp>(loc, types, ivs);
288
- SmallVector<Location> l (types.size (), op.getIterator ().getLoc ());
289
-
290
- // Generates loop conditions.
291
- Block *before = rewriter.createBlock (&whileOp.getBefore (), {}, types, l);
292
- rewriter.setInsertionPointToStart (before);
293
- ValueRange bArgs = before->getArguments ();
294
- auto [whileCond, remArgs] = it->genWhileCond (rewriter, loc, bArgs);
295
- assert (remArgs.size () == adaptor.getInitArgs ().size ());
296
- rewriter.create <scf::ConditionOp>(loc, whileCond, before->getArguments ());
297
-
298
- // Generates loop body.
299
- Block *loopBody = op.getBody ();
300
- OneToNTypeMapping bodyTypeMapping (loopBody->getArgumentTypes ());
301
- if (failed (typeConverter->convertSignatureArgs (
302
- loopBody->getArgumentTypes (), bodyTypeMapping)))
303
- return failure ();
304
- rewriter.applySignatureConversion (loopBody, bodyTypeMapping);
305
-
306
- Region &dstRegion = whileOp.getAfter ();
307
- // TODO: handle uses of coordinate!
308
- rewriter.inlineRegionBefore (op.getRegion (), dstRegion, dstRegion.end ());
309
- ValueRange aArgs = whileOp.getAfterArguments ();
310
- auto yieldOp = llvm::cast<sparse_tensor::YieldOp>(
311
- whileOp.getAfterBody ()->getTerminator ());
312
-
313
- rewriter.setInsertionPointToEnd (whileOp.getAfterBody ());
247
+ SmallVector<Value> ivs;
248
+ for (ValueRange inits : adaptor.getInitArgs ())
249
+ llvm::append_range (ivs, inits);
250
+
251
+ // Type conversion on iterate op block.
252
+ OneToNTypeMapping blockTypeMapping (op.getBody ()->getArgumentTypes ());
253
+ if (failed (typeConverter->convertSignatureArgs (
254
+ op.getBody ()->getArgumentTypes (), blockTypeMapping)))
255
+ return rewriter.notifyMatchFailure (
256
+ op, " failed to convert iterate region argurment types" );
257
+ rewriter.applySignatureConversion (op.getBody (), blockTypeMapping);
258
+
259
+ Block *block = op.getBody ();
260
+ ValueRange ret = genLoopWithIterator (
261
+ rewriter, loc, it.get (), ivs, /* iterFirst=*/ true ,
262
+ [block](PatternRewriter &rewriter, Location loc, Region &loopBody,
263
+ SparseIterator *it, ValueRange reduc) -> SmallVector<Value> {
264
+ SmallVector<Value> blockArgs (it->getCursor ());
265
+ // TODO: Also appends coordinates if used.
266
+ // blockArgs.push_back(it->deref(rewriter, loc));
267
+ llvm::append_range (blockArgs, reduc);
268
+
269
+ Block *dstBlock = &loopBody.getBlocks ().front ();
270
+ rewriter.inlineBlockBefore (block, dstBlock, dstBlock->end (),
271
+ blockArgs);
272
+ auto yield = llvm::cast<sparse_tensor::YieldOp>(dstBlock->back ());
273
+ // We can not use ValueRange as the operation holding the values will
274
+ // be destoryed.
275
+ SmallVector<Value> result (yield.getResults ());
276
+ rewriter.eraseOp (yield);
277
+ return result;
278
+ });
314
279
315
- aArgs = it->linkNewScope (aArgs);
316
- ValueRange nx = it->forward (rewriter, loc);
317
- SmallVector<Value> yields;
318
- llvm::append_range (yields, nx);
319
- llvm::append_range (yields, yieldOp.getResults ());
320
-
321
- // replace sparse_tensor.yield with scf.yield.
322
- rewriter.eraseOp (yieldOp);
323
- rewriter.create <scf::YieldOp>(loc, yields);
324
- const OneToNTypeMapping &resultMapping = adaptor.getResultMapping ();
325
- rewriter.replaceOp (
326
- op, whileOp.getResults ().drop_front (it->getCursor ().size ()),
327
- resultMapping);
328
- }
280
+ const OneToNTypeMapping &resultMapping = adaptor.getResultMapping ();
281
+ rewriter.replaceOp (op, ret, resultMapping);
329
282
return success ();
330
283
}
331
284
};
@@ -366,9 +319,10 @@ class SparseCoIterateOpConverter
366
319
Block *block = ®ion.getBlocks ().front ();
367
320
OneToNTypeMapping blockTypeMapping (block->getArgumentTypes ());
368
321
if (failed (typeConverter->convertSignatureArgs (block->getArgumentTypes (),
369
- blockTypeMapping)))
322
+ blockTypeMapping))) {
370
323
return rewriter.notifyMatchFailure (
371
324
op, " failed to convert coiterate region argurment types" );
325
+ }
372
326
373
327
rewriter.applySignatureConversion (block, blockTypeMapping);
374
328
}
0 commit comments