Skip to content

[mlir][ArmSME] Support 4-way widening outer products #79288

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Feb 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
325 changes: 325 additions & 0 deletions mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1103,6 +1103,331 @@ def UMops2WayOp
}];
}

class OuterProduct4Way<string mnemonic,
list<Type> allowedInputVectorTypes,
list<Type> allowedResultVectorTypes>
: OuterProductWideningBase<mnemonic, allowedInputVectorTypes,
allowedResultVectorTypes, /*numOuterProducts=*/4>;

def SMopa4WayOp
: OuterProduct4Way<"smopa_4way",
[ScalableVectorOfRankAndLengthAndType<[1], [16], [I8]>,
ScalableVectorOfRankAndLengthAndType<[1], [8], [I16]>],
[nxnxv4i32, nxnxv2i64]> {
let summary = "Signed integer sum of 4 outer products and accumulate";
let description = [{
This operation represents a sum of 4 widened outer products. It takes 2 1-D
scalable vectors as input and a 2-D scalable vector (ZA tile) as output.

For example (i8 to i32):

```mlir
%result = arm_sme.smopa_4way $lhs, $rhs :
vector<[16]xi8>, vector<[16]xi8> into vector<[4]x[4]xi32>
```

The `lhs` encodes a matrix of shape SVLSx4 and the `rhs` a matrix of
4xSVLS, where SVLS (spec [1], section B2.1) is the number of 32-bit
elements in a vector of SVL bits. To illustrate, below is a breakdown of
this operation for i8 to i32, SVL=128 (i.e., vscale=1):

```
LHS
[A0 A1 A2 A3 A4 A5 A6 A7 A8 A9 A10 A11 A12 A15 A14 A15]

RHS
[B0 B1 B2 B3 B4 B5 B6 B7 B8 B9 B10 B11 B12 B13 B14 B15]

----------------------------------------------------------------------------

implicit layout

[A0 A1 A2 A3] | [B0 B4 B8 B12]
[A4 A5 A6 A7] | [B1 B5 B9 B13]
[A8 A9 A10 A11] | [B2 B6 B10 B14]
[A12 A13 A14 A15] | [B3 B7 B11 B15]

----------------------------------------------------------------------------

4 outer products

Acol0 ⊗ Brow0 | Acol1 ⊗ Brow1
------------- | -------------
|
[B0 B4 B8 B12] | [B1 B5 B9 B13]
|
[A0 [ A0B0 A0B4 A0B8 A0B12] | [A1 [ A1B1 A1B5 A1B9 A1B13]
A4 [ A4B0 A4B4 A4B8 A4B12] | A5 [ A5B1 A5B5 A5B9 A5B13]
A8 [ A8B0 A8B4 A8B8 A8B12] | A9 [ A9B1 A9B5 A9B9 A9B13]
A12] [A12B0 A12B4 A12B8 A12B12] | A13] [A13B1 A13B5 A13B9 A13B13]
|
Acol2 ⊗ Brow2 | Acol3 ⊗ Brow3
------------- | -------------
|
[B2, B6, B10, B14] | [B3 B7 B11 B15]
|
[A2 [ A2B2 A2B6 A2B10 A2B14] | [A3 [ A3B3 A3B7 A3B11 A3B15]
A6 [ A6B2 A6B6 A6B10 A6B14] | A7 [ A7B3 A7B7 A7B11 A7B15]
A10 [A10B2 A10B6 A10B10 A10B14] | A11 [A11B3 A11B7 A11B11 A11B15]
A14] [A14B2 A14B6 A14B10 A14B14] | A15] [A15B3 A15B7 A15B11 A15B15]
|

----------------------------------------------------------------------------

sum of 4 outer products

Acol0 ⊗ Brow0 + Acol1 ⊗ Brow1 + Acol2 ⊗ Brow2 + Acol3 ⊗ Brow3

[ A0B0 + A1B1 + A2B2 + A3B3 ... ... A0B12 + A1B13 + A2B14 + A3B15]
[ A4B0 + A5B1 + A6B2 + A7B3 ... ... A4B12 + A5B13 + A6B14 + A7B15]
[ A8B0 + A9B1 + A10B2 + A11B3 ... ... A8B12 + A9B13 + A10B14 + A11B15]
[A12B0 + A13B1 + A14B2 + A15B3 ... ... A12B12 + A13B13 + A14B14 + A15B15]

----------------------------------------------------------------------------
```

This operation enables the folding of 4 outer products chained via the
accumulator into a single outer product.

For example:

```mlir
%a0_ext = arith.extsi %a0 : vector<[4]xi8> to vector<[4]xi32>
%b0_ext = arith.extsi %b0 : vector<[4]xi8> to vector<[4]xi32>

%a1_ext = arith.extsi %a1 : vector<[4]xi8> to vector<[4]xi32>
%b1_ext = arith.extsi %b1 : vector<[4]xi8> to vector<[4]xi32>

%a2_ext = arith.extsi %a2 : vector<[4]xi8> to vector<[4]xi32>
%b2_ext = arith.extsi %b2 : vector<[4]xi8> to vector<[4]xi32>

%a3_ext = arith.extsi %a3 : vector<[4]xi8> to vector<[4]xi32>
%b3_ext = arith.extsi %b3 : vector<[4]xi8> to vector<[4]xi32>

%0 = arm_sme.outerproduct %a0_ext, %b0_ext : vector<[4]xi32>, vector<[4]xi32>
%1 = arm_sme.outerproduct %a1_ext, %b1_ext acc(%0) : vector<[4]xi32>, vector<[4]xi32>
%2 = arm_sme.outerproduct %a2_ext, %b2_ext acc(%1) : vector<[4]xi32>, vector<[4]xi32>
%3 = arm_sme.outerproduct %a3_ext, %b3_ext acc(%2) : vector<[4]xi32>, vector<[4]xi32>
```

The 4 outer products in the example above can be fused into a single outer
product as follows:

```mlir
%lhs0 = "llvm.intr.experimental.vector.interleave2"(%a0, %a2) : (vector<[4]xi8>, vector<[4]xi8>) -> vector<[8]xi8>
%lhs1 = "llvm.intr.experimental.vector.interleave2"(%a1, %a3) : (vector<[4]xi8>, vector<[4]xi8>) -> vector<[8]xi8>
%lhs = "llvm.intr.experimental.vector.interleave2"(%lhs0, %lhs1) : (vector<[8]xi8>, vector<[8]xi8>) -> vector<[16]xi8>

%rhs0 = "llvm.intr.experimental.vector.interleave2"(%b0, %b2) : (vector<[4]xi8>, vector<[4]xi8>) -> vector<[8]xi8>
%rhs1 = "llvm.intr.experimental.vector.interleave2"(%b1, %b3) : (vector<[4]xi8>, vector<[4]xi8>) -> vector<[8]xi8>
%rhs = "llvm.intr.experimental.vector.interleave2"(%rhs0, %rhs1) : (vector<[8]xi8>, vector<[8]xi8>) -> vector<[16]xi8>

%0 = arm_sme.smopa_4way %lhs, %rhs : vector<[16]xi8>, vector<[16]xi8> into vector<[4]x[4]xi32>
```

This is implemented in the `-arm-sme-outer-product-fusion` pass.

Example: I8 to I32
```mlir
%result = arm_sme.smopa_4way $lhs, $rhs : vector<[16]xi8>, vector<[16]xi8> into vector<[4]x[4]xi32>
```

Example: I16 to I64
```mlir
%result = arm_sme.smopa_4way $lhs, $rhs : vector<[8]xi16>, vector<[8]xi16> into vector<[2]x[2]xi64>
```

| Spec | Features |
| ---- | -------- |
| [SMOPA (4-way)](https://developer.arm.com/documentation/ddi0602/2023-09/SME-Instructions/SMOPA--4-way---Signed-integer-sum-of-outer-products-and-accumulate-) | +sme (32-bit), +sme-i16i64 (64-bit)|
}];
}

def SMops4WayOp
: OuterProduct4Way<"smops_4way",
[ScalableVectorOfRankAndLengthAndType<[1], [16], [I8]>,
ScalableVectorOfRankAndLengthAndType<[1], [8], [I16]>],
[nxnxv4i32, nxnxv2i64]> {
let summary = "Signed integer sum of 4 outer products and subtract";
let description = [{
Equivalent to `smopa_4way` but outer products are subtracted from
destination `result`.

Example: I8 to I32
```mlir
%result = arm_sme.smops_4way $lhs, $rhs : vector<[16]xi8>, vector<[16]xi8> into vector<[4]x[4]xi32>
```

Example: I16 to I64
```mlir
%result = arm_sme.smops_4way $lhs, $rhs : vector<[8]xi16>, vector<[8]xi16> into vector<[2]x[2]xi64>
```

Refer to [smopa_4way](#arm_smesmopa_4way-arm_smesmopa4wayop) for a
detailed description of 4-way outer products.

| Spec | Features |
| ---- | -------- |
| [SMOPS (4-way)](https://developer.arm.com/documentation/ddi0602/2023-09/SME-Instructions/SMOPS--4-way---Signed-integer-sum-of-outer-products-and-subtract-) | +sme (32-bit), +sme-i16i64 (64-bit)|
}];
}

def UMopa4WayOp
: OuterProduct4Way<"umopa_4way",
[ScalableVectorOfRankAndLengthAndType<[1], [16], [I8]>,
ScalableVectorOfRankAndLengthAndType<[1], [8], [I16]>],
[nxnxv4i32, nxnxv2i64]> {
let summary = "Unsigned integer sum of 4 outer products and accumulate";
let description = [{
Example: I8 to I32
```mlir
%result = arm_sme.umopa_4way $lhs, $rhs : vector<[16]xi8>, vector<[16]xi8> into vector<[4]x[4]xi32>
```

Example: I16 to I64
```mlir
%result = arm_sme.umopa_4way $lhs, $rhs : vector<[8]xi16>, vector<[8]xi16> into vector<[2]x[2]xi64>
```

Refer to [smopa_4way](#arm_smesmopa_4way-arm_smesmopa4wayop) for a
detailed description of 4-way outer products.

| Spec | Features |
| ---- | -------- |
| [UMOPA (4-way)](https://developer.arm.com/documentation/ddi0602/2023-09/SME-Instructions/UMOPA--4-way---Unsigned-integer-sum-of-outer-products-and-accumulate-) | +sme (32-bit), +sme-i16i64 (64-bit)|
}];
}

def UMops4WayOp
: OuterProduct4Way<"umops_4way",
[ScalableVectorOfRankAndLengthAndType<[1], [16], [I8]>,
ScalableVectorOfRankAndLengthAndType<[1], [8], [I16]>],
[nxnxv4i32, nxnxv2i64]> {
let summary = "Unsigned integer sum of 4 outer products and subtract";
let description = [{
Example: I8 to I32
```mlir
%result = arm_sme.umops_4way $lhs, $rhs : vector<[16]xi8>, vector<[16]xi8> into vector<[4]x[4]xi32>
```

Example: I16 to I64
```mlir
%result = arm_sme.umops_4way $lhs, $rhs : vector<[8]xi16>, vector<[8]xi16> into vector<[2]x[2]xi64>
```

Refer to [smopa_4way](#arm_smesmopa_4way-arm_smesmopa4wayop) for a
detailed description of 4-way outer products.

| Spec | Features |
| ---- | -------- |
| [UMOPS (4-way)](https://developer.arm.com/documentation/ddi0602/2023-09/SME-Instructions/UMOPS--4-way---Unsigned-integer-sum-of-outer-products-and-subtract-) | +sme (32-bit), +sme-i16i64 (64-bit)|
}];
}

def SuMopa4WayOp
: OuterProduct4Way<"sumopa_4way",
[ScalableVectorOfRankAndLengthAndType<[1], [16], [I8]>,
ScalableVectorOfRankAndLengthAndType<[1], [8], [I16]>],
[nxnxv4i32, nxnxv2i64]> {
let summary = "Signed by unsigned integer sum of 4 outer products and accumulate";
let description = [{
Example: I8 to I32
```mlir
%result = arm_sme.sumopa_4way $lhs, $rhs : vector<[16]xi8>, vector<[16]xi8> into vector<[4]x[4]xi32>
```

Example: I16 to I64
```mlir
%result = arm_sme.sumopa_4way $lhs, $rhs : vector<[8]xi16>, vector<[8]xi16> into vector<[2]x[2]xi64>
```

Refer to [smopa_4way](#arm_smesmopa_4way-arm_smesmopa4wayop) for a
detailed description of 4-way outer products.

| Spec | Features |
| ---- | -------- |
| [SUMOPA (4-way)](https://developer.arm.com/documentation/ddi0602/2023-09/SME-Instructions/SUMOPA--Signed-by-unsigned-integer-sum-of-outer-products-and-accumulate-) | +sme (32-bit), +sme-i16i64 (64-bit)|
}];
}

def SuMops4WayOp
: OuterProduct4Way<"sumops_4way",
[ScalableVectorOfRankAndLengthAndType<[1], [16], [I8]>,
ScalableVectorOfRankAndLengthAndType<[1], [8], [I16]>],
[nxnxv4i32, nxnxv2i64]> {
let summary = "Signed by unsigned integer sum of 4 outer products and subtract";
let description = [{
Example: I8 to I32
```mlir
%result = arm_sme.sumops_4way $lhs, $rhs : vector<[16]xi8>, vector<[16]xi8> into vector<[4]x[4]xi32>
```

Example: I16 to I64
```mlir
%result = arm_sme.sumops_4way $lhs, $rhs : vector<[8]xi16>, vector<[8]xi16> into vector<[2]x[2]xi64>
```

Refer to [smopa_4way](#arm_smesmopa_4way-arm_smesmopa4wayop) for a
detailed description of 4-way outer products.

| Spec | Features |
| ---- | -------- |
| [SUMOPS (4-way)](https://developer.arm.com/documentation/ddi0602/2023-09/SME-Instructions/SUMOPS--Signed-by-unsigned-integer-sum-of-outer-products-and-subtract-) | +sme (32-bit), +sme-i16i64 (64-bit)|
}];
}

def UsMopa4WayOp
: OuterProduct4Way<"usmopa_4way",
[ScalableVectorOfRankAndLengthAndType<[1], [16], [I8]>,
ScalableVectorOfRankAndLengthAndType<[1], [8], [I16]>],
[nxnxv4i32, nxnxv2i64]> {
let summary = "Unsigned by signed integer sum of 4 outer products and accumulate";
let description = [{
Example: I8 to I32
```mlir
%result = arm_sme.usmopa_4way $lhs, $rhs : vector<[16]xi8>, vector<[16]xi8> into vector<[4]x[4]xi32>
```

Example: I16 to I64
```mlir
%result = arm_sme.usmopa_4way $lhs, $rhs : vector<[8]xi16>, vector<[8]xi16> into vector<[2]x[2]xi64>
```

Refer to [smopa_4way](#arm_smesmopa_4way-arm_smesmopa4wayop) for a
detailed description of 4-way outer products.

| Spec | Features |
| ---- | -------- |
| [USMOPA (4-way)](https://developer.arm.com/documentation/ddi0602/2023-09/SME-Instructions/USMOPA--Unsigned-by-signed-integer-sum-of-outer-products-and-accumulate-) | +sme (32-bit), +sme-i16i64 (64-bit)|
}];
}

def UsMops4WayOp
: OuterProduct4Way<"usmops_4way",
[ScalableVectorOfRankAndLengthAndType<[1], [16], [I8]>,
ScalableVectorOfRankAndLengthAndType<[1], [8], [I16]>],
[nxnxv4i32, nxnxv2i64]> {
let summary = "Unsigned by signed integer sum of 4 outer products and subtract";
let description = [{
Example: I8 to I32
```mlir
%result = arm_sme.usmops_4way $lhs, $rhs : vector<[16]xi8>, vector<[16]xi8> into vector<[4]x[4]xi32>
```

Example: I16 to I64
```mlir
%result = arm_sme.usmops_4way $lhs, $rhs : vector<[8]xi16>, vector<[8]xi16> into vector<[2]x[2]xi64>
```

Refer to [smopa_4way](#arm_smesmopa_4way-arm_smesmopa4wayop) for a
detailed description of 4-way outer products.

| Spec | Features |
| ---- | -------- |
| [USMOPS (4-way)](https://developer.arm.com/documentation/ddi0602/2023-09/SME-Instructions/USMOPS--Unsigned-by-signed-integer-sum-of-outer-products-and-subtract-) | +sme (32-bit), +sme-i16i64 (64-bit)|
}];
}

def StreamingVLOp : ArmSME_Op<"streaming_vl", [Pure]>
{
let summary = "Query the streaming vector length";
Expand Down
16 changes: 16 additions & 0 deletions mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -939,6 +939,22 @@ void mlir::populateArmSMEToLLVMConversionPatterns(LLVMTypeConverter &converter,
arm_sme::aarch64_sme_umopa_za32>,
OuterProductWideningOpConversion<arm_sme::UMops2WayOp,
arm_sme::aarch64_sme_umops_za32>,
OuterProductWideningOpConversion<arm_sme::SMopa4WayOp,
arm_sme::aarch64_sme_smopa_wide>,
OuterProductWideningOpConversion<arm_sme::SMops4WayOp,
arm_sme::aarch64_sme_smops_wide>,
OuterProductWideningOpConversion<arm_sme::UMopa4WayOp,
arm_sme::aarch64_sme_umopa_wide>,
OuterProductWideningOpConversion<arm_sme::UMops4WayOp,
arm_sme::aarch64_sme_umops_wide>,
OuterProductWideningOpConversion<arm_sme::SuMopa4WayOp,
arm_sme::aarch64_sme_sumopa_wide>,
OuterProductWideningOpConversion<arm_sme::SuMops4WayOp,
arm_sme::aarch64_sme_sumops_wide>,
OuterProductWideningOpConversion<arm_sme::UsMopa4WayOp,
arm_sme::aarch64_sme_usmopa_wide>,
OuterProductWideningOpConversion<arm_sme::UsMops4WayOp,
arm_sme::aarch64_sme_usmops_wide>,
ZeroOpConversion, GetTileConversion>(patterns, converter);
}

Expand Down
Loading