Skip to content

Commit 7c16f93

Browse files
author
Tobias Gysi
committed
[mlir][linalg] Remove template parameter from loop lowering.
Replace the templated linalgLowerOpToLoops method by three specialized methods linalgOpToLoops, LinalgOpToParallelLoops, and linalgOpToAffineLoops. Differential Revision: https://reviews.llvm.org/D102324
1 parent 900c898 commit 7c16f93

File tree

2 files changed

+63
-91
lines changed

2 files changed

+63
-91
lines changed

mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h

Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -342,21 +342,17 @@ Optional<LinalgOp> promoteSubViews(OpBuilder &b, LinalgOp op,
342342
LogicalResult vectorizeLinalgOp(OpBuilder &builder, Operation *op,
343343
SmallVectorImpl<Value> &newResults);
344344

345-
/// Emits a loop nest of `LoopTy` with the proper body for `linalgOp`.
346-
template <typename LoopTy>
347-
Optional<LinalgLoops> linalgLowerOpToLoops(PatternRewriter &rewriter,
348-
LinalgOp linalgOp);
349-
350345
/// Emits a loop nest of `scf.for` with the proper body for `linalgOp`.
351-
LogicalResult linalgOpToLoops(PatternRewriter &rewriter, LinalgOp linalgOp);
346+
Optional<LinalgLoops> linalgOpToLoops(PatternRewriter &rewriter,
347+
LinalgOp linalgOp);
352348

353349
/// Emits a loop nest of `scf.parallel` with the proper body for `linalgOp`.
354-
LogicalResult linalgOpToParallelLoops(PatternRewriter &rewriter,
355-
LinalgOp linalgOp);
350+
Optional<LinalgLoops> linalgOpToParallelLoops(PatternRewriter &rewriter,
351+
LinalgOp linalgOp);
356352

357353
/// Emits a loop nest of `affine.for` with the proper body for `linalgOp`.
358-
LogicalResult linalgOpToAffineLoops(PatternRewriter &rewriter,
359-
LinalgOp linalgOp);
354+
Optional<LinalgLoops> linalgOpToAffineLoops(PatternRewriter &rewriter,
355+
LinalgOp linalgOp);
360356

361357
//===----------------------------------------------------------------------===//
362358
// Preconditions that ensure the corresponding transformation succeeds and can
@@ -814,15 +810,15 @@ struct LinalgLoweringPattern : public RewritePattern {
814810
// TODO: Move lowering to library calls here.
815811
return failure();
816812
case LinalgLoweringType::Loops:
817-
if (failed(linalgOpToLoops(rewriter, op)))
813+
if (!linalgOpToLoops(rewriter, op))
818814
return failure();
819815
break;
820816
case LinalgLoweringType::AffineLoops:
821-
if (failed(linalgOpToAffineLoops(rewriter, op)))
817+
if (!linalgOpToAffineLoops(rewriter, op))
822818
return failure();
823819
break;
824820
case LinalgLoweringType::ParallelLoops:
825-
if (failed(linalgOpToParallelLoops(rewriter, op)))
821+
if (!linalgOpToParallelLoops(rewriter, op))
826822
return failure();
827823
break;
828824
}

mlir/lib/Dialect/Linalg/Transforms/Loops.cpp

Lines changed: 54 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -378,18 +378,54 @@ static void emitScalarImplementation(ArrayRef<Value> allIvs, PoolingSumOp op) {
378378
getPoolingInput<IndexedValueType>(op, indices.inputs);
379379
}
380380

381+
/// Replace the index operations in the body of the loop nest by the matching
382+
/// induction variables.
383+
static void replaceIndexOpsByInductionVariables(LinalgOp linalgOp,
384+
PatternRewriter &rewriter,
385+
ArrayRef<Operation *> loopOps) {
386+
// Extract the induction variables of the loop nest from outer to inner.
387+
SmallVector<Value> allIvs;
388+
for (Operation *loopOp : loopOps) {
389+
llvm::TypeSwitch<Operation *>(loopOp)
390+
.Case([&](scf::ParallelOp parallelOp) {
391+
allIvs.append(parallelOp.getInductionVars().begin(),
392+
parallelOp.getInductionVars().end());
393+
})
394+
.Case([&](scf::ForOp forOp) {
395+
allIvs.push_back(forOp.getInductionVar());
396+
})
397+
.Case([&](AffineForOp affineForOp) {
398+
allIvs.push_back(affineForOp.getInductionVar());
399+
})
400+
.Default([&](Operation *op) { assert(false && "unexpected op"); });
401+
}
402+
assert(linalgOp.getNumLoops() == allIvs.size() &&
403+
"expected the number of loops and induction variables to match");
404+
// Replace the index operations in the body of the innermost loop op.
405+
if (!loopOps.empty()) {
406+
LoopLikeOpInterface loopOp = loopOps.back();
407+
for (IndexOp indexOp :
408+
llvm::make_early_inc_range(loopOp.getLoopBody().getOps<IndexOp>()))
409+
rewriter.replaceOp(indexOp, allIvs[indexOp.dim()]);
410+
}
411+
}
412+
381413
template <typename LoopTy>
382-
static Optional<LinalgLoops> linalgOpToLoopsImpl(LinalgOp linalgOp,
383-
OpBuilder &builder) {
414+
static Optional<LinalgLoops> linalgOpToLoopsImpl(PatternRewriter &rewriter,
415+
LinalgOp linalgOp) {
384416
using IndexedValueTy = typename GenerateLoopNest<LoopTy>::IndexedValueTy;
385-
ScopedContext scope(builder, linalgOp.getLoc());
417+
ScopedContext scope(rewriter, linalgOp.getLoc());
418+
419+
// Canonicalize indexed_generic operations before lowering them to loops.
420+
if (isa<IndexedGenericOp>(linalgOp))
421+
return llvm::None;
386422

387423
// The flattened loopToOperandRangesMaps is expected to be an invertible
388424
// permutation map (which is asserted in the inverse calculation).
389425
assert(linalgOp.hasBufferSemantics() &&
390426
"expected linalg op with buffer semantics");
391427

392-
auto loopRanges = linalgOp.createLoopRanges(builder, linalgOp.getLoc());
428+
auto loopRanges = linalgOp.createLoopRanges(rewriter, linalgOp.getLoc());
393429
auto iteratorTypes = llvm::to_vector<4>(linalgOp.iterator_types().getValue());
394430

395431
SmallVector<Value, 4> allIvs;
@@ -420,41 +456,11 @@ static Optional<LinalgLoops> linalgOpToLoopsImpl(LinalgOp linalgOp,
420456
loopSet.insert(ivVal.getOwner()->getParentOp());
421457
}
422458
LinalgLoops loops(loopSet.begin(), loopSet.end());
459+
// Replace all index operations in the loop body.
460+
replaceIndexOpsByInductionVariables(linalgOp, rewriter, loops);
423461
return loops;
424462
}
425463

426-
/// Replace the index operations in the body of the loop nest by the matching
427-
/// induction variables.
428-
static void replaceIndexOpsByInductionVariables(LinalgOp linalgOp,
429-
PatternRewriter &rewriter,
430-
ArrayRef<Operation *> loopOps) {
431-
// Extract the induction variables of the loop nest from outer to inner.
432-
SmallVector<Value> allIvs;
433-
for (Operation *loopOp : loopOps) {
434-
llvm::TypeSwitch<Operation *>(loopOp)
435-
.Case([&](scf::ParallelOp parallelOp) {
436-
allIvs.append(parallelOp.getInductionVars().begin(),
437-
parallelOp.getInductionVars().end());
438-
})
439-
.Case([&](scf::ForOp forOp) {
440-
allIvs.push_back(forOp.getInductionVar());
441-
})
442-
.Case([&](AffineForOp affineForOp) {
443-
allIvs.push_back(affineForOp.getInductionVar());
444-
})
445-
.Default([&](Operation *op) { assert(false && "unexpected op"); });
446-
}
447-
assert(linalgOp.getNumLoops() == allIvs.size() &&
448-
"expected the number of loops and induction variables to match");
449-
// Replace the index operations in the body of the innermost loop op.
450-
if (!loopOps.empty()) {
451-
LoopLikeOpInterface loopOp = loopOps.back();
452-
for (IndexOp indexOp :
453-
llvm::make_early_inc_range(loopOp.getLoopBody().getOps<IndexOp>()))
454-
rewriter.replaceOp(indexOp, allIvs[indexOp.dim()]);
455-
}
456-
}
457-
458464
namespace {
459465
template <typename LoopType>
460466
class LinalgRewritePattern : public RewritePattern {
@@ -467,7 +473,7 @@ class LinalgRewritePattern : public RewritePattern {
467473
auto linalgOp = dyn_cast<LinalgOp>(op);
468474
if (!isa<LinalgOp>(op))
469475
return failure();
470-
if (!linalgLowerOpToLoops<LoopType>(rewriter, linalgOp))
476+
if (!linalgOpToLoopsImpl<LoopType>(rewriter, linalgOp))
471477
return failure();
472478
rewriter.eraseOp(op);
473479
return success();
@@ -614,52 +620,22 @@ mlir::createConvertLinalgToAffineLoopsPass() {
614620
return std::make_unique<LowerToAffineLoops>();
615621
}
616622

617-
/// Emits a loop nest with the proper body for `linalgOp`.
618-
template <typename LoopTy>
619-
Optional<LinalgLoops>
620-
mlir::linalg::linalgLowerOpToLoops(PatternRewriter &rewriter,
621-
LinalgOp linalgOp) {
622-
// Convert indexed_generic ops to generic ops before lowering them to loops.
623-
if (isa<IndexedGenericOp>(linalgOp))
624-
return llvm::None;
625-
626-
Optional<LinalgLoops> loopOps =
627-
linalgOpToLoopsImpl<LoopTy>(linalgOp.getOperation(), rewriter);
628-
if (loopOps.hasValue())
629-
replaceIndexOpsByInductionVariables(linalgOp, rewriter, loopOps.getValue());
630-
return loopOps;
631-
}
632-
633-
template Optional<LinalgLoops>
634-
mlir::linalg::linalgLowerOpToLoops<AffineForOp>(PatternRewriter &rewriter,
635-
LinalgOp linalgOp);
636-
template Optional<LinalgLoops>
637-
mlir::linalg::linalgLowerOpToLoops<scf::ForOp>(PatternRewriter &rewriter,
638-
LinalgOp linalgOp);
639-
template Optional<LinalgLoops>
640-
mlir::linalg::linalgLowerOpToLoops<scf::ParallelOp>(PatternRewriter &rewriter,
641-
LinalgOp linalgOp);
642-
643623
/// Emits a loop nest of `affine.for` with the proper body for `linalgOp`.
644-
LogicalResult mlir::linalg::linalgOpToAffineLoops(PatternRewriter &rewriter,
645-
LinalgOp linalgOp) {
646-
Optional<LinalgLoops> loops =
647-
linalgLowerOpToLoops<AffineForOp>(rewriter, linalgOp);
648-
return loops ? success() : failure();
624+
Optional<LinalgLoops>
625+
mlir::linalg::linalgOpToAffineLoops(PatternRewriter &rewriter,
626+
LinalgOp linalgOp) {
627+
return linalgOpToLoopsImpl<AffineForOp>(rewriter, linalgOp);
649628
}
650629

651630
/// Emits a loop nest of `scf.for` with the proper body for `linalgOp`.
652-
LogicalResult mlir::linalg::linalgOpToLoops(PatternRewriter &rewriter,
653-
LinalgOp linalgOp) {
654-
Optional<LinalgLoops> loops =
655-
linalgLowerOpToLoops<scf::ForOp>(rewriter, linalgOp);
656-
return loops ? success() : failure();
631+
Optional<LinalgLoops> mlir::linalg::linalgOpToLoops(PatternRewriter &rewriter,
632+
LinalgOp linalgOp) {
633+
return linalgOpToLoopsImpl<scf::ForOp>(rewriter, linalgOp);
657634
}
658635

659636
/// Emits a loop nest of `scf.parallel` with the proper body for `linalgOp`.
660-
LogicalResult mlir::linalg::linalgOpToParallelLoops(PatternRewriter &rewriter,
661-
LinalgOp linalgOp) {
662-
Optional<LinalgLoops> loops =
663-
linalgLowerOpToLoops<scf::ParallelOp>(rewriter, linalgOp);
664-
return loops ? success() : failure();
637+
Optional<LinalgLoops>
638+
mlir::linalg::linalgOpToParallelLoops(PatternRewriter &rewriter,
639+
LinalgOp linalgOp) {
640+
return linalgOpToLoopsImpl<scf::ParallelOp>(rewriter, linalgOp);
665641
}

0 commit comments

Comments
 (0)