Skip to content

[mlir][ArmSME] Support 2-way widening outer products #78975

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jan 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEIntrinsicOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,10 @@ def LLVM_aarch64_sme_sumopa_wide : ArmSME_IntrMopOverloadedOp<"sumopa.wide">;
def LLVM_aarch64_sme_sumops_wide : ArmSME_IntrMopOverloadedOp<"sumops.wide">;
def LLVM_aarch64_sme_usmopa_wide : ArmSME_IntrMopOverloadedOp<"usmopa.wide">;
def LLVM_aarch64_sme_usmops_wide : ArmSME_IntrMopOverloadedOp<"usmops.wide">;
def LLVM_aarch64_sme_smopa_za32 : ArmSME_IntrMopOverloadedOp<"smopa.za32">;
def LLVM_aarch64_sme_umopa_za32 : ArmSME_IntrMopOverloadedOp<"umopa.za32">;
def LLVM_aarch64_sme_smops_za32 : ArmSME_IntrMopOverloadedOp<"smops.za32">;
def LLVM_aarch64_sme_umops_za32 : ArmSME_IntrMopOverloadedOp<"umops.za32">;

class ArmSME_IntrLoadStoreOp<string mnemonic>
: ArmSME_IntrOp<mnemonic,
Expand Down
289 changes: 289 additions & 0 deletions mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -814,6 +814,295 @@ let arguments = (ins
}];
}

class OuterProductWideningBase<string mnemonic,
list<Type> allowedInputVectorTypes,
list<Type> allowedResultVectorTypes,
int numOuterProducts> :
ArmSME_Op<mnemonic, [
ArmSMETileOpInterface,
AttrSizedOperandSegments,
AllTypesMatch<["lhs", "rhs"]>,
HasMatchingMaskTypeConstraint<"lhs", "lhsMask">,
HasMatchingMaskTypeConstraint<"rhs", "rhsMask">,
PredOpTrait<
"both `lhsMask` and `rhsMask` should be provided or neither",
CPred<"bool(getLhsMask()) == bool(getRhsMask())">
>,
OptionalTypesMatchWith<"`result` and `acc` have the same type",
"result", "acc", "::llvm::cast<Type>($_self)">,
// This trait ensures the input types match the correct output type for ops
// that takes multiple inputs and outputs (i.e., 4-way).
PredOpTrait<
"tile element size equals input element size * " # numOuterProducts,
CPred<"getTileType().getElementTypeBitWidth() == "
"(getLhsType().getElementTypeBitWidth() * " # numOuterProducts # ")">
>,
]> {

let arguments = (ins
AnyTypeOf<allowedInputVectorTypes>:$lhs, AnyVector:$rhs,
Optional<AnyVector>:$lhsMask, Optional<AnyVector>:$rhsMask,
Optional<AnyVector>:$acc);
Comment on lines +844 to +845
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Optional<AnyVector>:$acc -> Optional<SMETile>:$acc ?
Optional<AnyVector>:$lhsMask -> Optional<SVEPredicate> ?
Optional<AnyVector>:$rhsMask -> Optional<SVEPredicate> ?
AnyVector:$rhs -> SVEVector ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

functionally it makes no difference, this could literally be AnyType (and perhaps it should), as AFAIK there's no way of expressing with the type alone that lhs must equal rhs without using constraints, unless the op only accepted one input type of course, in that case it could be inputType:$lhs, inputType:$rhs.

It's a bit unfortunate because the auto-generated op documentation will say rhs is vector of any type values, which we know isn't true, but a vector type that matches the size of a SVE vector isn't true either.

let results = (outs AnyTypeOf<allowedResultVectorTypes>:$result);

let assemblyFormat = [{
$lhs `,` $rhs
oilist(
`acc` `` `(` $acc `)`
| `masks` `` `(` $lhsMask `,` $rhsMask `)`
) attr-dict `:` type($lhs) `,` type($rhs) `into` type($result)
}];

let extraClassDeclaration = [{
VectorType getLhsType() { return llvm::cast<VectorType>(getLhs().getType()); }
VectorType getRhsType() { return llvm::cast<VectorType>(getRhs().getType()); }
VectorType getResultType() { return llvm::cast<VectorType>(getResult().getType()); }
std::optional<arm_sme::ArmSMETileType> getAllocatedTileType() {
// The outerproduct op allocates a new tile if no accumulator is passed.
if (!getAcc())
return arm_sme::getSMETileType(getResultType());
return std::nullopt;
}
VectorType getTileType() {
return getResultType();
}
}];
}

class OuterProduct2Way<string mnemonic,
list<Type> allowedInputVectorTypes,
list<Type> allowedResultVectorTypes>
: OuterProductWideningBase<mnemonic, allowedInputVectorTypes,
allowedResultVectorTypes, /*numOuterProducts=*/2>;

