@@ -814,6 +814,300 @@ let arguments = (ins
814
814
}];
815
815
}
816
816
817
+ class OuterProductWideBase<string mnemonic,
818
+ list<Type> allowedInputVectorTypes,
819
+ list<Type> allowedResultVectorTypes,
820
+ int numOuterProducts> :
821
+ ArmSME_Op<mnemonic, [
822
+ ArmSMETileOpInterface,
823
+ AttrSizedOperandSegments,
824
+ AllTypesMatch<["lhs", "rhs"]>,
825
+ HasMatchingMaskTypeConstraint<"lhs", "lhsMask">,
826
+ HasMatchingMaskTypeConstraint<"rhs", "rhsMask">,
827
+ PredOpTrait<
828
+ "both `lhsMask` and `rhsMask` should be provided or neither",
829
+ CPred<"bool(getLhsMask()) == bool(getRhsMask())">
830
+ >,
831
+ OptionalTypesMatchWith<"result and acc have the same type",
832
+ "result", "acc", "::llvm::cast<Type>($_self)">,
833
+ // this trait ensures the input type match the correct output type for ops
834
+ // that takes multiple inputs and outputs (i.e., 4-way).
835
+ PredOpTrait<
836
+ "tile element size equals lhs element size * " # numOuterProducts,
837
+ CPred<"getTileType().getElementTypeBitWidth() == "
838
+ "(getLhsType().getElementTypeBitWidth() * " # numOuterProducts # ")">
839
+ >,
840
+ ]> {
841
+
842
+ let arguments = (ins
843
+ AnyTypeOf<allowedInputVectorTypes>:$lhs, AnyVector:$rhs,
844
+ Optional<AnyVector>:$lhsMask, Optional<AnyVector>:$rhsMask,
845
+ Optional<AnyVector>:$acc);
846
+ let results = (outs AnyTypeOf<allowedResultVectorTypes>:$result);
847
+
848
+ let assemblyFormat = [{
849
+ $lhs `,` $rhs
850
+ oilist(
851
+ `acc` `` `(` $acc `)`
852
+ | `masks` `` `(` $lhsMask `,` $rhsMask `)`
853
+ ) attr-dict `:` type($lhs) `,` type($rhs) `into` type($result)
854
+ }];
855
+
856
+ let extraClassDeclaration = [{
857
+ VectorType getLhsType() { return llvm::cast<VectorType>(getLhs().getType()); }
858
+ VectorType getRhsType() { return llvm::cast<VectorType>(getRhs().getType()); }
859
+ VectorType getResultType() { return llvm::cast<VectorType>(getResult().getType()); }
860
+ std::optional<arm_sme::ArmSMETileType> getAllocatedTileType() {
861
+ // The outerproduct op allocates a new tile if no accumulator is passed.
862
+ if (!getAcc())
863
+ return arm_sme::getSMETileType(getResultType());
864
+ return std::nullopt;
865
+ }
866
+ VectorType getTileType() {
867
+ return getResultType();
868
+ }
869
+ }];
870
+ }
871
+
872
+ class OuterProductWide2Way<string mnemonic,
873
+ list<Type> allowedInputVectorTypes,
874
+ list<Type> allowedResultVectorTypes>
875
+ : OuterProductWideBase<mnemonic, allowedInputVectorTypes,
876
+ allowedResultVectorTypes, /*numOuterProducts=*/2>;
877
+
878
+ def FMopaWide2WayOp
879
+ : OuterProductWide2Way<"fmopa_wide_2way",
880
+ [ScalableVectorOfRankAndLengthAndType<[1], [8], [F16, BF16]>],
881
+ [nxnxv4f32]> {
882
+ let summary = "Floating-point sum of 2 outer products and accumulate";
883
+
884
+ let description = [{
885
+ This operation represents a sum of 2 widened outer products. It takes 2 1-D
886
+ scalable vectors as input and a 2-D scalable vector (ZA tile) as output.
887
+
888
+ For example (fp16 to fp32):
889
+
890
+ ```mlir
891
+ %result = arm_sme.fmopa_wide_2way %lhs, %rhs :
892
+ vector<[8]xf16>, vector<[8]xf16> into vector<[4]x[4]xf32>
893
+ ```
894
+
895
+ The `lhs` encodes a matrix of shape SVLSx2 and the `rhs` a matrix of
896
+ 2xSVLS, where SVLS (spec [1], section B2.1) is the number of 32-bit
897
+ elements in a vector of SVL bits. To illustrate, below is a breakdown of
898
+ this operation for SVL=128 (i.e., vscale=1):
899
+
900
+ ```
901
+ LHS RHS
902
+ [A0 A1 A2 A3 A4 A5 A6 A7] [B0 B1 B2 B3 B4 B5 B6 B7]
903
+
904
+ ----------------------------------------------------------------------------
905
+
906
+ implicit layout
907
+
908
+ [A0 A1] |
909
+ [A2 A3] | [B0 B2 B4 B6]
910
+ [A4 A5] | [B1 B3 B5 B7]
911
+ [A6 A7] |
912
+
913
+ ----------------------------------------------------------------------------
914
+
915
+ 2 outer products
916
+
917
+ Acol0 ⊗ Brow0 | Acol1 ⊗ Brow1
918
+ ------------- | -------------
919
+ |
920
+ [B0 B2 B4 B6] | [B1 B3 B5 B7]
921
+ |
922
+ [A0 [A0B0 A0B2 A0B4 A0B6] | [A1 [A1B1 A1B3 A1B5 A1B7]
923
+ A2 [A2B0 A2B2 A2B4 A2B6] | A3 [A3B1 A3B3 A3B5 A3B7]
924
+ A4 [A4B0 A4B2 A4B4 A4B6] | A5 [A5B1 A5B3 A5B5 A5B7]
925
+ A6] [A6B0 A6B2 A6B4 A6B6] | A7] [A7B1 A7B3 A7B5 A7B7]
926
+ |
927
+
928
+ ----------------------------------------------------------------------------
929
+
930
+ sum of 2 outer products
931
+
932
+ Acol0 ⊗ Brow0 + Acol1 ⊗ Brow1
933
+
934
+ [A0B0 + A1B1 A0B2 + A1B3 A0B4 + A1B5 A0B6 + A1B7]
935
+ [A2B0 + A3B1 A2B2 + A3B3 A2B4 + A3B5 A2B6 + A3B7]
936
+ [A4B0 + A5B1 A4B2 + A5B3 A4B4 + A5B5 A4B6 + A5B7]
937
+ [A6B0 + A7B1 A6B2 + A7B3 A6B4 + A7B5 A6B6 + A7B7]
938
+
939
+ ----------------------------------------------------------------------------
940
+ ```
941
+
942
+ This operation enables the folding of 2 outer products chained via the
943
+ accumulator into a single outer product.
944
+
945
+ For example:
946
+
947
+ ```mlir
948
+ %a0_ext = arith.extf %a0 : vector<[4]xf16> to vector<[4]xf32>
949
+ %b0_ext = arith.extf %b0 : vector<[4]xf16> to vector<[4]xf32>
950
+ %a1_ext = arith.extf %a1 : vector<[4]xf16> to vector<[4]xf32>
951
+ %b1_ext = arith.extf %b1 : vector<[4]xf16> to vector<[4]xf32>
952
+
953
+ %0 = arm_sme.outerproduct %a0_ext, %b0_ext : vector<[4]xf32>, vector<[4]xf32>
954
+ %1 = arm_sme.outerproduct %a1_ext, %b1_ext acc(%0) : vector<[4]xf32>, vector<[4]xf32>
955
+ ```
956
+
957
+ The 2 outer products in the example above can be fused into a single outer
958
+ product as follows:
959
+
960
+ ```mlir
961
+ %undef = llvm.mlir.undef : vector<[8]xf16>
962
+ %a0_ins = vector.scalable.insert %a0, %undef[0] : vector<[4]xf16> into vector<[8]xf16>
963
+ %a1_ins = vector.scalable.insert %a1, %undef[0] : vector<[4]xf16> into vector<[8]xf16>
964
+ %a_packed = "arm_sve.intr.zip1"(%a0_ins, %a1_ins) : (vector<[8]xf16>, vector<[8]xf16>) -> vector<[8]xf16>
965
+ %b0_ins = vector.scalable.insert %b0, %undef[0] : vector<[4]xf16> into vector<[8]xf16>
966
+ %b1_ins = vector.scalable.insert %b1, %undef[0] : vector<[4]xf16> into vector<[8]xf16>
967
+ %b_packed = "arm_sve.intr.zip1"(%b0_ins, %b1_ins) : (vector<[8]xf16>, vector<[8]xf16>) -> vector<[8]xf16>
968
+ %0 = arm_sme.fmopa_wide_2way %a_packed, %b_packed : vector<[8]xf16>, vector<[8]xf16> into vector<[4]x[4]xf32>
969
+ ```
970
+
971
+ This is implemented in the `-arm-sme-outer-product-widening` pass.
972
+
973
+ Example: FP16 to FP32
974
+ ```mlir
975
+ %result = arm_sme.fmopa_wide_2way $lhs, $rhs : vector<[8]xf16>, vector<[8]xf16> into vector<[4]x[4]xf32>
976
+ ```
977
+
978
+ Example: BF16 to FP32
979
+ ```mlir
980
+ %result = arm_sme.fmopa_wide_2way $lhs, $rhs : vector<[8]xbf16>, vector<[8]xbf16> into vector<[4]x[4]xf32>
981
+ ```
982
+
983
+ | Spec | Features |
984
+ | ---- | -------- |
985
+ | [FMOPA (widening, 2-way, FP16 to FP32)](https://developer.arm.com/documentation/ddi0602/2023-09/SME-Instructions/FMOPA--widening--2-way--FP16-to-FP32---Half-precision-floating-point-sum-of-outer-products-and-accumulate-) | +sme |
986
+ | [BFMOPA (widening, 2-way, BF16 to FP32)](https://developer.arm.com/documentation/ddi0602/2023-09/SME-Instructions/BFMOPA--widening---BFloat16-sum-of-outer-products-and-accumulate-) | +sme |
987
+
988
+ [1] https://developer.arm.com/documentation/ddi0616
989
+ }];
990
+ }
991
+
992
+ // TODO: support:
993
+ // - FMOPA 2-way FP8 to FP16
994
+ // - FMOPA 4-way FP16 to FP32
995
+ // once intrinsic support lands in the backend.
996
+
997
+ def FMopsWide2WayOp
998
+ : OuterProductWide2Way<"fmops_wide_2way",
999
+ [ScalableVectorOfRankAndLengthAndType<[1], [8], [F16, BF16]>],
1000
+ [nxnxv4f32]> {
1001
+ let summary = "Floating-point sum of 2 outer products and subtract";
1002
+ let description = [{
1003
+ Equivalent to `fmopa_wide_2way` but outer products are subtracted from
1004
+ destination `result`.
1005
+
1006
+ Example: FP16 to FP32
1007
+ ```mlir
1008
+ %result = arm_sme.fmops_wide_2way $lhs, $rhs : vector<[8]xf16>, vector<[8]xf16> into vector<[4]x[4]xf32>
1009
+ ```
1010
+
1011
+ Example: BF16 to FP32
1012
+ ```mlir
1013
+ %result = arm_sme.fmops_wide_2way $lhs, $rhs : vector<[8]xbf16>, vector<[8]xbf16> into vector<[4]x[4]xf32>
1014
+
1015
+ Refer to
1016
+ [fmopa_wide_2way](#arm_smefmopa_wide_2way-arm_smefmopa_wide_2wayop) for a
1017
+ detailed description of 2-way outer products.
1018
+
1019
+ | Spec | Features |
1020
+ | ---- | -------- |
1021
+ | [FMOPS (widening, 2-way, FP16 to FP32)](https://developer.arm.com/documentation/ddi0602/2023-09/SME-Instructions/FMOPS--widening---Half-precision-floating-point-sum-of-outer-products-and-subtract-) | +sme |
1022
+ | [BFMOPS (widening, 2-way, BF16 to FP32)](https://developer.arm.com/documentation/ddi0602/2023-09/SME-Instructions/BMOPS--Bitwise-exclusive-NOR-population-count-outer-product-and-subtract-) | +sme |
1023
+ ```
1024
+ }];
1025
+ }
1026
+
1027
+ def SMopaWide2WayOp
1028
+ : OuterProductWide2Way<"smopa_wide_2way",
1029
+ [ScalableVectorOfRankAndLengthAndType<[1], [8], [I16]>],
1030
+ [nxnxv4i32]> {
1031
+ let summary = "Signed integer sum of 2 outer products and accumulate";
1032
+ let description = [{
1033
+ Example:
1034
+ ```mlir
1035
+ %result = arm_sme.smopa_wide_2way $lhs, $rhs : vector<[8]xi16>, vector<[8]xi16> into vector<[4]x[4]xi32>
1036
+
1037
+ Refer to
1038
+ [fmopa_wide_2way](#arm_smefmopa_wide_2way-arm_smefmopa_wide_2wayop) for a
1039
+ detailed description of 2-way outer products.
1040
+
1041
+ | Spec | Features |
1042
+ | ---- | -------- |
1043
+ | [SMOPA (2-way)](https://developer.arm.com/documentation/ddi0602/2023-09/SME-Instructions/SMOPA--2-way---Signed-integer-sum-of-outer-products-and-accumulate-) | +sme2 |
1044
+ ```
1045
+ }];
1046
+ }
1047
+
1048
+ def SMopsWide2WayOp
1049
+ : OuterProductWide2Way<"smops_wide_2way",
1050
+ [ScalableVectorOfRankAndLengthAndType<[1], [8], [I16]>],
1051
+ [nxnxv4i32]> {
1052
+ let summary = "Signed integer sum of 2 outer products and subtract";
1053
+ let description = [{
1054
+ Example:
1055
+ ```mlir
1056
+ %result = arm_sme.smops_wide_2way $lhs, $rhs : vector<[8]xi16>, vector<[8]xi16> into vector<[4]x[4]xi32>
1057
+
1058
+ Refer to
1059
+ [fmopa_wide_2way](#arm_smefmopa_wide_2way-arm_smefmopa_wide_2wayop) for a
1060
+ detailed description of 2-way outer products.
1061
+
1062
+ | Spec | Features |
1063
+ | ---- | -------- |
1064
+ | [SMOPS (2-way)](https://developer.arm.com/documentation/ddi0602/2023-09/SME-Instructions/SMOPS--2-way---Signed-integer-sum-of-outer-products-and-subtract-) | +sme2 |
1065
+ ```
1066
+ }];
1067
+ }
1068
+
1069
+ def UMopaWide2WayOp
1070
+ : OuterProductWide2Way<"umopa_wide_2way",
1071
+ [ScalableVectorOfRankAndLengthAndType<[1], [8], [I16]>],
1072
+ [nxnxv4i32]> {
1073
+ let summary = "Unsiged integer sum of 2 outer products and accumulate";
1074
+ let description = [{
1075
+ Example:
1076
+ ```mlir
1077
+ %result = arm_sme.umopa_wide_2way $lhs, $rhs : vector<[8]xi16>, vector<[8]xi16> into vector<[4]x[4]xi32>
1078
+
1079
+ Refer to
1080
+ [fmopa_wide_2way](#arm_smefmopa_wide_2way-arm_smefmopa_wide_2wayop) for a
1081
+ detailed description of 2-way outer products.
1082
+
1083
+ | Spec | Features |
1084
+ | ---- | -------- |
1085
+ | [UMOPA (2-way)](https://developer.arm.com/documentation/ddi0602/2023-09/SME-Instructions/UMOPA--2-way---Unsigned-integer-sum-of-outer-products-and-accumulate-) | +sme2 |
1086
+ ```
1087
+ }];
1088
+ }
1089
+
1090
+ def UMopsWide2WayOp
1091
+ : OuterProductWide2Way<"umops_wide_2way",
1092
+ [ScalableVectorOfRankAndLengthAndType<[1], [8], [I16]>],
1093
+ [nxnxv4i32]> {
1094
+ let summary = "Unsiged integer sum of 2 outer products and subtract";
1095
+ let description = [{
1096
+ Example:
1097
+ ```mlir
1098
+ %result = arm_sme.umops_wide_2way $lhs, $rhs : vector<[8]xi16>, vector<[8]xi16> into vector<[4]x[4]xi32>
1099
+
1100
+ Refer to
1101
+ [fmopa_wide_2way](#arm_smefmopa_wide_2way-arm_smefmopa_wide_2wayop) for a
1102
+ detailed description of 2-way outer products.
1103
+
1104
+ | Spec | Features |
1105
+ | ---- | -------- |
1106
+ | [UMOPS (2-way)](https://developer.arm.com/documentation/ddi0602/2023-09/SME-Instructions/UMOPS--2-way---Unsigned-integer-sum-of-outer-products-and-subtract-) | +sme2 |
1107
+ ```
1108
+ }];
1109
+ }
1110
+
817
1111
def StreamingVLOp : ArmSME_Op<"streaming_vl", [Pure]>
818
1112
{
819
1113
let summary = "Query the streaming vector length";
0 commit comments