Skip to content

Commit 455ae29

Browse files
committed
[mlir][ArmSME] Support 2-way widening outer products
This patch introduces support for 2-way widening outer products. This enables the folding of 2 'arm_sme.outerproduct' operations that are chained via the accumulator into single widened operations. Changes: - Add 'llvm.aarch64.sme.[us]mop[as].za32' intrinsics for 2-way variants. These map to instruction variants added in SME2 and use different intrinsics. Intrinsics are already implemented for widening variants from SME1. - Adds the following operations: - fmopa_wide_2way, fmops_wide_2way - smopa_wide_2way, smops_wide_2way - umopa_wide_2way, umops_wide_2way - Implements conversions for the above ops to intrinsics in ArmSMEToLLVM. - Adds a pass 'arm-sme-outer-product' widening that folds 'arm_sme.outerproduct' operations. For a detailed description of these operations see the 'arm_sme.fmopa_wide_2way' description. The reason for introducing many operations rather than one is the signed/unsigned variants can't be distinguished with types (e.g., ui16, si16) since 'arith.extui' and 'arith.extsi' only support signless integers. A single operation would require this information and an attribute (for example) for the sign doesn't feel right if floating-point types are also supported where this wouldn't apply. Furthermore, the SME FP8 extensions (FEAT_SME_F8F16, FEAT_SME_F8F32) introduce FMOPA 2-way (FP8 to FP16) and 4-way (FP8 to FP32) variants but no subtract variant. Whilst these are not supported in this patch, it felt simpler to have separate ops for add/subtract given this.
1 parent f1fe6a2 commit 455ae29

File tree

14 files changed

+1213
-2
lines changed

14 files changed

+1213
-2
lines changed

mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEIntrinsicOps.td

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,10 @@ def LLVM_aarch64_sme_sumopa_wide : ArmSME_IntrMopOverloadedOp<"sumopa.wide">;
105105
def LLVM_aarch64_sme_sumops_wide : ArmSME_IntrMopOverloadedOp<"sumops.wide">;
106106
def LLVM_aarch64_sme_usmopa_wide : ArmSME_IntrMopOverloadedOp<"usmopa.wide">;
107107
def LLVM_aarch64_sme_usmops_wide : ArmSME_IntrMopOverloadedOp<"usmops.wide">;
108+
def LLVM_aarch64_sme_smopa_za32 : ArmSME_IntrMopOverloadedOp<"smopa.za32">;
109+
def LLVM_aarch64_sme_umopa_za32 : ArmSME_IntrMopOverloadedOp<"umopa.za32">;
110+
def LLVM_aarch64_sme_smops_za32 : ArmSME_IntrMopOverloadedOp<"smops.za32">;
111+
def LLVM_aarch64_sme_umops_za32 : ArmSME_IntrMopOverloadedOp<"umops.za32">;
108112

109113
class ArmSME_IntrLoadStoreOp<string mnemonic>
110114
: ArmSME_IntrOp<mnemonic,

mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td

Lines changed: 294 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -814,6 +814,300 @@ let arguments = (ins
814814
}];
815815
}
816816

