@@ -378,18 +378,54 @@ static void emitScalarImplementation(ArrayRef<Value> allIvs, PoolingSumOp op) {
378
378
getPoolingInput<IndexedValueType>(op, indices.inputs );
379
379
}
380
380
381
+ // / Replace the index operations in the body of the loop nest by the matching
382
+ // / induction variables.
383
+ static void replaceIndexOpsByInductionVariables (LinalgOp linalgOp,
384
+ PatternRewriter &rewriter,
385
+ ArrayRef<Operation *> loopOps) {
386
+ // Extract the induction variables of the loop nest from outer to inner.
387
+ SmallVector<Value> allIvs;
388
+ for (Operation *loopOp : loopOps) {
389
+ llvm::TypeSwitch<Operation *>(loopOp)
390
+ .Case ([&](scf::ParallelOp parallelOp) {
391
+ allIvs.append (parallelOp.getInductionVars ().begin (),
392
+ parallelOp.getInductionVars ().end ());
393
+ })
394
+ .Case ([&](scf::ForOp forOp) {
395
+ allIvs.push_back (forOp.getInductionVar ());
396
+ })
397
+ .Case ([&](AffineForOp affineForOp) {
398
+ allIvs.push_back (affineForOp.getInductionVar ());
399
+ })
400
+ .Default ([&](Operation *op) { assert (false && " unexpected op" ); });
401
+ }
402
+ assert (linalgOp.getNumLoops () == allIvs.size () &&
403
+ " expected the number of loops and induction variables to match" );
404
+ // Replace the index operations in the body of the innermost loop op.
405
+ if (!loopOps.empty ()) {
406
+ LoopLikeOpInterface loopOp = loopOps.back ();
407
+ for (IndexOp indexOp :
408
+ llvm::make_early_inc_range (loopOp.getLoopBody ().getOps <IndexOp>()))
409
+ rewriter.replaceOp (indexOp, allIvs[indexOp.dim ()]);
410
+ }
411
+ }
412
+
381
413
template <typename LoopTy>
382
- static Optional<LinalgLoops> linalgOpToLoopsImpl (LinalgOp linalgOp ,
383
- OpBuilder &builder ) {
414
+ static Optional<LinalgLoops> linalgOpToLoopsImpl (PatternRewriter &rewriter ,
415
+ LinalgOp linalgOp ) {
384
416
using IndexedValueTy = typename GenerateLoopNest<LoopTy>::IndexedValueTy;
385
- ScopedContext scope (builder, linalgOp.getLoc ());
417
+ ScopedContext scope (rewriter, linalgOp.getLoc ());
418
+
419
+ // Canonicalize indexed_generic operations before lowering them to loops.
420
+ if (isa<IndexedGenericOp>(linalgOp))
421
+ return llvm::None;
386
422
387
423
// The flattened loopToOperandRangesMaps is expected to be an invertible
388
424
// permutation map (which is asserted in the inverse calculation).
389
425
assert (linalgOp.hasBufferSemantics () &&
390
426
" expected linalg op with buffer semantics" );
391
427
392
- auto loopRanges = linalgOp.createLoopRanges (builder , linalgOp.getLoc ());
428
+ auto loopRanges = linalgOp.createLoopRanges (rewriter , linalgOp.getLoc ());
393
429
auto iteratorTypes = llvm::to_vector<4 >(linalgOp.iterator_types ().getValue ());
394
430
395
431
SmallVector<Value, 4 > allIvs;
@@ -420,41 +456,11 @@ static Optional<LinalgLoops> linalgOpToLoopsImpl(LinalgOp linalgOp,
420
456
loopSet.insert (ivVal.getOwner ()->getParentOp ());
421
457
}
422
458
LinalgLoops loops (loopSet.begin (), loopSet.end ());
459
+ // Replace all index operations in the loop body.
460
+ replaceIndexOpsByInductionVariables (linalgOp, rewriter, loops);
423
461
return loops;
424
462
}
425
463
426
- // / Replace the index operations in the body of the loop nest by the matching
427
- // / induction variables.
428
- static void replaceIndexOpsByInductionVariables (LinalgOp linalgOp,
429
- PatternRewriter &rewriter,
430
- ArrayRef<Operation *> loopOps) {
431
- // Extract the induction variables of the loop nest from outer to inner.
432
- SmallVector<Value> allIvs;
433
- for (Operation *loopOp : loopOps) {
434
- llvm::TypeSwitch<Operation *>(loopOp)
435
- .Case ([&](scf::ParallelOp parallelOp) {
436
- allIvs.append (parallelOp.getInductionVars ().begin (),
437
- parallelOp.getInductionVars ().end ());
438
- })
439
- .Case ([&](scf::ForOp forOp) {
440
- allIvs.push_back (forOp.getInductionVar ());
441
- })
442
- .Case ([&](AffineForOp affineForOp) {
443
- allIvs.push_back (affineForOp.getInductionVar ());
444
- })
445
- .Default ([&](Operation *op) { assert (false && " unexpected op" ); });
446
- }
447
- assert (linalgOp.getNumLoops () == allIvs.size () &&
448
- " expected the number of loops and induction variables to match" );
449
- // Replace the index operations in the body of the innermost loop op.
450
- if (!loopOps.empty ()) {
451
- LoopLikeOpInterface loopOp = loopOps.back ();
452
- for (IndexOp indexOp :
453
- llvm::make_early_inc_range (loopOp.getLoopBody ().getOps <IndexOp>()))
454
- rewriter.replaceOp (indexOp, allIvs[indexOp.dim ()]);
455
- }
456
- }
457
-
458
464
namespace {
459
465
template <typename LoopType>
460
466
class LinalgRewritePattern : public RewritePattern {
@@ -467,7 +473,7 @@ class LinalgRewritePattern : public RewritePattern {
467
473
auto linalgOp = dyn_cast<LinalgOp>(op);
468
474
if (!isa<LinalgOp>(op))
469
475
return failure ();
470
- if (!linalgLowerOpToLoops <LoopType>(rewriter, linalgOp))
476
+ if (!linalgOpToLoopsImpl <LoopType>(rewriter, linalgOp))
471
477
return failure ();
472
478
rewriter.eraseOp (op);
473
479
return success ();
@@ -614,52 +620,22 @@ mlir::createConvertLinalgToAffineLoopsPass() {
614
620
return std::make_unique<LowerToAffineLoops>();
615
621
}
616
622
617
- // / Emits a loop nest with the proper body for `linalgOp`.
618
- template <typename LoopTy>
619
- Optional<LinalgLoops>
620
- mlir::linalg::linalgLowerOpToLoops (PatternRewriter &rewriter,
621
- LinalgOp linalgOp) {
622
- // Convert indexed_generic ops to generic ops before lowering them to loops.
623
- if (isa<IndexedGenericOp>(linalgOp))
624
- return llvm::None;
625
-
626
- Optional<LinalgLoops> loopOps =
627
- linalgOpToLoopsImpl<LoopTy>(linalgOp.getOperation (), rewriter);
628
- if (loopOps.hasValue ())
629
- replaceIndexOpsByInductionVariables (linalgOp, rewriter, loopOps.getValue ());
630
- return loopOps;
631
- }
632
-
633
- template Optional<LinalgLoops>
634
- mlir::linalg::linalgLowerOpToLoops<AffineForOp>(PatternRewriter &rewriter,
635
- LinalgOp linalgOp);
636
- template Optional<LinalgLoops>
637
- mlir::linalg::linalgLowerOpToLoops<scf::ForOp>(PatternRewriter &rewriter,
638
- LinalgOp linalgOp);
639
- template Optional<LinalgLoops>
640
- mlir::linalg::linalgLowerOpToLoops<scf::ParallelOp>(PatternRewriter &rewriter,
641
- LinalgOp linalgOp);
642
-
643
623
// / Emits a loop nest of `affine.for` with the proper body for `linalgOp`.
644
- LogicalResult mlir::linalg::linalgOpToAffineLoops (PatternRewriter &rewriter,
645
- LinalgOp linalgOp) {
646
- Optional<LinalgLoops> loops =
647
- linalgLowerOpToLoops<AffineForOp>(rewriter, linalgOp);
648
- return loops ? success () : failure ();
624
+ Optional<LinalgLoops>
625
+ mlir::linalg::linalgOpToAffineLoops (PatternRewriter &rewriter,
626
+ LinalgOp linalgOp) {
627
+ return linalgOpToLoopsImpl<AffineForOp>(rewriter, linalgOp);
649
628
}
650
629
651
630
// / Emits a loop nest of `scf.for` with the proper body for `linalgOp`.
652
- LogicalResult mlir::linalg::linalgOpToLoops (PatternRewriter &rewriter,
653
- LinalgOp linalgOp) {
654
- Optional<LinalgLoops> loops =
655
- linalgLowerOpToLoops<scf::ForOp>(rewriter, linalgOp);
656
- return loops ? success () : failure ();
631
+ Optional<LinalgLoops> mlir::linalg::linalgOpToLoops (PatternRewriter &rewriter,
632
+ LinalgOp linalgOp) {
633
+ return linalgOpToLoopsImpl<scf::ForOp>(rewriter, linalgOp);
657
634
}
658
635
659
636
// / Emits a loop nest of `scf.parallel` with the proper body for `linalgOp`.
660
- LogicalResult mlir::linalg::linalgOpToParallelLoops (PatternRewriter &rewriter,
661
- LinalgOp linalgOp) {
662
- Optional<LinalgLoops> loops =
663
- linalgLowerOpToLoops<scf::ParallelOp>(rewriter, linalgOp);
664
- return loops ? success () : failure ();
637
+ Optional<LinalgLoops>
638
+ mlir::linalg::linalgOpToParallelLoops (PatternRewriter &rewriter,
639
+ LinalgOp linalgOp) {
640
+ return linalgOpToLoopsImpl<scf::ParallelOp>(rewriter, linalgOp);
665
641
}
0 commit comments