@@ -533,26 +533,111 @@ class LinalgScopedEmitter<IndexedValueType, IndexedGenericOp> {
533
533
// consequence, (1) it is only allowed to emit new ops if the match is
534
534
// guaranteed to be a success, (2) it is not allowed erase/replace, and (3) an
535
535
// encompassing pattern must take care of the erasure logic.
536
- template <typename LoopTy, typename IndexedValueTy, typename ConcreteOpTy>
536
+ template <typename LoopTy, typename ConcreteOpTy>
537
537
class LinalgOpToLoopsImpl {
538
538
public:
539
- static LogicalResult doit (Operation *op, PatternRewriter &rewriter);
539
+ static Optional<LinalgLoops> doit (Operation *op, PatternRewriter &rewriter);
540
540
};
541
541
542
- template <typename LoopTy>
543
- bool loweringIsAllowed (int numParallelLoops, int numLoops) {
544
- return true ;
545
- }
546
- template <>
547
- bool loweringIsAllowed<loop::ParallelOp>(int numParallelLoops, int numLoops) {
548
- return numParallelLoops == numLoops;
549
- }
542
+ namespace {
543
+ // / Helper struct to generate the loop nest for the op. This factored out here
544
+ // / to be able to partially specialize this for different LoopTy.
545
+ template <typename LoopTy, typename ConcreteOpTy>
546
+ class GenerateLoopNest {
547
+ public:
548
+ using IndexedValueTy =
549
+ typename std::conditional<std::is_same<LoopTy, AffineForOp>::value,
550
+ AffineIndexedValue, StdIndexedValue>::type;
551
+ static void doit (ConcreteOpTy linalgOp, ArrayRef<Value> loopRanges,
552
+ MutableArrayRef<ValueHandle> allIvs) {
553
+ SmallVector<ValueHandle *, 4 > allPIvs =
554
+ makeHandlePointers (MutableArrayRef<ValueHandle>(allIvs));
555
+
556
+ GenericLoopNestRangeBuilder<LoopTy>(allPIvs, loopRanges)([&] {
557
+ SmallVector<Value, 4 > allIvValues (allIvs.begin (), allIvs.end ());
558
+ LinalgScopedEmitter<IndexedValueTy,
559
+ ConcreteOpTy>::emitScalarImplementation (allIvValues,
560
+ linalgOp);
561
+ });
562
+ }
563
+ };
564
+
565
+ // / Generates loops nest using loop.parallel. loop.parallel is only used for the
566
+ // / outer parallel loops. All other loops are generated using loop.for
567
+ // / operation.
568
+ template <typename ConcreteOpTy>
569
+ class GenerateLoopNest <loop::ParallelOp, ConcreteOpTy> {
570
+ public:
571
+ using IndexedValueTy = StdIndexedValue;
572
+
573
+ static void doit (ConcreteOpTy linalgOp, ArrayRef<Value> loopRanges,
574
+ MutableArrayRef<ValueHandle> allIvs) {
575
+ // Only generate loop.parallel for outer consecutive "parallel"
576
+ // iterator_types.
577
+ // TODO(ravishankarm): Generate loop.parallel for all "parallel" iterator
578
+ // types, not just the outer most ones. Also handle "reduction" iterator
579
+ // types.
580
+ auto nPar = linalgOp.getNumParallelLoops ();
581
+ auto nRed = linalgOp.getNumReductionLoops ();
582
+ auto nWin = linalgOp.getNumWindowLoops ();
583
+ auto nLoops = nPar + nRed + nWin;
584
+ auto nOuterPar = linalgOp.iterator_types ()
585
+ .getValue ()
586
+ .take_while ([](Attribute attr) {
587
+ return attr.cast <StringAttr>().getValue () ==
588
+ getParallelIteratorTypeName ();
589
+ })
590
+ .size ();
591
+ // If there are no outer parallel loops, then number of loop ops is same as
592
+ // the number of loops, and they are all loop.for ops.
593
+ auto nLoopOps = (nOuterPar ? nLoops - nOuterPar + 1 : nLoops);
594
+ SmallVector<ValueHandle *, 4 > allPIvs =
595
+ makeHandlePointers (MutableArrayRef<ValueHandle>(allIvs));
596
+
597
+ SmallVector<OperationHandle, 4 > allLoops (nLoopOps, OperationHandle ());
598
+ SmallVector<OperationHandle *, 4 > allPLoops;
599
+ allPLoops.reserve (allLoops.size ());
600
+ for (OperationHandle &loop : allLoops)
601
+ allPLoops.push_back (&loop);
602
+
603
+ ArrayRef<ValueHandle *> allPIvsRef (allPIvs);
604
+ ArrayRef<OperationHandle *> allPLoopsRef (allPLoops);
605
+
606
+ if (nOuterPar) {
607
+ GenericLoopNestRangeBuilder<loop::ParallelOp>(
608
+ allPIvsRef.take_front (nOuterPar),
609
+ loopRanges.take_front (nOuterPar))([&] {
610
+ GenericLoopNestRangeBuilder<loop::ForOp>(
611
+ allPIvsRef.drop_front (nOuterPar),
612
+ loopRanges.drop_front (nOuterPar))([&] {
613
+ SmallVector<Value, 4 > allIvValues (allIvs.begin (), allIvs.end ());
614
+ LinalgScopedEmitter<StdIndexedValue, ConcreteOpTy>::
615
+ emitScalarImplementation (allIvValues, linalgOp);
616
+ });
617
+ });
618
+ } else {
619
+ // If there are no parallel loops then fallback to generating all loop.for
620
+ // operations.
621
+ GenericLoopNestRangeBuilder<loop::ForOp>(allPIvsRef, loopRanges)([&] {
622
+ SmallVector<Value, 4 > allIvValues (allIvs.begin (), allIvs.end ());
623
+ LinalgScopedEmitter<StdIndexedValue,
624
+ ConcreteOpTy>::emitScalarImplementation (allIvValues,
625
+ linalgOp);
626
+ });
627
+ }
628
+ }
629
+ };
630
+ } // namespace
550
631
551
- template <typename LoopTy, typename IndexedValueTy, typename ConcreteOpTy>
552
- LogicalResult LinalgOpToLoopsImpl<LoopTy, IndexedValueTy, ConcreteOpTy>::doit(
553
- Operation *op, PatternRewriter &rewriter) {
554
- OpBuilder b (op);
555
- ScopedContext scope (b, op->getLoc ());
632
+ template <typename LoopTy, typename ConcreteOpTy>
633
+ Optional<LinalgLoops>
634
+ LinalgOpToLoopsImpl<LoopTy, ConcreteOpTy>::doit(Operation *op,
635
+ PatternRewriter &rewriter) {
636
+ using Impl = GenerateLoopNest<LoopTy, ConcreteOpTy>;
637
+ using IndexedValueTy =
638
+ typename GenerateLoopNest<LoopTy, ConcreteOpTy>::IndexedValueTy;
639
+
640
+ ScopedContext scope (rewriter, op->getLoc ());
556
641
557
642
// The flattened loopToOperandRangesMaps is expected to be an invertible
558
643
// permutation map (which is asserted in the inverse calculation).
@@ -563,8 +648,6 @@ LogicalResult LinalgOpToLoopsImpl<LoopTy, IndexedValueTy, ConcreteOpTy>::doit(
563
648
auto nRed = linalgOp.getNumReductionLoops ();
564
649
auto nWin = linalgOp.getNumWindowLoops ();
565
650
auto nLoops = nPar + nRed + nWin;
566
- if (!loweringIsAllowed<LoopTy>(nPar, nLoops))
567
- return failure ();
568
651
auto mapsRange =
569
652
linalgOp.indexing_maps ().template getAsRange <AffineMapAttr>();
570
653
auto maps =
@@ -573,67 +656,72 @@ LogicalResult LinalgOpToLoopsImpl<LoopTy, IndexedValueTy, ConcreteOpTy>::doit(
573
656
if (!invertedMap) {
574
657
LinalgScopedEmitter<IndexedValueTy, ConcreteOpTy>::emitScalarImplementation (
575
658
{}, linalgOp);
576
- return success ();
659
+ return LinalgLoops ();
577
660
}
578
661
579
- SmallVector<ValueHandle, 4 > allIvs (nLoops, ValueHandle (b. getIndexType ()));
580
- SmallVector<ValueHandle *, 4 > allPIvs =
581
- makeHandlePointers (MutableArrayRef<ValueHandle>(allIvs));
582
- auto loopRanges = emitLoopRanges (scope.getBuilder (), scope.getLocation (),
583
- invertedMap, getViewSizes (b , linalgOp));
662
+ SmallVector<ValueHandle, 4 > allIvs (nLoops,
663
+ ValueHandle (rewriter. getIndexType ()));
664
+ auto loopRanges =
665
+ emitLoopRanges (scope.getBuilder (), scope.getLocation (), invertedMap ,
666
+ getViewSizes (rewriter , linalgOp));
584
667
assert (loopRanges.size () == allIvs.size ());
585
-
586
- GenericLoopNestRangeBuilder<LoopTy>(allPIvs, loopRanges)([&] {
587
- SmallVector<Value, 4 > allIvValues (allIvs.begin (), allIvs.end ());
588
- LinalgScopedEmitter<IndexedValueTy, ConcreteOpTy>::emitScalarImplementation (
589
- allIvValues, linalgOp);
590
- });
591
- return success ();
668
+ Impl::doit (linalgOp, loopRanges, allIvs);
669
+ // Number of loop ops might be different from the number of ivs since some
670
+ // loops like affine.parallel and loop.parallel have multiple ivs.
671
+ llvm::SetVector<Operation *> loopSet;
672
+ for (ValueHandle &iv : allIvs) {
673
+ if (!iv.hasValue ())
674
+ return {};
675
+ // The induction variable is a block argument of the entry block of the
676
+ // loop operation.
677
+ BlockArgument ivVal = iv.getValue ().dyn_cast <BlockArgument>();
678
+ if (!ivVal)
679
+ return {};
680
+ loopSet.insert (ivVal.getOwner ()->getParentOp ());
681
+ }
682
+ LinalgLoops loops (loopSet.begin (), loopSet.end ());
683
+ return loops;
592
684
}
593
685
594
- template <typename LoopType, typename IndexedValueType, typename ConcreteOp>
686
+ template <typename LoopType, typename ConcreteOp>
595
687
class LinalgRewritePattern : public RewritePattern {
596
688
public:
597
689
explicit LinalgRewritePattern (MLIRContext *context)
598
690
: RewritePattern(ConcreteOp::getOperationName(), 1, context) {}
599
691
600
692
LogicalResult matchAndRewrite (Operation *op,
601
693
PatternRewriter &rewriter) const override {
602
- using Impl = LinalgOpToLoopsImpl<LoopType, IndexedValueType, ConcreteOp>;
603
- if (failed ( Impl::doit (op, rewriter) ))
694
+ using Impl = LinalgOpToLoopsImpl<LoopType, ConcreteOp>;
695
+ if (! Impl::doit (op, rewriter))
604
696
return failure ();
605
697
rewriter.eraseOp (op);
606
698
return success ();
607
699
}
608
700
};
609
701
610
702
// Helper classes for type list expansion.
611
- template <typename LoopType, typename IndexedValueType, typename ... LinalgOps>
703
+ template <typename LoopType, typename ... LinalgOps>
612
704
class RewritePatternList ;
613
705
614
- template <typename LoopType, typename IndexedValueType >
615
- class RewritePatternList <LoopType, IndexedValueType > {
706
+ template <typename LoopType>
707
+ class RewritePatternList <LoopType> {
616
708
public:
617
709
static void build (OwningRewritePatternList &patterns, MLIRContext *ctx) {}
618
710
};
619
711
620
- template <typename LoopType, typename IndexedValueType, typename ConcreteOp,
621
- typename ... LinalgOps>
622
- class RewritePatternList <LoopType, IndexedValueType, ConcreteOp, LinalgOps...> {
712
+ template <typename LoopType, typename ConcreteOp, typename ... LinalgOps>
713
+ class RewritePatternList <LoopType, ConcreteOp, LinalgOps...> {
623
714
public:
624
715
static void build (OwningRewritePatternList &patterns, MLIRContext *ctx) {
625
- patterns
626
- .insert <LinalgRewritePattern<LoopType, IndexedValueType, ConcreteOp>>(
627
- ctx);
628
- RewritePatternList<LoopType, IndexedValueType, LinalgOps...>::build (
629
- patterns, ctx);
716
+ patterns.insert <LinalgRewritePattern<LoopType, ConcreteOp>>(ctx);
717
+ RewritePatternList<LoopType, LinalgOps...>::build (patterns, ctx);
630
718
}
631
719
};
632
720
633
721
// / Populate the given list with patterns that convert from Linalg to LLVM.
634
- template <typename LoopType, typename IndexedValueType >
722
+ template <typename LoopType>
635
723
void FillRewritePatterns (OwningRewritePatternList &patterns, MLIRContext *ctx) {
636
- RewritePatternList<LoopType, IndexedValueType,
724
+ RewritePatternList<LoopType,
637
725
#define GET_OP_LIST
638
726
#include " mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
639
727
>::build (patterns, ctx);
@@ -677,13 +765,13 @@ struct FoldAffineOp : public RewritePattern {
677
765
};
678
766
} // namespace
679
767
680
- template <typename LoopType, typename IndexedValueType >
768
+ template <typename LoopType>
681
769
static void lowerLinalgToLoopsImpl (Operation *op, MLIRContext *context) {
682
770
OwningRewritePatternList patterns;
683
771
// Canonicalization and folding patterns applied greedily allow cleaning up
684
772
// the emitted IR on the fly.
685
773
// TODO(ntv) fold view and subview ops?
686
- FillRewritePatterns<LoopType, IndexedValueType >(patterns, context);
774
+ FillRewritePatterns<LoopType>(patterns, context);
687
775
DimOp::getCanonicalizationPatterns (patterns, context);
688
776
AffineApplyOp::getCanonicalizationPatterns (patterns, context);
689
777
patterns.insert <FoldAffineOp>(context);
@@ -695,21 +783,18 @@ namespace {
695
783
struct LowerToAffineLoops
696
784
: public LinalgLowerToAffineLoopsBase<LowerToAffineLoops> {
697
785
void runOnFunction () override {
698
- lowerLinalgToLoopsImpl<AffineForOp, AffineIndexedValue>(getFunction (),
699
- &getContext ());
786
+ lowerLinalgToLoopsImpl<AffineForOp>(getFunction (), &getContext ());
700
787
}
701
788
};
702
789
struct LowerToLoops : public LinalgLowerToLoopsBase <LowerToLoops> {
703
790
void runOnFunction () override {
704
- lowerLinalgToLoopsImpl<loop::ForOp, StdIndexedValue>(getFunction (),
705
- &getContext ());
791
+ lowerLinalgToLoopsImpl<loop::ForOp>(getFunction (), &getContext ());
706
792
}
707
793
};
708
794
struct LowerToParallelLoops
709
795
: public LinalgLowerToParallelLoopsBase<LowerToParallelLoops> {
710
796
void runOnFunction () override {
711
- lowerLinalgToLoopsImpl<loop::ParallelOp, StdIndexedValue>(getFunction (),
712
- &getContext ());
797
+ lowerLinalgToLoopsImpl<loop::ParallelOp>(getFunction (), &getContext ());
713
798
}
714
799
};
715
800
} // namespace
@@ -728,28 +813,38 @@ mlir::createConvertLinalgToAffineLoopsPass() {
728
813
return std::make_unique<LowerToAffineLoops>();
729
814
}
730
815
816
+ // / Emits a loop nest with the proper body for `op`.
817
+ template <typename LoopTy, typename ConcreteOp>
818
+ Optional<LinalgLoops>
819
+ mlir::linalg::linalgLowerOpToLoops (PatternRewriter &rewriter, Operation *op) {
820
+ return LinalgOpToLoopsImpl<LoopTy, ConcreteOp>::doit (op, rewriter);
821
+ }
822
+
731
823
// / Emits a loop nest of `loop.for` with the proper body for `op`.
732
824
template <typename ConcreteOp>
733
825
LogicalResult mlir::linalg::linalgOpToLoops (PatternRewriter &rewriter,
734
826
Operation *op) {
735
- return LinalgOpToLoopsImpl<loop::ForOp, StdIndexedValue, ConcreteOp>::doit (
736
- op, rewriter);
827
+ Optional<LinalgLoops> loops =
828
+ linalgLowerOpToLoops<loop::ForOp, ConcreteOp>(rewriter, op);
829
+ return loops ? success () : failure ();
737
830
}
738
831
739
832
// / Emits a loop nest of `affine.for` with the proper body for `op`.
740
833
template <typename ConcreteOp>
741
834
LogicalResult mlir::linalg::linalgOpToAffineLoops (PatternRewriter &rewriter,
742
835
Operation *op) {
743
- return LinalgOpToLoopsImpl<AffineForOp, AffineIndexedValue, ConcreteOp>::doit (
744
- op, rewriter);
836
+ Optional<LinalgLoops> loops =
837
+ linalgLowerOpToLoops<AffineForOp, ConcreteOp>(rewriter, op);
838
+ return loops ? success () : failure ();
745
839
}
746
840
747
841
// / Emits a loop nest of `loop.parallel` with the proper body for `op`.
748
842
template <typename ConcreteOp>
749
843
LogicalResult mlir::linalg::linalgOpToParallelLoops (PatternRewriter &rewriter,
750
844
Operation *op) {
751
- return LinalgOpToLoopsImpl<loop::ParallelOp, StdIndexedValue,
752
- ConcreteOp>::doit (op, rewriter);
845
+ Optional<LinalgLoops> loops =
846
+ linalgLowerOpToLoops<loop::ParallelOp, ConcreteOp>(rewriter, op);
847
+ return loops ? success () : failure ();
753
848
}
754
849
755
850
// TODO(ntv) Need to make these instantiations more future-proof to avoid the
@@ -758,7 +853,12 @@ LogicalResult mlir::linalg::linalgOpToParallelLoops(PatternRewriter &rewriter,
758
853
template LogicalResult mlir::linalg::linalgOpToLoops<OP_TYPE>( \
759
854
PatternRewriter & rewriter, Operation * op); \
760
855
template LogicalResult mlir::linalg::linalgOpToAffineLoops<OP_TYPE>( \
761
- PatternRewriter & rewriter, Operation * op);
856
+ PatternRewriter & rewriter, Operation * op); \
857
+ template LogicalResult mlir::linalg::linalgOpToParallelLoops<OP_TYPE>( \
858
+ PatternRewriter & rewriter, Operation * op); \
859
+ template Optional<LinalgLoops> \
860
+ mlir::linalg::linalgLowerOpToLoops<loop::ParallelOp, OP_TYPE>( \
861
+ PatternRewriter & rewriter, Operation * op);
762
862
763
863
INSTANTIATE_LINALG_OP_TO_LOOPS (CopyOp)
764
864
INSTANTIATE_LINALG_OP_TO_LOOPS(FillOp)
@@ -771,9 +871,3 @@ INSTANTIATE_LINALG_OP_TO_LOOPS(PoolingMinOp)
771
871
INSTANTIATE_LINALG_OP_TO_LOOPS(PoolingSumOp)
772
872
INSTANTIATE_LINALG_OP_TO_LOOPS(GenericOp)
773
873
INSTANTIATE_LINALG_OP_TO_LOOPS(IndexedGenericOp)
774
-
775
- // TODO(pifon): Enable lowering to parallel loops for ops other than
776
- // linalg.generic for now to be on the safe side.
777
- template LogicalResult
778
- mlir::linalg::linalgOpToParallelLoops<GenericOp>(PatternRewriter &rewriter,
779
- Operation *op);
0 commit comments