def FMopa2WayOp
: OuterProduct2Way<"fmopa_2way",
[ScalableVectorOfRankAndLengthAndType<[1], [8], [F16, BF16]>],
[nxnxv4f32]> {
let summary = "Floating-point sum of 2 outer products and accumulate";

let description = [{
This operation represents a sum of 2 widened outer products. It takes 2 1-D
scalable vectors as input and a 2-D scalable vector (ZA tile) as output.

For example (fp16 to fp32):

```mlir
%result = arm_sme.fmopa_2way %lhs, %rhs :
vector<[8]xf16>, vector<[8]xf16> into vector<[4]x[4]xf32>
```

The `lhs` encodes a matrix of shape SVLSx2 and the `rhs` a matrix of
2xSVLS, where SVLS (spec [1], section B2.1) is the number of 32-bit
elements in a vector of SVL bits. To illustrate, below is a breakdown of
this operation for fp16 to fp32, SVL=128 (i.e., vscale=1):

```
LHS RHS
[A0 A1 A2 A3 A4 A5 A6 A7] [B0 B1 B2 B3 B4 B5 B6 B7]

----------------------------------------------------------------------------

implicit layout

[A0 A1] |
[A2 A3] | [B0 B2 B4 B6]
[A4 A5] | [B1 B3 B5 B7]
[A6 A7] |

----------------------------------------------------------------------------

2 outer products

Acol0 ⊗ Brow0 | Acol1 ⊗ Brow1
------------- | -------------
|
[B0 B2 B4 B6] | [B1 B3 B5 B7]
|
[A0 [A0B0 A0B2 A0B4 A0B6] | [A1 [A1B1 A1B3 A1B5 A1B7]
A2 [A2B0 A2B2 A2B4 A2B6] | A3 [A3B1 A3B3 A3B5 A3B7]
A4 [A4B0 A4B2 A4B4 A4B6] | A5 [A5B1 A5B3 A5B5 A5B7]
A6] [A6B0 A6B2 A6B4 A6B6] | A7] [A7B1 A7B3 A7B5 A7B7]
|

----------------------------------------------------------------------------

sum of 2 outer products

Acol0 ⊗ Brow0 + Acol1 ⊗ Brow1

[A0B0 + A1B1 A0B2 + A1B3 A0B4 + A1B5 A0B6 + A1B7]
[A2B0 + A3B1 A2B2 + A3B3 A2B4 + A3B5 A2B6 + A3B7]
[A4B0 + A5B1 A4B2 + A5B3 A4B4 + A5B5 A4B6 + A5B7]
[A6B0 + A7B1 A6B2 + A7B3 A6B4 + A7B5 A6B6 + A7B7]

----------------------------------------------------------------------------
```

This operation enables the folding of 2 outer products chained via the
accumulator into a single outer product.

For example:

```mlir
%a0_ext = arith.extf %a0 : vector<[4]xf16> to vector<[4]xf32>
%b0_ext = arith.extf %b0 : vector<[4]xf16> to vector<[4]xf32>
%a1_ext = arith.extf %a1 : vector<[4]xf16> to vector<[4]xf32>
%b1_ext = arith.extf %b1 : vector<[4]xf16> to vector<[4]xf32>

%0 = arm_sme.outerproduct %a0_ext, %b0_ext : vector<[4]xf32>, vector<[4]xf32>
%1 = arm_sme.outerproduct %a1_ext, %b1_ext acc(%0) : vector<[4]xf32>, vector<[4]xf32>
```

The 2 outer products in the example above can be fused into a single outer
product as follows:

```mlir
%a_packed = "llvm.intr.experimental.vector.interleave2"(%a0, %a1) : (vector<[4]xf16>, vector<[4]xf16>) -> vector<[8]xf16>
%b_packed = "llvm.intr.experimental.vector.interleave2"(%b0, %b1) : (vector<[4]xf16>, vector<[4]xf16>) -> vector<[8]xf16>
%0 = arm_sme.fmopa_2way %a_packed, %b_packed : vector<[8]xf16>, vector<[8]xf16> into vector<[4]x[4]xf32>
```

This is implemented in the `-arm-sme-outer-product-fusion` pass.

Example: FP16 to FP32
```mlir
%result = arm_sme.fmopa_2way $lhs, $rhs : vector<[8]xf16>, vector<[8]xf16> into vector<[4]x[4]xf32>
```

Example: BF16 to FP32
```mlir
%result = arm_sme.fmopa_2way $lhs, $rhs : vector<[8]xbf16>, vector<[8]xbf16> into vector<[4]x[4]xf32>
```

| Spec | Features |
| ---- | -------- |
| [FMOPA (widening, 2-way, FP16 to FP32)](https://developer.arm.com/documentation/ddi0602/2023-09/SME-Instructions/FMOPA--widening--2-way--FP16-to-FP32---Half-precision-floating-point-sum-of-outer-products-and-accumulate-) | +sme |
| [BFMOPA (widening, 2-way, BF16 to FP32)](https://developer.arm.com/documentation/ddi0602/2023-09/SME-Instructions/BFMOPA--widening---BFloat16-sum-of-outer-products-and-accumulate-) | +sme |

[1] https://developer.arm.com/documentation/ddi0616
}];
}

// TODO: support:
// - FMOPA 2-way FP8 to FP16
// - FMOPA 4-way FP16 to FP32
// once intrinsic support lands in the backend.

def FMops2WayOp
: OuterProduct2Way<"fmops_2way",
[ScalableVectorOfRankAndLengthAndType<[1], [8], [F16, BF16]>],
[nxnxv4f32]> {
let summary = "Floating-point sum of 2 outer products and subtract";
let description = [{
Equivalent to `fmopa_2way` but outer products are subtracted from
destination `result`.

Example: FP16 to FP32
```mlir
%result = arm_sme.fmops_2way $lhs, $rhs : vector<[8]xf16>, vector<[8]xf16> into vector<[4]x[4]xf32>
```

Example: BF16 to FP32
```mlir
%result = arm_sme.fmops_2way $lhs, $rhs : vector<[8]xbf16>, vector<[8]xbf16> into vector<[4]x[4]xf32>

Refer to
[fmopa_2way](#arm_smefmopa_2way-arm_smefmopa_2wayop) for a detailed
description of 2-way outer products.

| Spec | Features |
| ---- | -------- |
| [FMOPS (widening, 2-way, FP16 to FP32)](https://developer.arm.com/documentation/ddi0602/2023-09/SME-Instructions/FMOPS--widening---Half-precision-floating-point-sum-of-outer-products-and-subtract-) | +sme |
| [BFMOPS (widening, 2-way, BF16 to FP32)](https://developer.arm.com/documentation/ddi0602/2023-09/SME-Instructions/BMOPS--Bitwise-exclusive-NOR-population-count-outer-product-and-subtract-) | +sme |
```
}];
}

def SMopa2WayOp
: OuterProduct2Way<"smopa_2way",
[ScalableVectorOfRankAndLengthAndType<[1], [8], [I16]>],
[nxnxv4i32]> {
let summary = "Signed integer sum of 2 outer products and accumulate";
let description = [{
Example:
```mlir
%result = arm_sme.smopa_2way $lhs, $rhs : vector<[8]xi16>, vector<[8]xi16> into vector<[4]x[4]xi32>

Refer to
[fmopa_2way](#arm_smefmopa_2way-arm_smefmopa_2wayop) for a detailed
description of 2-way outer products.

| Spec | Features |
| ---- | -------- |
| [SMOPA (2-way)](https://developer.arm.com/documentation/ddi0602/2023-09/SME-Instructions/SMOPA--2-way---Signed-integer-sum-of-outer-products-and-accumulate-) | +sme2 |
```
}];
}

def SMops2WayOp
: OuterProduct2Way<"smops_2way",
[ScalableVectorOfRankAndLengthAndType<[1], [8], [I16]>],
[nxnxv4i32]> {
let summary = "Signed integer sum of 2 outer products and subtract";
let description = [{
Example:
```mlir
%result = arm_sme.smops_2way $lhs, $rhs : vector<[8]xi16>, vector<[8]xi16> into vector<[4]x[4]xi32>

Refer to
[fmopa_2way](#arm_smefmopa_2way-arm_smefmopa_2wayop) for a detailed
description of 2-way outer products.

| Spec | Features |
| ---- | -------- |
| [SMOPS (2-way)](https://developer.arm.com/documentation/ddi0602/2023-09/SME-Instructions/SMOPS--2-way---Signed-integer-sum-of-outer-products-and-subtract-) | +sme2 |
```
}];
}

def UMopa2WayOp
: OuterProduct2Way<"umopa_2way",
[ScalableVectorOfRankAndLengthAndType<[1], [8], [I16]>],
[nxnxv4i32]> {
let summary = "Unsiged integer sum of 2 outer products and accumulate";
let description = [{
Example:
```mlir
%result = arm_sme.umopa_2way $lhs, $rhs : vector<[8]xi16>, vector<[8]xi16> into vector<[4]x[4]xi32>

Refer to
[fmopa_2way](#arm_smefmopa_2way-arm_smefmopa_2wayop) for a detailed
description of 2-way outer products.

| Spec | Features |
| ---- | -------- |
| [UMOPA (2-way)](https://developer.arm.com/documentation/ddi0602/2023-09/SME-Instructions/UMOPA--2-way---Unsigned-integer-sum-of-outer-products-and-accumulate-) | +sme2 |
```
}];
}

def UMops2WayOp
: OuterProduct2Way<"umops_2way",
[ScalableVectorOfRankAndLengthAndType<[1], [8], [I16]>],
[nxnxv4i32]> {
let summary = "Unsiged integer sum of 2 outer products and subtract";
let description = [{
Example:
```mlir
%result = arm_sme.umops_2way $lhs, $rhs : vector<[8]xi16>, vector<[8]xi16> into vector<[4]x[4]xi32>

Refer to
[fmopa_2way](#arm_smefmopa_2way-arm_smefmopa_2wayop) for a detailed
description of 2-way outer products.

| Spec | Features |
| ---- | -------- |
| [UMOPS (2-way)](https://developer.arm.com/documentation/ddi0602/2023-09/SME-Instructions/UMOPS--2-way---Unsigned-integer-sum-of-outer-products-and-subtract-) | +sme2 |
```
}];
}

def StreamingVLOp : ArmSME_Op<"streaming_vl", [Pure]>
{
let summary = "Query the streaming vector length";
Expand Down
4 changes: 4 additions & 0 deletions mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,10 @@ std::unique_ptr<Pass> createEnableArmStreamingPass(
/// Pass that allocates tile IDs to ArmSME operations.
std::unique_ptr<Pass> createTileAllocationPass();

/// Pass that fuses 'arm_sme.outerproduct' ops into 2-way or 4-way widening
/// variants.
std::unique_ptr<Pass> createOuterProductFusionPass();

//===----------------------------------------------------------------------===//
// Registration
//===----------------------------------------------------------------------===//
Expand Down
34 changes: 34 additions & 0 deletions mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -122,4 +122,38 @@ def TileAllocation
let dependentDialects = ["func::FuncDialect"];
}

def OuterProductFusion
: Pass<"arm-sme-outer-product-fusion", "mlir::func::FuncOp"> {
let summary = "Fuse 'arm_sme.outerproduct' operations into 2-way or 4-way widening variants";
let description = [{
This pass fuses 'arm_sme.outerproduct' operations that are chained via the
accumulator into 2-way or 4-way ArmSME outer product operations.

For example:
```mlir
%a0_ext = arith.extf %a0 : vector<[4]xf16> to vector<[4]xf32>
%b0_ext = arith.extf %b0 : vector<[4]xf16> to vector<[4]xf32>
%a1_ext = arith.extf %a1 : vector<[4]xf16> to vector<[4]xf32>
%b1_ext = arith.extf %b1 : vector<[4]xf16> to vector<[4]xf32>

%0 = arm_sme.outerproduct %a0_ext, %b0_ext : vector<[4]xf32>, vector<[4]xf32>
%1 = arm_sme.outerproduct %a1_ext, %b1_ext acc(%0) : vector<[4]xf32>, vector<[4]xf32>
```

Becomes:

```mlir
%a_packed = "llvm.intr.experimental.vector.interleave2"(%a0, %a1) : (vector<[4]xf16>, vector<[4]xf16>) -> vector<[8]xf16>
%b_packed = "llvm.intr.experimental.vector.interleave2"(%b0, %b1) : (vector<[4]xf16>, vector<[4]xf16>) -> vector<[8]xf16>
%0 = arm_sme.fmopa_2way %a_packed, %b_packed : vector<[8]xf16>, vector<[8]xf16> into vector<[4]x[4]xf32>
```

For further information on the 2-way or 4-way widening ops see:
https://mlir.llvm.org/docs/Dialects/ArmSME/#arm_smefmopa_2way-arm_smefmopa_2wayop
https://mlir.llvm.org/docs/Dialects/ArmSME/#arm_smesmopa_4way-arm_smesmopa_4wayop
}];
let constructor = "mlir::arm_sme::createOuterProductFusionPass()";
let dependentDialects = ["func::FuncDialect", "arm_sme::ArmSMEDialect", "LLVM::LLVMDialect"];
}

#endif // MLIR_DIALECT_ARMSME_TRANSFORMS_PASSES_TD
4 changes: 4 additions & 0 deletions mlir/include/mlir/Dialect/ArmSME/Transforms/Transforms.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@ class LLVMConversionTarget;
class LLVMTypeConverter;
class RewritePatternSet;

namespace arm_sme {
void populateOuterProductFusionPatterns(RewritePatternSet &patterns);
} // namespace arm_sme

} // namespace mlir

#endif // MLIR_DIALECT_ARMSME_TRANSFORMS_H
Loading