-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][ArmSME] Support 2-way widening outer products #78975
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
601bbae
26f705d
589c0d2
b9e3a5b
51188d0
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -814,6 +814,295 @@ let arguments = (ins | |
}]; | ||
} | ||
|
||
class OuterProductWideningBase<string mnemonic, | ||
list<Type> allowedInputVectorTypes, | ||
list<Type> allowedResultVectorTypes, | ||
int numOuterProducts> : | ||
ArmSME_Op<mnemonic, [ | ||
ArmSMETileOpInterface, | ||
AttrSizedOperandSegments, | ||
AllTypesMatch<["lhs", "rhs"]>, | ||
HasMatchingMaskTypeConstraint<"lhs", "lhsMask">, | ||
HasMatchingMaskTypeConstraint<"rhs", "rhsMask">, | ||
PredOpTrait< | ||
"both `lhsMask` and `rhsMask` should be provided or neither", | ||
CPred<"bool(getLhsMask()) == bool(getRhsMask())"> | ||
>, | ||
OptionalTypesMatchWith<"`result` and `acc` have the same type", | ||
"result", "acc", "::llvm::cast<Type>($_self)">, | ||
// This trait ensures the input types match the correct output type for ops | ||
// that takes multiple inputs and outputs (i.e., 4-way). | ||
PredOpTrait< | ||
"tile element size equals input element size * " # numOuterProducts, | ||
CPred<"getTileType().getElementTypeBitWidth() == " | ||
"(getLhsType().getElementTypeBitWidth() * " # numOuterProducts # ")"> | ||
>, | ||
]> { | ||
|
||
let arguments = (ins | ||
AnyTypeOf<allowedInputVectorTypes>:$lhs, AnyVector:$rhs, | ||
Optional<AnyVector>:$lhsMask, Optional<AnyVector>:$rhsMask, | ||
Optional<AnyVector>:$acc); | ||
Comment on lines
+844
to
+845
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. functionally it makes no difference, this could literally be It's a bit unfortunate because the auto-generated op documentation will say |
||
let results = (outs AnyTypeOf<allowedResultVectorTypes>:$result); | ||
|
||
let assemblyFormat = [{ | ||
$lhs `,` $rhs | ||
oilist( | ||
`acc` `` `(` $acc `)` | ||
| `masks` `` `(` $lhsMask `,` $rhsMask `)` | ||
) attr-dict `:` type($lhs) `,` type($rhs) `into` type($result) | ||
}]; | ||
|
||
let extraClassDeclaration = [{ | ||
VectorType getLhsType() { return llvm::cast<VectorType>(getLhs().getType()); } | ||
VectorType getRhsType() { return llvm::cast<VectorType>(getRhs().getType()); } | ||
VectorType getResultType() { return llvm::cast<VectorType>(getResult().getType()); } | ||
std::optional<arm_sme::ArmSMETileType> getAllocatedTileType() { | ||
// The outerproduct op allocates a new tile if no accumulator is passed. | ||
if (!getAcc()) | ||
return arm_sme::getSMETileType(getResultType()); | ||
return std::nullopt; | ||
} | ||
VectorType getTileType() { | ||
return getResultType(); | ||
} | ||
}]; | ||
} | ||
|
||
class OuterProduct2Way<string mnemonic, | ||
list<Type> allowedInputVectorTypes, | ||
list<Type> allowedResultVectorTypes> | ||
: OuterProductWideningBase<mnemonic, allowedInputVectorTypes, | ||
allowedResultVectorTypes, /*numOuterProducts=*/2>; | ||
|
||
def FMopa2WayOp | ||
: OuterProduct2Way<"fmopa_2way", | ||
[ScalableVectorOfRankAndLengthAndType<[1], [8], [F16, BF16]>], | ||
[nxnxv4f32]> { | ||
let summary = "Floating-point sum of 2 outer products and accumulate"; | ||
|
||
let description = [{ | ||
This operation represents a sum of 2 widened outer products. It takes 2 1-D | ||
scalable vectors as input and a 2-D scalable vector (ZA tile) as output. | ||
banach-space marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
For example (fp16 to fp32): | ||
|
||
```mlir | ||
%result = arm_sme.fmopa_2way %lhs, %rhs : | ||
vector<[8]xf16>, vector<[8]xf16> into vector<[4]x[4]xf32> | ||
``` | ||
|
||
The `lhs` encodes a matrix of shape SVLSx2 and the `rhs` a matrix of | ||
2xSVLS, where SVLS (spec [1], section B2.1) is the number of 32-bit | ||
elements in a vector of SVL bits. To illustrate, below is a breakdown of | ||
this operation for fp16 to fp32, SVL=128 (i.e., vscale=1): | ||
|
||
``` | ||
LHS RHS | ||
[A0 A1 A2 A3 A4 A5 A6 A7] [B0 B1 B2 B3 B4 B5 B6 B7] | ||
c-rhodes marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
---------------------------------------------------------------------------- | ||
|
||
implicit layout | ||
|
||
[A0 A1] | | ||
[A2 A3] | [B0 B2 B4 B6] | ||
[A4 A5] | [B1 B3 B5 B7] | ||
[A6 A7] | | ||
|
||
---------------------------------------------------------------------------- | ||
|
||
2 outer products | ||
|
||
Acol0 ⊗ Brow0 | Acol1 ⊗ Brow1 | ||
------------- | ------------- | ||
| | ||
[B0 B2 B4 B6] | [B1 B3 B5 B7] | ||
| | ||
[A0 [A0B0 A0B2 A0B4 A0B6] | [A1 [A1B1 A1B3 A1B5 A1B7] | ||
A2 [A2B0 A2B2 A2B4 A2B6] | A3 [A3B1 A3B3 A3B5 A3B7] | ||
A4 [A4B0 A4B2 A4B4 A4B6] | A5 [A5B1 A5B3 A5B5 A5B7] | ||
A6] [A6B0 A6B2 A6B4 A6B6] | A7] [A7B1 A7B3 A7B5 A7B7] | ||
| | ||
|
||
---------------------------------------------------------------------------- | ||
|
||
sum of 2 outer products | ||
|
||
Acol0 ⊗ Brow0 + Acol1 ⊗ Brow1 | ||
|
||
[A0B0 + A1B1 A0B2 + A1B3 A0B4 + A1B5 A0B6 + A1B7] | ||
[A2B0 + A3B1 A2B2 + A3B3 A2B4 + A3B5 A2B6 + A3B7] | ||
[A4B0 + A5B1 A4B2 + A5B3 A4B4 + A5B5 A4B6 + A5B7] | ||
[A6B0 + A7B1 A6B2 + A7B3 A6B4 + A7B5 A6B6 + A7B7] | ||
|
||
---------------------------------------------------------------------------- | ||
``` | ||
|
||
This operation enables the folding of 2 outer products chained via the | ||
accumulator into a single outer product. | ||
|
||
For example: | ||
|
||
```mlir | ||
%a0_ext = arith.extf %a0 : vector<[4]xf16> to vector<[4]xf32> | ||
%b0_ext = arith.extf %b0 : vector<[4]xf16> to vector<[4]xf32> | ||
%a1_ext = arith.extf %a1 : vector<[4]xf16> to vector<[4]xf32> | ||
%b1_ext = arith.extf %b1 : vector<[4]xf16> to vector<[4]xf32> | ||
|
||
%0 = arm_sme.outerproduct %a0_ext, %b0_ext : vector<[4]xf32>, vector<[4]xf32> | ||
%1 = arm_sme.outerproduct %a1_ext, %b1_ext acc(%0) : vector<[4]xf32>, vector<[4]xf32> | ||
``` | ||
|
||
The 2 outer products in the example above can be fused into a single outer | ||
product as follows: | ||
|
||
```mlir | ||
%a_packed = "llvm.intr.experimental.vector.interleave2"(%a0, %a1) : (vector<[4]xf16>, vector<[4]xf16>) -> vector<[8]xf16> | ||
%b_packed = "llvm.intr.experimental.vector.interleave2"(%b0, %b1) : (vector<[4]xf16>, vector<[4]xf16>) -> vector<[8]xf16> | ||
%0 = arm_sme.fmopa_2way %a_packed, %b_packed : vector<[8]xf16>, vector<[8]xf16> into vector<[4]x[4]xf32> | ||
``` | ||
|
||
This is implemented in the `-arm-sme-outer-product-fusion` pass. | ||
|
||
Example: FP16 to FP32 | ||
```mlir | ||
%result = arm_sme.fmopa_2way $lhs, $rhs : vector<[8]xf16>, vector<[8]xf16> into vector<[4]x[4]xf32> | ||
``` | ||
|
||
Example: BF16 to FP32 | ||
```mlir | ||
%result = arm_sme.fmopa_2way $lhs, $rhs : vector<[8]xbf16>, vector<[8]xbf16> into vector<[4]x[4]xf32> | ||
``` | ||
|
||
| Spec | Features | | ||
| ---- | -------- | | ||
| [FMOPA (widening, 2-way, FP16 to FP32)](https://developer.arm.com/documentation/ddi0602/2023-09/SME-Instructions/FMOPA--widening--2-way--FP16-to-FP32---Half-precision-floating-point-sum-of-outer-products-and-accumulate-) | +sme | | ||
| [BFMOPA (widening, 2-way, BF16 to FP32)](https://developer.arm.com/documentation/ddi0602/2023-09/SME-Instructions/BFMOPA--widening---BFloat16-sum-of-outer-products-and-accumulate-) | +sme | | ||
|
||
[1] https://developer.arm.com/documentation/ddi0616 | ||
}]; | ||
} | ||
|
||
// TODO: support: | ||
// - FMOPA 2-way FP8 to FP16 | ||
// - FMOPA 4-way FP16 to FP32 | ||
// once intrinsic support lands in the backend. | ||
|
||
def FMops2WayOp | ||
: OuterProduct2Way<"fmops_2way", | ||
[ScalableVectorOfRankAndLengthAndType<[1], [8], [F16, BF16]>], | ||
[nxnxv4f32]> { | ||
let summary = "Floating-point sum of 2 outer products and subtract"; | ||
let description = [{ | ||
Equivalent to `fmopa_2way` but outer products are subtracted from | ||
destination `result`. | ||
|
||
Example: FP16 to FP32 | ||
```mlir | ||
%result = arm_sme.fmops_2way $lhs, $rhs : vector<[8]xf16>, vector<[8]xf16> into vector<[4]x[4]xf32> | ||
``` | ||
|
||
Example: BF16 to FP32 | ||
```mlir | ||
%result = arm_sme.fmops_2way $lhs, $rhs : vector<[8]xbf16>, vector<[8]xbf16> into vector<[4]x[4]xf32> | ||
|
||
Refer to | ||
[fmopa_2way](#arm_smefmopa_2way-arm_smefmopa_2wayop) for a detailed | ||
description of 2-way outer products. | ||
|
||
| Spec | Features | | ||
| ---- | -------- | | ||
| [FMOPS (widening, 2-way, FP16 to FP32)](https://developer.arm.com/documentation/ddi0602/2023-09/SME-Instructions/FMOPS--widening---Half-precision-floating-point-sum-of-outer-products-and-subtract-) | +sme | | ||
| [BFMOPS (widening, 2-way, BF16 to FP32)](https://developer.arm.com/documentation/ddi0602/2023-09/SME-Instructions/BMOPS--Bitwise-exclusive-NOR-population-count-outer-product-and-subtract-) | +sme | | ||
``` | ||
}]; | ||
} | ||
|
||
def SMopa2WayOp | ||
: OuterProduct2Way<"smopa_2way", | ||
[ScalableVectorOfRankAndLengthAndType<[1], [8], [I16]>], | ||
[nxnxv4i32]> { | ||
let summary = "Signed integer sum of 2 outer products and accumulate"; | ||
let description = [{ | ||
Example: | ||
```mlir | ||
%result = arm_sme.smopa_2way $lhs, $rhs : vector<[8]xi16>, vector<[8]xi16> into vector<[4]x[4]xi32> | ||
|
||
Refer to | ||
[fmopa_2way](#arm_smefmopa_2way-arm_smefmopa_2wayop) for a detailed | ||
description of 2-way outer products. | ||
|
||
| Spec | Features | | ||
| ---- | -------- | | ||
| [SMOPA (2-way)](https://developer.arm.com/documentation/ddi0602/2023-09/SME-Instructions/SMOPA--2-way---Signed-integer-sum-of-outer-products-and-accumulate-) | +sme2 | | ||
``` | ||
}]; | ||
} | ||
|
||
def SMops2WayOp | ||
: OuterProduct2Way<"smops_2way", | ||
[ScalableVectorOfRankAndLengthAndType<[1], [8], [I16]>], | ||
[nxnxv4i32]> { | ||
let summary = "Signed integer sum of 2 outer products and subtract"; | ||
let description = [{ | ||
Example: | ||
```mlir | ||
%result = arm_sme.smops_2way $lhs, $rhs : vector<[8]xi16>, vector<[8]xi16> into vector<[4]x[4]xi32> | ||
|
||
Refer to | ||
[fmopa_2way](#arm_smefmopa_2way-arm_smefmopa_2wayop) for a detailed | ||
description of 2-way outer products. | ||
|
||
| Spec | Features | | ||
| ---- | -------- | | ||
| [SMOPS (2-way)](https://developer.arm.com/documentation/ddi0602/2023-09/SME-Instructions/SMOPS--2-way---Signed-integer-sum-of-outer-products-and-subtract-) | +sme2 | | ||
``` | ||
}]; | ||
} | ||
|
||
def UMopa2WayOp | ||
: OuterProduct2Way<"umopa_2way", | ||
[ScalableVectorOfRankAndLengthAndType<[1], [8], [I16]>], | ||
[nxnxv4i32]> { | ||
let summary = "Unsiged integer sum of 2 outer products and accumulate"; | ||
let description = [{ | ||
Example: | ||
```mlir | ||
%result = arm_sme.umopa_2way $lhs, $rhs : vector<[8]xi16>, vector<[8]xi16> into vector<[4]x[4]xi32> | ||
|
||
Refer to | ||
[fmopa_2way](#arm_smefmopa_2way-arm_smefmopa_2wayop) for a detailed | ||
description of 2-way outer products. | ||
|
||
| Spec | Features | | ||
| ---- | -------- | | ||
| [UMOPA (2-way)](https://developer.arm.com/documentation/ddi0602/2023-09/SME-Instructions/UMOPA--2-way---Unsigned-integer-sum-of-outer-products-and-accumulate-) | +sme2 | | ||
``` | ||
}]; | ||
} | ||
|
||
def UMops2WayOp | ||
: OuterProduct2Way<"umops_2way", | ||
[ScalableVectorOfRankAndLengthAndType<[1], [8], [I16]>], | ||
[nxnxv4i32]> { | ||
let summary = "Unsiged integer sum of 2 outer products and subtract"; | ||
let description = [{ | ||
Example: | ||
```mlir | ||
%result = arm_sme.umops_2way $lhs, $rhs : vector<[8]xi16>, vector<[8]xi16> into vector<[4]x[4]xi32> | ||
|
||
Refer to | ||
[fmopa_2way](#arm_smefmopa_2way-arm_smefmopa_2wayop) for a detailed | ||
description of 2-way outer products. | ||
|
||
| Spec | Features | | ||
| ---- | -------- | | ||
| [UMOPS (2-way)](https://developer.arm.com/documentation/ddi0602/2023-09/SME-Instructions/UMOPS--2-way---Unsigned-integer-sum-of-outer-products-and-subtract-) | +sme2 | | ||
``` | ||
}]; | ||
} | ||
|
||
def StreamingVLOp : ArmSME_Op<"streaming_vl", [Pure]> | ||
{ | ||
let summary = "Query the streaming vector length"; | ||
|
Uh oh!
There was an error while loading. Please reload this page.