817+
class OuterProductWideBase<string mnemonic,
818+
list<Type> allowedInputVectorTypes,
819+
list<Type> allowedResultVectorTypes,
820+
int numOuterProducts> :
821+
ArmSME_Op<mnemonic, [
822+
ArmSMETileOpInterface,
823+
AttrSizedOperandSegments,
824+
AllTypesMatch<["lhs", "rhs"]>,
825+
HasMatchingMaskTypeConstraint<"lhs", "lhsMask">,
826+
HasMatchingMaskTypeConstraint<"rhs", "rhsMask">,
827+
PredOpTrait<
828+
"both `lhsMask` and `rhsMask` should be provided or neither",
829+
CPred<"bool(getLhsMask()) == bool(getRhsMask())">
830+
>,
831+
OptionalTypesMatchWith<"result and acc have the same type",
832+
"result", "acc", "::llvm::cast<Type>($_self)">,
833+
// this trait ensures the input type match the correct output type for ops
834+
// that takes multiple inputs and outputs (i.e., 4-way).
835+
PredOpTrait<
836+
"tile element size equals lhs element size * " # numOuterProducts,
837+
CPred<"getTileType().getElementTypeBitWidth() == "
838+
"(getLhsType().getElementTypeBitWidth() * " # numOuterProducts # ")">
839+
>,
840+
]> {
841+
842+
let arguments = (ins
843+
AnyTypeOf<allowedInputVectorTypes>:$lhs, AnyVector:$rhs,
844+
Optional<AnyVector>:$lhsMask, Optional<AnyVector>:$rhsMask,
845+
Optional<AnyVector>:$acc);
846+
let results = (outs AnyTypeOf<allowedResultVectorTypes>:$result);
847+
848+
let assemblyFormat = [{
849+
$lhs `,` $rhs
850+
oilist(
851+
`acc` `` `(` $acc `)`
852+
| `masks` `` `(` $lhsMask `,` $rhsMask `)`
853+
) attr-dict `:` type($lhs) `,` type($rhs) `into` type($result)
854+
}];
855+
856+
let extraClassDeclaration = [{
857+
VectorType getLhsType() { return llvm::cast<VectorType>(getLhs().getType()); }
858+
VectorType getRhsType() { return llvm::cast<VectorType>(getRhs().getType()); }
859+
VectorType getResultType() { return llvm::cast<VectorType>(getResult().getType()); }
860+
std::optional<arm_sme::ArmSMETileType> getAllocatedTileType() {
861+
// The outerproduct op allocates a new tile if no accumulator is passed.
862+
if (!getAcc())
863+
return arm_sme::getSMETileType(getResultType());
864+
return std::nullopt;
865+
}
866+
VectorType getTileType() {
867+
return getResultType();
868+
}
869+
}];
870+
}
871+
872+
class OuterProductWide2Way<string mnemonic,
873+
list<Type> allowedInputVectorTypes,
874+
list<Type> allowedResultVectorTypes>
875+
: OuterProductWideBase<mnemonic, allowedInputVectorTypes,
876+
allowedResultVectorTypes, /*numOuterProducts=*/2>;
877+
878+
def FMopaWide2WayOp
879+
: OuterProductWide2Way<"fmopa_wide_2way",
880+
[ScalableVectorOfRankAndLengthAndType<[1], [8], [F16, BF16]>],
881+
[nxnxv4f32]> {
882+
let summary = "Floating-point sum of 2 outer products and accumulate";
883+
884+
let description = [{
885+
This operation represents a sum of 2 widened outer products. It takes 2 1-D
886+
scalable vectors as input and a 2-D scalable vector (ZA tile) as output.
887+
888+
For example (fp16 to fp32):
889+
890+
```mlir
891+
%result = arm_sme.fmopa_wide_2way %lhs, %rhs :
892+
vector<[8]xf16>, vector<[8]xf16> into vector<[4]x[4]xf32>
893+
```
894+
895+
The `lhs` encodes a matrix of shape SVLSx2 and the `rhs` a matrix of
896+
2xSVLS, where SVLS (spec [1], section B2.1) is the number of 32-bit
897+
elements in a vector of SVL bits. To illustrate, below is a breakdown of
898+
this operation for SVL=128 (i.e., vscale=1):
899+
900+
```
901+
LHS RHS
902+
[A0 A1 A2 A3 A4 A5 A6 A7] [B0 B1 B2 B3 B4 B5 B6 B7]
903+
904+
----------------------------------------------------------------------------
905+
906+
implicit layout
907+
908+
[A0 A1] |
909+
[A2 A3] | [B0 B2 B4 B6]
910+
[A4 A5] | [B1 B3 B5 B7]
911+
[A6 A7] |
912+
913+
----------------------------------------------------------------------------
914+
915+
2 outer products
916+
917+
Acol0 ⊗ Brow0 | Acol1 ⊗ Brow1
918+
------------- | -------------
919+
|
920+
[B0 B2 B4 B6] | [B1 B3 B5 B7]
921+
|
922+
[A0 [A0B0 A0B2 A0B4 A0B6] | [A1 [A1B1 A1B3 A1B5 A1B7]
923+
A2 [A2B0 A2B2 A2B4 A2B6] | A3 [A3B1 A3B3 A3B5 A3B7]
924+
A4 [A4B0 A4B2 A4B4 A4B6] | A5 [A5B1 A5B3 A5B5 A5B7]
925+
A6] [A6B0 A6B2 A6B4 A6B6] | A7] [A7B1 A7B3 A7B5 A7B7]
926+
|
927+
928+
----------------------------------------------------------------------------
929+
930+
sum of 2 outer products
931+
932+
Acol0 ⊗ Brow0 + Acol1 ⊗ Brow1
933+
934+
[A0B0 + A1B1 A0B2 + A1B3 A0B4 + A1B5 A0B6 + A1B7]
935+
[A2B0 + A3B1 A2B2 + A3B3 A2B4 + A3B5 A2B6 + A3B7]
936+
[A4B0 + A5B1 A4B2 + A5B3 A4B4 + A5B5 A4B6 + A5B7]
937+
[A6B0 + A7B1 A6B2 + A7B3 A6B4 + A7B5 A6B6 + A7B7]
938+
939+
----------------------------------------------------------------------------
940+
```
941+
942+
This operation enables the folding of 2 outer products chained via the
943+
accumulator into a single outer product.
944+
945+
For example:
946+
947+
```mlir
948+
%a0_ext = arith.extf %a0 : vector<[4]xf16> to vector<[4]xf32>
949+
%b0_ext = arith.extf %b0 : vector<[4]xf16> to vector<[4]xf32>
950+
%a1_ext = arith.extf %a1 : vector<[4]xf16> to vector<[4]xf32>
951+
%b1_ext = arith.extf %b1 : vector<[4]xf16> to vector<[4]xf32>
952+
953+
%0 = arm_sme.outerproduct %a0_ext, %b0_ext : vector<[4]xf32>, vector<[4]xf32>
954+
%1 = arm_sme.outerproduct %a1_ext, %b1_ext acc(%0) : vector<[4]xf32>, vector<[4]xf32>
955+
```
956+
957+
The 2 outer products in the example above can be fused into a single outer
958+
product as follows:
959+
960+
```mlir
961+
%undef = llvm.mlir.undef : vector<[8]xf16>
962+
%a0_ins = vector.scalable.insert %a0, %undef[0] : vector<[4]xf16> into vector<[8]xf16>
963+
%a1_ins = vector.scalable.insert %a1, %undef[0] : vector<[4]xf16> into vector<[8]xf16>
964+
%a_packed = "arm_sve.intr.zip1"(%a0_ins, %a1_ins) : (vector<[8]xf16>, vector<[8]xf16>) -> vector<[8]xf16>
965+
%b0_ins = vector.scalable.insert %b0, %undef[0] : vector<[4]xf16> into vector<[8]xf16>
966+
%b1_ins = vector.scalable.insert %b1, %undef[0] : vector<[4]xf16> into vector<[8]xf16>
967+
%b_packed = "arm_sve.intr.zip1"(%b0_ins, %b1_ins) : (vector<[8]xf16>, vector<[8]xf16>) -> vector<[8]xf16>
968+
%0 = arm_sme.fmopa_wide_2way %a_packed, %b_packed : vector<[8]xf16>, vector<[8]xf16> into vector<[4]x[4]xf32>
969+
```
970+
971+
This is implemented in the `-arm-sme-outer-product-widening` pass.
972+
973+
Example: FP16 to FP32
974+
```mlir
975+
%result = arm_sme.fmopa_wide_2way $lhs, $rhs : vector<[8]xf16>, vector<[8]xf16> into vector<[4]x[4]xf32>
976+
```
977+
978+
Example: BF16 to FP32
979+
```mlir
980+
%result = arm_sme.fmopa_wide_2way $lhs, $rhs : vector<[8]xbf16>, vector<[8]xbf16> into vector<[4]x[4]xf32>
981+
```
982+
983+
| Spec | Features |
984+
| ---- | -------- |
985+
| [FMOPA (widening, 2-way, FP16 to FP32)](https://developer.arm.com/documentation/ddi0602/2023-09/SME-Instructions/FMOPA--widening--2-way--FP16-to-FP32---Half-precision-floating-point-sum-of-outer-products-and-accumulate-) | +sme |
986+
| [BFMOPA (widening, 2-way, BF16 to FP32)](https://developer.arm.com/documentation/ddi0602/2023-09/SME-Instructions/BFMOPA--widening---BFloat16-sum-of-outer-products-and-accumulate-) | +sme |
987+
988+
[1] https://developer.arm.com/documentation/ddi0616
989+
}];
990+
}
991+
992+
// TODO: support:
993+
// - FMOPA 2-way FP8 to FP16
994+
// - FMOPA 4-way FP16 to FP32
995+
// once intrinsic support lands in the backend.
996+
997+
def FMopsWide2WayOp
998+
: OuterProductWide2Way<"fmops_wide_2way",
999+
[ScalableVectorOfRankAndLengthAndType<[1], [8], [F16, BF16]>],
1000+
[nxnxv4f32]> {
1001+
let summary = "Floating-point sum of 2 outer products and subtract";
1002+
let description = [{
1003+
Equivalent to `fmopa_wide_2way` but outer products are subtracted from
1004+
destination `result`.
1005+
1006+
Example: FP16 to FP32
1007+
```mlir
1008+
%result = arm_sme.fmops_wide_2way $lhs, $rhs : vector<[8]xf16>, vector<[8]xf16> into vector<[4]x[4]xf32>
1009+
```
1010+
1011+
Example: BF16 to FP32
1012+
```mlir
1013+
%result = arm_sme.fmops_wide_2way $lhs, $rhs : vector<[8]xbf16>, vector<[8]xbf16> into vector<[4]x[4]xf32>
1014+
1015+
Refer to
1016+
[fmopa_wide_2way](#arm_smefmopa_wide_2way-arm_smefmopa_wide_2wayop) for a
1017+
detailed description of 2-way outer products.
1018+
1019+
| Spec | Features |
1020+
| ---- | -------- |
1021+
| [FMOPS (widening, 2-way, FP16 to FP32)](https://developer.arm.com/documentation/ddi0602/2023-09/SME-Instructions/FMOPS--widening---Half-precision-floating-point-sum-of-outer-products-and-subtract-) | +sme |
1022+
| [BFMOPS (widening, 2-way, BF16 to FP32)](https://developer.arm.com/documentation/ddi0602/2023-09/SME-Instructions/BMOPS--Bitwise-exclusive-NOR-population-count-outer-product-and-subtract-) | +sme |
1023+
```
1024+
}];
1025+
}
1026+
1027+
def SMopaWide2WayOp
1028+
: OuterProductWide2Way<"smopa_wide_2way",
1029+
[ScalableVectorOfRankAndLengthAndType<[1], [8], [I16]>],
1030+
[nxnxv4i32]> {
1031+
let summary = "Signed integer sum of 2 outer products and accumulate";
1032+
let description = [{
1033+
Example:
1034+
```mlir
1035+
%result = arm_sme.smopa_wide_2way $lhs, $rhs : vector<[8]xi16>, vector<[8]xi16> into vector<[4]x[4]xi32>
1036+
1037+
Refer to
1038+
[fmopa_wide_2way](#arm_smefmopa_wide_2way-arm_smefmopa_wide_2wayop) for a
1039+
detailed description of 2-way outer products.
1040+
1041+
| Spec | Features |
1042+
| ---- | -------- |
1043+
| [SMOPA (2-way)](https://developer.arm.com/documentation/ddi0602/2023-09/SME-Instructions/SMOPA--2-way---Signed-integer-sum-of-outer-products-and-accumulate-) | +sme2 |
1044+
```
1045+
}];
1046+
}
1047+
1048+
def SMopsWide2WayOp
1049+
: OuterProductWide2Way<"smops_wide_2way",
1050+
[ScalableVectorOfRankAndLengthAndType<[1], [8], [I16]>],
1051+
[nxnxv4i32]> {
1052+
let summary = "Signed integer sum of 2 outer products and subtract";
1053+
let description = [{
1054+
Example:
1055+
```mlir
1056+
%result = arm_sme.smops_wide_2way $lhs, $rhs : vector<[8]xi16>, vector<[8]xi16> into vector<[4]x[4]xi32>
1057+
1058+
Refer to
1059+
[fmopa_wide_2way](#arm_smefmopa_wide_2way-arm_smefmopa_wide_2wayop) for a
1060+
detailed description of 2-way outer products.
1061+
1062+
| Spec | Features |
1063+
| ---- | -------- |
1064+
| [SMOPS (2-way)](https://developer.arm.com/documentation/ddi0602/2023-09/SME-Instructions/SMOPS--2-way---Signed-integer-sum-of-outer-products-and-subtract-) | +sme2 |
1065+
```
1066+
}];
1067+
}
1068+
1069+
def UMopaWide2WayOp
1070+
: OuterProductWide2Way<"umopa_wide_2way",
1071+
[ScalableVectorOfRankAndLengthAndType<[1], [8], [I16]>],
1072+
[nxnxv4i32]> {
1073+
let summary = "Unsiged integer sum of 2 outer products and accumulate";
1074+
let description = [{
1075+
Example:
1076+
```mlir
1077+
%result = arm_sme.umopa_wide_2way $lhs, $rhs : vector<[8]xi16>, vector<[8]xi16> into vector<[4]x[4]xi32>
1078+
1079+
Refer to
1080+
[fmopa_wide_2way](#arm_smefmopa_wide_2way-arm_smefmopa_wide_2wayop) for a
1081+
detailed description of 2-way outer products.
1082+
1083+
| Spec | Features |
1084+
| ---- | -------- |
1085+
| [UMOPA (2-way)](https://developer.arm.com/documentation/ddi0602/2023-09/SME-Instructions/UMOPA--2-way---Unsigned-integer-sum-of-outer-products-and-accumulate-) | +sme2 |
1086+
```
1087+
}];
1088+
}
1089+
1090+
def UMopsWide2WayOp
1091+
: OuterProductWide2Way<"umops_wide_2way",
1092+
[ScalableVectorOfRankAndLengthAndType<[1], [8], [I16]>],
1093+
[nxnxv4i32]> {
1094+
let summary = "Unsiged integer sum of 2 outer products and subtract";
1095+
let description = [{
1096+
Example:
1097+
```mlir
1098+
%result = arm_sme.umops_wide_2way $lhs, $rhs : vector<[8]xi16>, vector<[8]xi16> into vector<[4]x[4]xi32>
1099+
1100+
Refer to
1101+
[fmopa_wide_2way](#arm_smefmopa_wide_2way-arm_smefmopa_wide_2wayop) for a
1102+
detailed description of 2-way outer products.
1103+
1104+
| Spec | Features |
1105+
| ---- | -------- |
1106+
| [UMOPS (2-way)](https://developer.arm.com/documentation/ddi0602/2023-09/SME-Instructions/UMOPS--2-way---Unsigned-integer-sum-of-outer-products-and-subtract-) | +sme2 |
1107+
```
1108+
}];
1109+
}
1110+
8171111
def StreamingVLOp : ArmSME_Op<"streaming_vl", [Pure]>
8181112
{
8191113
let summary = "Query the streaming vector length";

mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,9 @@ std::unique_ptr<Pass> createEnableArmStreamingPass(
3232
/// Pass that allocates tile IDs to ArmSME operations.
3333
std::unique_ptr<Pass> createTileAllocationPass();
3434

35+
/// Pass that folds 'arm_sme.outerproduct' ops into widening variants.
36+
std::unique_ptr<Pass> createOuterProductWideningPass();
37+
3538
//===----------------------------------------------------------------------===//
3639
// Registration
3740
//===----------------------------------------------------------------------===//

mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,4 +122,43 @@ def TileAllocation
122122
let dependentDialects = ["func::FuncDialect"];
123123
}
124124

125+
def OuterProductWidening
126+
: Pass<"arm-sme-outer-product-widening", "mlir::func::FuncOp"> {
127+
let summary = "Fold 'arm_sme.outerproduct' operations into widening variants";
128+
let description = [{
129+
This pass folds 'arm_sme.outerproduct' operations that are chained via the
130+
accumulator into 2-way or 4-way ArmSME outer product operations.
131+
132+
For example:
133+
```mlir
134+
%a0_ext = arith.extf %a0 : vector<[4]xf16> to vector<[4]xf32>
135+
%b0_ext = arith.extf %b0 : vector<[4]xf16> to vector<[4]xf32>
136+
%a1_ext = arith.extf %a1 : vector<[4]xf16> to vector<[4]xf32>
137+
%b1_ext = arith.extf %b1 : vector<[4]xf16> to vector<[4]xf32>
138+
139+
%0 = arm_sme.outerproduct %a0_ext, %b0_ext : vector<[4]xf32>, vector<[4]xf32>
140+
%1 = arm_sme.outerproduct %a1_ext, %b1_ext acc(%0) : vector<[4]xf32>, vector<[4]xf32>
141+
```
142+
143+
Becomes:
144+
145+
```mlir
146+
%undef = llvm.mlir.undef : vector<[8]xf16>
147+
%a0_ins = vector.scalable.insert %a0, %undef[0] : vector<[4]xf16> into vector<[8]xf16>
148+
%a1_ins = vector.scalable.insert %a1, %undef[0] : vector<[4]xf16> into vector<[8]xf16>
149+
%a_packed = "arm_sve.intr.zip1"(%a0_ins, %a1_ins) : (vector<[8]xf16>, vector<[8]xf16>) -> vector<[8]xf16>
150+
%b0_ins = vector.scalable.insert %b0, %undef[0] : vector<[4]xf16> into vector<[8]xf16>
151+
%b1_ins = vector.scalable.insert %b1, %undef[0] : vector<[4]xf16> into vector<[8]xf16>
152+
%b_packed = "arm_sve.intr.zip1"(%b0_ins, %b1_ins) : (vector<[8]xf16>, vector<[8]xf16>) -> vector<[8]xf16>
153+
%0 = arm_sme.fmopa_wide_2way %a_packed, %b_packed : vector<[8]xf16>, vector<[8]xf16> into vector<[4]x[4]xf32>
154+
```
155+
156+
For further information on the widening ops see:
157+
https://mlir.llvm.org/docs/Dialects/ArmSME/#arm_smefmopa_wide_2way-arm_smefmopa_wide_2wayop
158+
https://mlir.llvm.org/docs/Dialects/ArmSME/#arm_smesmopa_wide_4way-arm_smesmopa_wide_4wayop
159+
}];
160+
let constructor = "mlir::arm_sme::createOuterProductWideningPass()";
161+
let dependentDialects = ["func::FuncDialect", "arm_sme::ArmSMEDialect", "arm_sve::ArmSVEDialect", "LLVM::LLVMDialect"];
162+
}
163+
125164
#endif // MLIR_DIALECT_ARMSME_TRANSFORMS_PASSES_TD

mlir/include/mlir/Dialect/ArmSME/Transforms/Transforms.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,10 @@ class LLVMConversionTarget;
1515
class LLVMTypeConverter;
1616
class RewritePatternSet;
1717

18+
namespace arm_sme {
19+
void populateOuterProductWideningPatterns(RewritePatternSet &patterns);
20+
} // namespace arm_sme
21+
1822
} // namespace mlir
1923

2024
#endif // MLIR_DIALECT_ARMSME_TRANSFORMS_H

0 commit comments

Comments
 (0)