@@ -833,4 +833,265 @@ struct LinalgFoldUnitExtentDimsPass
833
833
(void )applyPatternsAndFoldGreedily (op, std::move (patterns));
834
834
}
835
835
};
836
+
837
+ } // namespace
838
+
839
+ namespace {
840
+
841
+ // / Returns reassociation indices for collapsing/expanding a
842
+ // / tensor of rank `rank` at position `pos`.
843
+ static SmallVector<ReassociationIndices>
844
+ getReassociationForReshapeAtDim (int64_t rank, int64_t pos) {
845
+ SmallVector<ReassociationIndices> reassociation (rank - 1 , {0 , 1 });
846
+ bool lastDim = pos == rank - 1 ;
847
+ if (rank > 2 ) {
848
+ for (int64_t i = 0 ; i < rank - 1 ; i++) {
849
+ if (i == pos || (lastDim && i == pos - 1 ))
850
+ reassociation[i] = ReassociationIndices{i, i + 1 };
851
+ else if (i < pos)
852
+ reassociation[i] = ReassociationIndices{i};
853
+ else
854
+ reassociation[i] = ReassociationIndices{i + 1 };
855
+ }
856
+ }
857
+ return reassociation;
858
+ }
859
+
860
+ // / Returns a collapsed `val` where the collapsing occurs at dim `pos`.
861
+ // / If `pos < 0`, then don't collapse.
862
+ static Value collapseSingletonDimAt (PatternRewriter &rewriter, Value val,
863
+ int64_t pos) {
864
+ if (pos < 0 )
865
+ return val;
866
+ auto valType = cast<ShapedType>(val.getType ());
867
+ SmallVector<int64_t > collapsedShape (valType.getShape ());
868
+ collapsedShape.erase (collapsedShape.begin () + pos);
869
+ return collapseValue (
870
+ rewriter, val.getLoc (), val, collapsedShape,
871
+ getReassociationForReshapeAtDim (valType.getRank (), pos),
872
+ ControlDropUnitDims::RankReductionStrategy::ReassociativeReshape);
873
+ }
874
+
875
+ // / Base class for all rank reduction patterns for contraction ops
876
+ // / with unit dimensions. All patterns should convert one named op
877
+ // / to another named op. Intended to reduce only one iteration space dim
878
+ // / at a time.
879
+ // / Reducing multiple dims will happen with recusive application of
880
+ // / pattern rewrites.
881
+ template <typename FromOpTy, typename ToOpTy>
882
+ struct RankReduceContractionOps : OpRewritePattern<FromOpTy> {
883
+ using OpRewritePattern<FromOpTy>::OpRewritePattern;
884
+
885
+ // / Collapse all collapsable operands.
886
+ SmallVector<Value>
887
+ collapseOperands (PatternRewriter &rewriter, ArrayRef<Value> operands,
888
+ ArrayRef<int64_t > operandCollapseDims) const {
889
+ assert (operandCollapseDims.size () == 3 && operands.size () == 3 &&
890
+ " expected 3 operands and dims" );
891
+ return llvm::map_to_vector (
892
+ llvm::zip (operands, operandCollapseDims), [&](auto pair) {
893
+ return collapseSingletonDimAt (rewriter, std::get<0 >(pair),
894
+ std::get<1 >(pair));
895
+ });
896
+ }
897
+
898
+ // / Expand result tensor.
899
+ Value expandResult (PatternRewriter &rewriter, Value result,
900
+ RankedTensorType expandedType, int64_t dim) const {
901
+ return rewriter.create <tensor::ExpandShapeOp>(
902
+ result.getLoc (), expandedType, result,
903
+ getReassociationForReshapeAtDim (expandedType.getRank (), dim));
904
+ }
905
+
906
+ LogicalResult matchAndRewrite (FromOpTy contractionOp,
907
+ PatternRewriter &rewriter) const override {
908
+
909
+ auto loc = contractionOp.getLoc ();
910
+ auto inputs = contractionOp.getDpsInputs ();
911
+ auto inits = contractionOp.getDpsInits ();
912
+ if (inputs.size () != 2 || inits.size () != 1 )
913
+ return rewriter.notifyMatchFailure (contractionOp,
914
+ " expected 2 inputs and 1 init" );
915
+ auto lhs = inputs[0 ];
916
+ auto rhs = inputs[1 ];
917
+ auto init = inits[0 ];
918
+ SmallVector<Value> operands{lhs, rhs, init};
919
+
920
+ SmallVector<int64_t > operandUnitDims;
921
+ if (failed (getOperandUnitDims (contractionOp, operandUnitDims)))
922
+ return rewriter.notifyMatchFailure (contractionOp,
923
+ " no reducable dims found" );
924
+
925
+ SmallVector<Value> collapsedOperands =
926
+ collapseOperands (rewriter, operands, operandUnitDims);
927
+ Value collapsedLhs = collapsedOperands[0 ];
928
+ Value collapsedRhs = collapsedOperands[1 ];
929
+ Value collapsedInit = collapsedOperands[2 ];
930
+ SmallVector<Type, 1 > collapsedResultTy;
931
+ if (isa<RankedTensorType>(collapsedInit.getType ()))
932
+ collapsedResultTy.push_back (collapsedInit.getType ());
933
+ auto collapsedOp = rewriter.create <ToOpTy>(
934
+ loc, collapsedResultTy, ValueRange{collapsedLhs, collapsedRhs},
935
+ ValueRange{collapsedInit});
936
+ for (auto attr : contractionOp->getAttrs ()) {
937
+ if (attr.getName () == LinalgDialect::kMemoizedIndexingMapsAttrName )
938
+ continue ;
939
+ collapsedOp->setAttr (attr.getName (), attr.getValue ());
940
+ }
941
+
942
+ auto results = contractionOp.getResults ();
943
+ assert (results.size () < 2 && " expected at most one result" );
944
+ if (results.empty ()) {
945
+ rewriter.replaceOp (contractionOp, collapsedOp);
946
+ } else {
947
+ rewriter.replaceOp (
948
+ contractionOp,
949
+ expandResult (rewriter, collapsedOp.getResultTensors ()[0 ],
950
+ cast<RankedTensorType>(results[0 ].getType ()),
951
+ operandUnitDims[2 ]));
952
+ }
953
+
954
+ return success ();
955
+ }
956
+
957
+ // / Populate `operandUnitDims` with 3 indices indicating the unit dim
958
+ // / for each operand that should be collapsed in this pattern. If an
959
+ // / operand shouldn't be collapsed, the index should be negative.
960
+ virtual LogicalResult
961
+ getOperandUnitDims (LinalgOp op,
962
+ SmallVectorImpl<int64_t > &operandUnitDims) const = 0 ;
963
+ };
964
+
965
+ // / Patterns for unbatching batched contraction ops
966
+ template <typename FromOpTy, typename ToOpTy>
967
+ struct RankReduceToUnBatched : RankReduceContractionOps<FromOpTy, ToOpTy> {
968
+ using RankReduceContractionOps<FromOpTy, ToOpTy>::RankReduceContractionOps;
969
+
970
+ // / Look for unit batch dims to collapse.
971
+ LogicalResult
972
+ getOperandUnitDims (LinalgOp op,
973
+ SmallVectorImpl<int64_t > &operandUnitDims) const override {
974
+ FailureOr<ContractionDimensions> maybeContractionDims =
975
+ inferContractionDims (op);
976
+ if (failed (maybeContractionDims)) {
977
+ LLVM_DEBUG (llvm::dbgs () << " could not infer contraction dims" );
978
+ return failure ();
979
+ }
980
+ ContractionDimensions contractionDims = maybeContractionDims.value ();
981
+
982
+ if (contractionDims.batch .size () != 1 )
983
+ return failure ();
984
+ auto batchDim = contractionDims.batch [0 ];
985
+ SmallVector<std::pair<Value, unsigned >, 3 > bOperands;
986
+ op.mapIterationSpaceDimToAllOperandDims (batchDim, bOperands);
987
+ if (bOperands.size () != 3 || llvm::any_of (bOperands, [](auto pair) {
988
+ return cast<ShapedType>(std::get<0 >(pair).getType ())
989
+ .getShape ()[std::get<1 >(pair)] != 1 ;
990
+ })) {
991
+ LLVM_DEBUG (llvm::dbgs () << " specified unit dims not found" );
992
+ return failure ();
993
+ }
994
+
995
+ operandUnitDims = SmallVector<int64_t >{std::get<1 >(bOperands[0 ]),
996
+ std::get<1 >(bOperands[1 ]),
997
+ std::get<1 >(bOperands[2 ])};
998
+ return success ();
999
+ }
1000
+ };
1001
+
1002
+ // / Patterns for reducing non-batch dimensions
1003
+ template <typename FromOpTy, typename ToOpTy>
1004
+ struct RankReduceMatmul : RankReduceContractionOps<FromOpTy, ToOpTy> {
1005
+ using RankReduceContractionOps<FromOpTy, ToOpTy>::RankReduceContractionOps;
1006
+
1007
+ // / Helper for determining whether the lhs/init or rhs/init are reduced.
1008
+ static bool constexpr reduceLeft =
1009
+ (std::is_same_v<FromOpTy, BatchMatmulOp> &&
1010
+ std::is_same_v<ToOpTy, BatchVecmatOp>) ||
1011
+ (std::is_same_v<FromOpTy, BatchMatmulTransposeAOp> &&
1012
+ std::is_same_v<ToOpTy, BatchVecmatOp>) ||
1013
+ (std::is_same_v<FromOpTy, MatmulOp> &&
1014
+ std::is_same_v<ToOpTy, VecmatOp>) ||
1015
+ (std::is_same_v<FromOpTy, MatmulTransposeAOp> &&
1016
+ std::is_same_v<ToOpTy, VecmatOp>) ||
1017
+ (std::is_same_v<FromOpTy, MatvecOp> && std::is_same_v<ToOpTy, DotOp>);
1018
+
1019
+ // / Look for non-batch spatial dims to collapse.
1020
+ LogicalResult
1021
+ getOperandUnitDims (LinalgOp op,
1022
+ SmallVectorImpl<int64_t > &operandUnitDims) const override {
1023
+ FailureOr<ContractionDimensions> maybeContractionDims =
1024
+ inferContractionDims (op);
1025
+ if (failed (maybeContractionDims)) {
1026
+ LLVM_DEBUG (llvm::dbgs () << " could not infer contraction dims" );
1027
+ return failure ();
1028
+ }
1029
+ ContractionDimensions contractionDims = maybeContractionDims.value ();
1030
+
1031
+ if constexpr (reduceLeft) {
1032
+ auto m = contractionDims.m [0 ];
1033
+ SmallVector<std::pair<Value, unsigned >, 2 > mOperands ;
1034
+ op.mapIterationSpaceDimToAllOperandDims (m, mOperands );
1035
+ if (mOperands .size () != 2 )
1036
+ return failure ();
1037
+ if (llvm::all_of (mOperands , [](auto pair) {
1038
+ return cast<ShapedType>(std::get<0 >(pair).getType ())
1039
+ .getShape ()[std::get<1 >(pair)] == 1 ;
1040
+ })) {
1041
+ operandUnitDims = SmallVector<int64_t >{std::get<1 >(mOperands [0 ]), -1 ,
1042
+ std::get<1 >(mOperands [1 ])};
1043
+ return success ();
1044
+ }
1045
+ } else {
1046
+ auto n = contractionDims.n [0 ];
1047
+ SmallVector<std::pair<Value, unsigned >, 2 > nOperands;
1048
+ op.mapIterationSpaceDimToAllOperandDims (n, nOperands);
1049
+ if (nOperands.size () != 2 )
1050
+ return failure ();
1051
+ if (llvm::all_of (nOperands, [](auto pair) {
1052
+ return cast<ShapedType>(std::get<0 >(pair).getType ())
1053
+ .getShape ()[std::get<1 >(pair)] == 1 ;
1054
+ })) {
1055
+ operandUnitDims = SmallVector<int64_t >{-1 , std::get<1 >(nOperands[0 ]),
1056
+ std::get<1 >(nOperands[1 ])};
1057
+ return success ();
1058
+ }
1059
+ }
1060
+ LLVM_DEBUG (llvm::dbgs () << " specified unit dims not found" );
1061
+ return failure ();
1062
+ }
1063
+ };
1064
+
836
1065
} // namespace
1066
+
1067
+ void mlir::linalg::populateContractionOpRankReducingPatterns (
1068
+ RewritePatternSet &patterns) {
1069
+ MLIRContext *context = patterns.getContext ();
1070
+ // Unbatching patterns for unit batch size
1071
+ patterns.add <RankReduceToUnBatched<BatchMatmulOp, MatmulOp>>(context);
1072
+ patterns
1073
+ .add <RankReduceToUnBatched<BatchMatmulTransposeAOp, MatmulTransposeAOp>>(
1074
+ context);
1075
+ patterns
1076
+ .add <RankReduceToUnBatched<BatchMatmulTransposeBOp, MatmulTransposeBOp>>(
1077
+ context);
1078
+ patterns.add <RankReduceToUnBatched<BatchMatvecOp, MatvecOp>>(context);
1079
+ patterns.add <RankReduceToUnBatched<BatchVecmatOp, VecmatOp>>(context);
1080
+
1081
+ // Non-batch rank 1 reducing patterns
1082
+ patterns.add <RankReduceMatmul<MatmulOp, VecmatOp>>(context);
1083
+ patterns.add <RankReduceMatmul<MatmulOp, MatvecOp>>(context);
1084
+ patterns.add <RankReduceMatmul<MatmulTransposeAOp, VecmatOp>>(context);
1085
+ patterns.add <RankReduceMatmul<MatmulTransposeBOp, MatvecOp>>(context);
1086
+ // Batch rank 1 reducing patterns
1087
+ patterns.add <RankReduceMatmul<BatchMatmulOp, BatchVecmatOp>>(context);
1088
+ patterns.add <RankReduceMatmul<BatchMatmulOp, BatchMatvecOp>>(context);
1089
+ patterns.add <RankReduceMatmul<BatchMatmulTransposeAOp, BatchVecmatOp>>(
1090
+ context);
1091
+ patterns.add <RankReduceMatmul<BatchMatmulTransposeBOp, BatchMatvecOp>>(
1092
+ context);
1093
+
1094
+ // Non-batch rank 0 reducing patterns
1095
+ patterns.add <RankReduceMatmul<MatvecOp, DotOp>>(context);
1096
+ patterns.add <RankReduceMatmul<VecmatOp, DotOp>>(context);
1097
+ }
0 commit comments