Skip to content

Commit 03391df

Browse files
author
MaheshRavishankar
committed
[mlir][Linalg] Add loop.parallel lowering for all Linalg Ops.
The outer parallel loops of a linalg operation is lowered to loop.parallel, with the other loops lowered to loop.for. This gets the lowering to loop.parallel on par with the loop.for lowering. In future the reduction loop could also be lowered to loop.parallel. Also add a utility function that returns the loops that are created. Differential Revision: https://reviews.llvm.org/D77678
1 parent dffbeff commit 03391df

File tree

4 files changed

+780
-369
lines changed

4 files changed

+780
-369
lines changed

mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransforms.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,13 @@ LogicalResult tileAndFuseLinalgOpAndSetMarker(
7070
PatternRewriter &rewriter, Operation *op, ArrayRef<int64_t> sizes,
7171
ArrayRef<int64_t> operandIndicesToFuse, StringRef linalgMarker);
7272

73+
using LinalgLoops = SmallVector<Operation *, 4>;
74+
75+
/// Emits a loop nest of with the proper body for `op`.
76+
template <typename LoopTy, typename ConcreteOp>
77+
Optional<LinalgLoops> linalgLowerOpToLoops(PatternRewriter &rewriter,
78+
Operation *op);
79+
7380
/// Emits a loop nest of `loop.for` with the proper body for `op`.
7481
template <typename ConcreteOp>
7582
LogicalResult linalgOpToLoops(PatternRewriter &rewriter, Operation *op);

mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp

Lines changed: 161 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -533,26 +533,111 @@ class LinalgScopedEmitter<IndexedValueType, IndexedGenericOp> {
533533
// consequence, (1) it is only allowed to emit new ops if the match is
534534
// guaranteed to be a success, (2) it is not allowed erase/replace, and (3) an
535535
// encompassing pattern must take care of the erasure logic.
536-
template <typename LoopTy, typename IndexedValueTy, typename ConcreteOpTy>
536+
template <typename LoopTy, typename ConcreteOpTy>
537537
class LinalgOpToLoopsImpl {
538538
public:
539-
static LogicalResult doit(Operation *op, PatternRewriter &rewriter);
539+
static Optional<LinalgLoops> doit(Operation *op, PatternRewriter &rewriter);
540540
};
541541

542-
template <typename LoopTy>
543-
bool loweringIsAllowed(int numParallelLoops, int numLoops) {
544-
return true;
545-
}
546-
template <>
547-
bool loweringIsAllowed<loop::ParallelOp>(int numParallelLoops, int numLoops) {
548-
return numParallelLoops == numLoops;
549-
}
542+
namespace {
543+
/// Helper struct to generate the loop nest for the op. This factored out here
544+
/// to be able to partially specialize this for different LoopTy.
545+
template <typename LoopTy, typename ConcreteOpTy>
546+
class GenerateLoopNest {
547+
public:
548+
using IndexedValueTy =
549+
typename std::conditional<std::is_same<LoopTy, AffineForOp>::value,
550+
AffineIndexedValue, StdIndexedValue>::type;
551+
static void doit(ConcreteOpTy linalgOp, ArrayRef<Value> loopRanges,
552+
MutableArrayRef<ValueHandle> allIvs) {
553+
SmallVector<ValueHandle *, 4> allPIvs =
554+
makeHandlePointers(MutableArrayRef<ValueHandle>(allIvs));
555+
556+
GenericLoopNestRangeBuilder<LoopTy>(allPIvs, loopRanges)([&] {
557+
SmallVector<Value, 4> allIvValues(allIvs.begin(), allIvs.end());
558+
LinalgScopedEmitter<IndexedValueTy,
559+
ConcreteOpTy>::emitScalarImplementation(allIvValues,
560+
linalgOp);
561+
});
562+
}
563+
};
564+
565+
/// Generates loops nest using loop.parallel. loop.parallel is only used for the
566+
/// outer parallel loops. All other loops are generated using loop.for
567+
/// operation.
568+
template <typename ConcreteOpTy>
569+
class GenerateLoopNest<loop::ParallelOp, ConcreteOpTy> {
570+
public:
571+
using IndexedValueTy = StdIndexedValue;
572+
573+
static void doit(ConcreteOpTy linalgOp, ArrayRef<Value> loopRanges,
574+
MutableArrayRef<ValueHandle> allIvs) {
575+
// Only generate loop.parallel for outer consecutive "parallel"
576+
// iterator_types.
577+
// TODO(ravishankarm): Generate loop.parallel for all "parallel" iterator
578+
// types, not just the outer most ones. Also handle "reduction" iterator
579+
// types.
580+
auto nPar = linalgOp.getNumParallelLoops();
581+
auto nRed = linalgOp.getNumReductionLoops();
582+
auto nWin = linalgOp.getNumWindowLoops();
583+
auto nLoops = nPar + nRed + nWin;
584+
auto nOuterPar = linalgOp.iterator_types()
585+
.getValue()
586+
.take_while([](Attribute attr) {
587+
return attr.cast<StringAttr>().getValue() ==
588+
getParallelIteratorTypeName();
589+
})
590+
.size();
591+
// If there are no outer parallel loops, then number of loop ops is same as
592+
// the number of loops, and they are all loop.for ops.
593+
auto nLoopOps = (nOuterPar ? nLoops - nOuterPar + 1 : nLoops);
594+
SmallVector<ValueHandle *, 4> allPIvs =
595+
makeHandlePointers(MutableArrayRef<ValueHandle>(allIvs));
596+
597+
SmallVector<OperationHandle, 4> allLoops(nLoopOps, OperationHandle());
598+
SmallVector<OperationHandle *, 4> allPLoops;
599+
allPLoops.reserve(allLoops.size());
600+
for (OperationHandle &loop : allLoops)
601+
allPLoops.push_back(&loop);
602+
603+
ArrayRef<ValueHandle *> allPIvsRef(allPIvs);
604+
ArrayRef<OperationHandle *> allPLoopsRef(allPLoops);
605+
606+
if (nOuterPar) {
607+
GenericLoopNestRangeBuilder<loop::ParallelOp>(
608+
allPIvsRef.take_front(nOuterPar),
609+
loopRanges.take_front(nOuterPar))([&] {
610+
GenericLoopNestRangeBuilder<loop::ForOp>(
611+
allPIvsRef.drop_front(nOuterPar),
612+
loopRanges.drop_front(nOuterPar))([&] {
613+
SmallVector<Value, 4> allIvValues(allIvs.begin(), allIvs.end());
614+
LinalgScopedEmitter<StdIndexedValue, ConcreteOpTy>::
615+
emitScalarImplementation(allIvValues, linalgOp);
616+
});
617+
});
618+
} else {
619+
// If there are no parallel loops then fallback to generating all loop.for
620+
// operations.
621+
GenericLoopNestRangeBuilder<loop::ForOp>(allPIvsRef, loopRanges)([&] {
622+
SmallVector<Value, 4> allIvValues(allIvs.begin(), allIvs.end());
623+
LinalgScopedEmitter<StdIndexedValue,
624+
ConcreteOpTy>::emitScalarImplementation(allIvValues,
625+
linalgOp);
626+
});
627+
}
628+
}
629+
};
630+
} // namespace
550631

551-
template <typename LoopTy, typename IndexedValueTy, typename ConcreteOpTy>
552-
LogicalResult LinalgOpToLoopsImpl<LoopTy, IndexedValueTy, ConcreteOpTy>::doit(
553-
Operation *op, PatternRewriter &rewriter) {
554-
OpBuilder b(op);
555-
ScopedContext scope(b, op->getLoc());
632+
template <typename LoopTy, typename ConcreteOpTy>
633+
Optional<LinalgLoops>
634+
LinalgOpToLoopsImpl<LoopTy, ConcreteOpTy>::doit(Operation *op,
635+
PatternRewriter &rewriter) {
636+
using Impl = GenerateLoopNest<LoopTy, ConcreteOpTy>;
637+
using IndexedValueTy =
638+
typename GenerateLoopNest<LoopTy, ConcreteOpTy>::IndexedValueTy;
639+
640+
ScopedContext scope(rewriter, op->getLoc());
556641

557642
// The flattened loopToOperandRangesMaps is expected to be an invertible
558643
// permutation map (which is asserted in the inverse calculation).
@@ -563,8 +648,6 @@ LogicalResult LinalgOpToLoopsImpl<LoopTy, IndexedValueTy, ConcreteOpTy>::doit(
563648
auto nRed = linalgOp.getNumReductionLoops();
564649
auto nWin = linalgOp.getNumWindowLoops();
565650
auto nLoops = nPar + nRed + nWin;
566-
if (!loweringIsAllowed<LoopTy>(nPar, nLoops))
567-
return failure();
568651
auto mapsRange =
569652
linalgOp.indexing_maps().template getAsRange<AffineMapAttr>();
570653
auto maps =
@@ -573,67 +656,72 @@ LogicalResult LinalgOpToLoopsImpl<LoopTy, IndexedValueTy, ConcreteOpTy>::doit(
573656
if (!invertedMap) {
574657
LinalgScopedEmitter<IndexedValueTy, ConcreteOpTy>::emitScalarImplementation(
575658
{}, linalgOp);
576-
return success();
659+
return LinalgLoops();
577660
}
578661

579-
SmallVector<ValueHandle, 4> allIvs(nLoops, ValueHandle(b.getIndexType()));
580-
SmallVector<ValueHandle *, 4> allPIvs =
581-
makeHandlePointers(MutableArrayRef<ValueHandle>(allIvs));
582-
auto loopRanges = emitLoopRanges(scope.getBuilder(), scope.getLocation(),
583-
invertedMap, getViewSizes(b, linalgOp));
662+
SmallVector<ValueHandle, 4> allIvs(nLoops,
663+
ValueHandle(rewriter.getIndexType()));
664+
auto loopRanges =
665+
emitLoopRanges(scope.getBuilder(), scope.getLocation(), invertedMap,
666+
getViewSizes(rewriter, linalgOp));
584667
assert(loopRanges.size() == allIvs.size());
585-
586-
GenericLoopNestRangeBuilder<LoopTy>(allPIvs, loopRanges)([&] {
587-
SmallVector<Value, 4> allIvValues(allIvs.begin(), allIvs.end());
588-
LinalgScopedEmitter<IndexedValueTy, ConcreteOpTy>::emitScalarImplementation(
589-
allIvValues, linalgOp);
590-
});
591-
return success();
668+
Impl::doit(linalgOp, loopRanges, allIvs);
669+
// Number of loop ops might be different from the number of ivs since some
670+
// loops like affine.parallel and loop.parallel have multiple ivs.
671+
llvm::SetVector<Operation *> loopSet;
672+
for (ValueHandle &iv : allIvs) {
673+
if (!iv.hasValue())
674+
return {};
675+
// The induction variable is a block argument of the entry block of the
676+
// loop operation.
677+
BlockArgument ivVal = iv.getValue().dyn_cast<BlockArgument>();
678+
if (!ivVal)
679+
return {};
680+
loopSet.insert(ivVal.getOwner()->getParentOp());
681+
}
682+
LinalgLoops loops(loopSet.begin(), loopSet.end());
683+
return loops;
592684
}
593685

594-
template <typename LoopType, typename IndexedValueType, typename ConcreteOp>
686+
template <typename LoopType, typename ConcreteOp>
595687
class LinalgRewritePattern : public RewritePattern {
596688
public:
597689
explicit LinalgRewritePattern(MLIRContext *context)
598690
: RewritePattern(ConcreteOp::getOperationName(), 1, context) {}
599691

600692
LogicalResult matchAndRewrite(Operation *op,
601693
PatternRewriter &rewriter) const override {
602-
using Impl = LinalgOpToLoopsImpl<LoopType, IndexedValueType, ConcreteOp>;
603-
if (failed(Impl::doit(op, rewriter)))
694+
using Impl = LinalgOpToLoopsImpl<LoopType, ConcreteOp>;
695+
if (!Impl::doit(op, rewriter))
604696
return failure();
605697
rewriter.eraseOp(op);
606698
return success();
607699
}
608700
};
609701

610702
// Helper classes for type list expansion.
611-
template <typename LoopType, typename IndexedValueType, typename... LinalgOps>
703+
template <typename LoopType, typename... LinalgOps>
612704
class RewritePatternList;
613705

614-
template <typename LoopType, typename IndexedValueType>
615-
class RewritePatternList<LoopType, IndexedValueType> {
706+
template <typename LoopType>
707+
class RewritePatternList<LoopType> {
616708
public:
617709
static void build(OwningRewritePatternList &patterns, MLIRContext *ctx) {}
618710
};
619711

620-
template <typename LoopType, typename IndexedValueType, typename ConcreteOp,
621-
typename... LinalgOps>
622-
class RewritePatternList<LoopType, IndexedValueType, ConcreteOp, LinalgOps...> {
712+
template <typename LoopType, typename ConcreteOp, typename... LinalgOps>
713+
class RewritePatternList<LoopType, ConcreteOp, LinalgOps...> {
623714
public:
624715
static void build(OwningRewritePatternList &patterns, MLIRContext *ctx) {
625-
patterns
626-
.insert<LinalgRewritePattern<LoopType, IndexedValueType, ConcreteOp>>(
627-
ctx);
628-
RewritePatternList<LoopType, IndexedValueType, LinalgOps...>::build(
629-
patterns, ctx);
716+
patterns.insert<LinalgRewritePattern<LoopType, ConcreteOp>>(ctx);
717+
RewritePatternList<LoopType, LinalgOps...>::build(patterns, ctx);
630718
}
631719
};
632720

633721
/// Populate the given list with patterns that convert from Linalg to LLVM.
634-
template <typename LoopType, typename IndexedValueType>
722+
template <typename LoopType>
635723
void FillRewritePatterns(OwningRewritePatternList &patterns, MLIRContext *ctx) {
636-
RewritePatternList<LoopType, IndexedValueType,
724+
RewritePatternList<LoopType,
637725
#define GET_OP_LIST
638726
#include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
639727
>::build(patterns, ctx);
@@ -677,13 +765,13 @@ struct FoldAffineOp : public RewritePattern {
677765
};
678766
} // namespace
679767

680-
template <typename LoopType, typename IndexedValueType>
768+
template <typename LoopType>
681769
static void lowerLinalgToLoopsImpl(Operation *op, MLIRContext *context) {
682770
OwningRewritePatternList patterns;
683771
// Canonicalization and folding patterns applied greedily allow cleaning up
684772
// the emitted IR on the fly.
685773
// TODO(ntv) fold view and subview ops?
686-
FillRewritePatterns<LoopType, IndexedValueType>(patterns, context);
774+
FillRewritePatterns<LoopType>(patterns, context);
687775
DimOp::getCanonicalizationPatterns(patterns, context);
688776
AffineApplyOp::getCanonicalizationPatterns(patterns, context);
689777
patterns.insert<FoldAffineOp>(context);
@@ -695,21 +783,18 @@ namespace {
695783
struct LowerToAffineLoops
696784
: public LinalgLowerToAffineLoopsBase<LowerToAffineLoops> {
697785
void runOnFunction() override {
698-
lowerLinalgToLoopsImpl<AffineForOp, AffineIndexedValue>(getFunction(),
699-
&getContext());
786+
lowerLinalgToLoopsImpl<AffineForOp>(getFunction(), &getContext());
700787
}
701788
};
702789
struct LowerToLoops : public LinalgLowerToLoopsBase<LowerToLoops> {
703790
void runOnFunction() override {
704-
lowerLinalgToLoopsImpl<loop::ForOp, StdIndexedValue>(getFunction(),
705-
&getContext());
791+
lowerLinalgToLoopsImpl<loop::ForOp>(getFunction(), &getContext());
706792
}
707793
};
708794
struct LowerToParallelLoops
709795
: public LinalgLowerToParallelLoopsBase<LowerToParallelLoops> {
710796
void runOnFunction() override {
711-
lowerLinalgToLoopsImpl<loop::ParallelOp, StdIndexedValue>(getFunction(),
712-
&getContext());
797+
lowerLinalgToLoopsImpl<loop::ParallelOp>(getFunction(), &getContext());
713798
}
714799
};
715800
} // namespace
@@ -728,28 +813,38 @@ mlir::createConvertLinalgToAffineLoopsPass() {
728813
return std::make_unique<LowerToAffineLoops>();
729814
}
730815

816+
/// Emits a loop nest with the proper body for `op`.
817+
template <typename LoopTy, typename ConcreteOp>
818+
Optional<LinalgLoops>
819+
mlir::linalg::linalgLowerOpToLoops(PatternRewriter &rewriter, Operation *op) {
820+
return LinalgOpToLoopsImpl<LoopTy, ConcreteOp>::doit(op, rewriter);
821+
}
822+
731823
/// Emits a loop nest of `loop.for` with the proper body for `op`.
732824
template <typename ConcreteOp>
733825
LogicalResult mlir::linalg::linalgOpToLoops(PatternRewriter &rewriter,
734826
Operation *op) {
735-
return LinalgOpToLoopsImpl<loop::ForOp, StdIndexedValue, ConcreteOp>::doit(
736-
op, rewriter);
827+
Optional<LinalgLoops> loops =
828+
linalgLowerOpToLoops<loop::ForOp, ConcreteOp>(rewriter, op);
829+
return loops ? success() : failure();
737830
}
738831

739832
/// Emits a loop nest of `affine.for` with the proper body for `op`.
740833
template <typename ConcreteOp>
741834
LogicalResult mlir::linalg::linalgOpToAffineLoops(PatternRewriter &rewriter,
742835
Operation *op) {
743-
return LinalgOpToLoopsImpl<AffineForOp, AffineIndexedValue, ConcreteOp>::doit(
744-
op, rewriter);
836+
Optional<LinalgLoops> loops =
837+
linalgLowerOpToLoops<AffineForOp, ConcreteOp>(rewriter, op);
838+
return loops ? success() : failure();
745839
}
746840

747841
/// Emits a loop nest of `loop.parallel` with the proper body for `op`.
748842
template <typename ConcreteOp>
749843
LogicalResult mlir::linalg::linalgOpToParallelLoops(PatternRewriter &rewriter,
750844
Operation *op) {
751-
return LinalgOpToLoopsImpl<loop::ParallelOp, StdIndexedValue,
752-
ConcreteOp>::doit(op, rewriter);
845+
Optional<LinalgLoops> loops =
846+
linalgLowerOpToLoops<loop::ParallelOp, ConcreteOp>(rewriter, op);
847+
return loops ? success() : failure();
753848
}
754849

755850
// TODO(ntv) Need to make these instantiations more future-proof to avoid the
@@ -758,7 +853,12 @@ LogicalResult mlir::linalg::linalgOpToParallelLoops(PatternRewriter &rewriter,
758853
template LogicalResult mlir::linalg::linalgOpToLoops<OP_TYPE>( \
759854
PatternRewriter & rewriter, Operation * op); \
760855
template LogicalResult mlir::linalg::linalgOpToAffineLoops<OP_TYPE>( \
761-
PatternRewriter & rewriter, Operation * op);
856+
PatternRewriter & rewriter, Operation * op); \
857+
template LogicalResult mlir::linalg::linalgOpToParallelLoops<OP_TYPE>( \
858+
PatternRewriter & rewriter, Operation * op); \
859+
template Optional<LinalgLoops> \
860+
mlir::linalg::linalgLowerOpToLoops<loop::ParallelOp, OP_TYPE>( \
861+
PatternRewriter & rewriter, Operation * op);
762862

763863
INSTANTIATE_LINALG_OP_TO_LOOPS(CopyOp)
764864
INSTANTIATE_LINALG_OP_TO_LOOPS(FillOp)
@@ -771,9 +871,3 @@ INSTANTIATE_LINALG_OP_TO_LOOPS(PoolingMinOp)
771871
INSTANTIATE_LINALG_OP_TO_LOOPS(PoolingSumOp)
772872
INSTANTIATE_LINALG_OP_TO_LOOPS(GenericOp)
773873
INSTANTIATE_LINALG_OP_TO_LOOPS(IndexedGenericOp)
774-
775-
// TODO(pifon): Enable lowering to parallel loops for ops other than
776-
// linalg.generic for now to be on the safe side.
777-
template LogicalResult
778-
mlir::linalg::linalgOpToParallelLoops<GenericOp>(PatternRewriter &rewriter,
779-
Operation *op);

0 commit comments

Comments
 (0)