Skip to content

[mlir][ArmSME] Support 4-way widening outer products #79288

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Feb 7, 2024

Conversation

c-rhodes
Copy link
Collaborator

This patch introduces support for 4-way widening outer products. This enables
the folding of 4 'arm_sme.outerproduct' operations that are chained via the
accumulator into single widened operations.

Changes:

  • Adds the following operations:
    • smopa_wide_4way, smops_wide_4way
    • umopa_wide_4way, umops_wide_4way
    • sumopa_wide_4way, sumops_wide_4way
    • sumopa_wide_4way, sumops_wide_4way
  • Implements conversions for the above ops to intrinsics in ArmSMEToLLVM.
  • Extends 'arm-sme-outer-product' pass.

For a detailed description of these operations see the
'arm_sme.smopa_wide_4way' description.

@llvmbot
Copy link
Member

llvmbot commented Jan 24, 2024

@llvm/pr-subscribers-mlir-sve
@llvm/pr-subscribers-mlir
@llvm/pr-subscribers-mlir-llvm

@llvm/pr-subscribers-mlir-sme

Author: Cullen Rhodes (c-rhodes)

Changes

This patch introduces support for 4-way widening outer products. This enables
the folding of 4 'arm_sme.outerproduct' operations that are chained via the
accumulator into single widened operations.

Changes:

  • Adds the following operations:
    • smopa_wide_4way, smops_wide_4way
    • umopa_wide_4way, umops_wide_4way
    • sumopa_wide_4way, sumops_wide_4way
    • sumopa_wide_4way, sumops_wide_4way
  • Implements conversions for the above ops to intrinsics in ArmSMEToLLVM.
  • Extends 'arm-sme-outer-product' pass.

For a detailed description of these operations see the
'arm_sme.smopa_wide_4way' description.


Patch is 152.33 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/79288.diff

17 Files Affected:

  • (modified) mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEIntrinsicOps.td (+4)
  • (modified) mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td (+643)
  • (modified) mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h (+3)
  • (modified) mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td (+39)
  • (modified) mlir/include/mlir/Dialect/ArmSME/Transforms/Transforms.h (+4)
  • (modified) mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td (+4)
  • (modified) mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp (+80-2)
  • (modified) mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt (+2)
  • (added) mlir/lib/Dialect/ArmSME/Transforms/OuterProductWidening.cpp (+501)
  • (modified) mlir/test/Conversion/ArmSMEToLLVM/arm-sme-to-llvm.mlir (+272)
  • (modified) mlir/test/Dialect/ArmSME/invalid.mlir (+66)
  • (added) mlir/test/Dialect/ArmSME/outer-product-widening.mlir (+785)
  • (modified) mlir/test/Dialect/ArmSME/roundtrip.mlir (+272)
  • (added) mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f16f16f32.mlir (+100)
  • (added) mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-i8i8i32.mlir (+142)
  • (modified) mlir/test/Target/LLVMIR/arm-sme.mlir (+12)
  • (modified) mlir/test/Target/LLVMIR/arm-sve.mlir (+7)
diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEIntrinsicOps.td b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEIntrinsicOps.td
index d85ef963ae5dc46..f051e03efbcda64 100644
--- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEIntrinsicOps.td
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEIntrinsicOps.td
@@ -105,6 +105,10 @@ def LLVM_aarch64_sme_sumopa_wide : ArmSME_IntrMopOverloadedOp<"sumopa.wide">;
 def LLVM_aarch64_sme_sumops_wide : ArmSME_IntrMopOverloadedOp<"sumops.wide">;
 def LLVM_aarch64_sme_usmopa_wide : ArmSME_IntrMopOverloadedOp<"usmopa.wide">;
 def LLVM_aarch64_sme_usmops_wide : ArmSME_IntrMopOverloadedOp<"usmops.wide">;
+def LLVM_aarch64_sme_smopa_za32 : ArmSME_IntrMopOverloadedOp<"smopa.za32">;
+def LLVM_aarch64_sme_umopa_za32 : ArmSME_IntrMopOverloadedOp<"umopa.za32">;
+def LLVM_aarch64_sme_smops_za32 : ArmSME_IntrMopOverloadedOp<"smops.za32">;
+def LLVM_aarch64_sme_umops_za32 : ArmSME_IntrMopOverloadedOp<"umops.za32">;
 
 class ArmSME_IntrLoadStoreOp<string mnemonic>
     : ArmSME_IntrOp<mnemonic,
diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
index 8a34ad7e52012fe..ed8b100eadf3ab2 100644
--- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
@@ -814,6 +814,649 @@ let arguments = (ins
   }];
 }
 
+class OuterProductWideBase<string mnemonic,
+                           list<Type> allowedInputVectorTypes,
+                           list<Type> allowedResultVectorTypes,
+                           int numOuterProducts> :
+  ArmSME_Op<mnemonic, [
+    ArmSMETileOpInterface,
+    AttrSizedOperandSegments,
+    AllTypesMatch<["lhs", "rhs"]>,
+    HasMatchingMaskTypeConstraint<"lhs", "lhsMask">,
+    HasMatchingMaskTypeConstraint<"rhs", "rhsMask">,
+    PredOpTrait<
+      "both `lhsMask` and `rhsMask` should be provided or neither",
+      CPred<"bool(getLhsMask()) == bool(getRhsMask())">
+    >,
+    OptionalTypesMatchWith<"result and acc have the same type",
+                           "result", "acc", "::llvm::cast<Type>($_self)">,
+    // this trait ensures the input type match the correct output type for ops
+    // that takes multiple inputs and outputs (i.e., 4-way).
+    PredOpTrait<
+      "tile element size equals lhs element size * " # numOuterProducts,
+      CPred<"getTileType().getElementTypeBitWidth() == "
+            "(getLhsType().getElementTypeBitWidth() * " # numOuterProducts # ")">
+    >,
+  ]> {
+
+  let arguments = (ins
+    AnyTypeOf<allowedInputVectorTypes>:$lhs, AnyVector:$rhs,
+    Optional<AnyVector>:$lhsMask, Optional<AnyVector>:$rhsMask,
+    Optional<AnyVector>:$acc);
+  let results = (outs AnyTypeOf<allowedResultVectorTypes>:$result);
+
+  let assemblyFormat = [{
+    $lhs `,` $rhs
+    oilist(
+        `acc` `` `(` $acc `)`
+      | `masks` `` `(` $lhsMask `,` $rhsMask `)`
+    ) attr-dict `:` type($lhs) `,` type($rhs) `into` type($result)
+  }];
+
+  let extraClassDeclaration = [{
+    VectorType getLhsType() { return llvm::cast<VectorType>(getLhs().getType()); }
+    VectorType getRhsType() { return llvm::cast<VectorType>(getRhs().getType()); }
+    VectorType getResultType() { return llvm::cast<VectorType>(getResult().getType()); }
+    std::optional<arm_sme::ArmSMETileType> getAllocatedTileType() {
+      // The outerproduct op allocates a new tile if no accumulator is passed.
+      if (!getAcc())
+        return arm_sme::getSMETileType(getResultType());
+      return std::nullopt;
+    }
+    VectorType getTileType() {
+      return getResultType();
+    }
+  }];
+}
+
+class OuterProductWide2Way<string mnemonic,
+                           list<Type> allowedInputVectorTypes,
+                           list<Type> allowedResultVectorTypes>
+  : OuterProductWideBase<mnemonic, allowedInputVectorTypes,
+                         allowedResultVectorTypes, /*numOuterProducts=*/2>;
+
+class OuterProductWide4Way<string mnemonic,
+                           list<Type> allowedInputVectorTypes,
+                           list<Type> allowedResultVectorTypes>
+  : OuterProductWideBase<mnemonic, allowedInputVectorTypes,
+                         allowedResultVectorTypes, /*numOuterProducts=*/4>;
+
+def FMopaWide2WayOp
+  : OuterProductWide2Way<"fmopa_wide_2way",
+      [ScalableVectorOfRankAndLengthAndType<[1], [8], [F16, BF16]>],
+      [nxnxv4f32]> {
+  let summary = "Floating-point sum of 2 outer products and accumulate";
+
+  let description = [{
+    This operation represents a sum of 2 widened outer products. It takes 2 1-D
+    scalable vectors as input and a 2-D scalable vector (ZA tile) as output.
+
+    For example (fp16 to fp32):
+
+    ```mlir
+    %result = arm_sme.fmopa_wide_2way %lhs, %rhs :
+      vector<[8]xf16>, vector<[8]xf16> into vector<[4]x[4]xf32>
+    ```
+
+    The `lhs` encodes a matrix of shape SVLSx2 and the `rhs` a matrix of
+    2xSVLS, where SVLS (spec [1], section B2.1) is the number of 32-bit
+    elements in a vector of SVL bits. To illustrate, below is a breakdown of
+    this operation for SVL=128 (i.e., vscale=1):
+
+    ```
+                          LHS                          RHS
+               [A0 A1 A2 A3 A4 A5 A6 A7]    [B0 B1 B2 B3 B4 B5 B6 B7]
+
+    ----------------------------------------------------------------------------
+
+                                  implicit layout
+
+                              [A0 A1]    |
+                              [A2 A3]    |    [B0 B2 B4 B6]
+                              [A4 A5]    |    [B1 B3 B5 B7]
+                              [A6 A7]    |
+
+    ----------------------------------------------------------------------------
+
+                                  2 outer products
+
+                      Acol0 ⊗ Brow0      |           Acol1 ⊗ Brow1
+                      -------------      |           -------------
+                                         |
+                  [B0 B2 B4 B6]          |       [B1 B3 B5 B7]
+                                         |
+             [A0  [A0B0 A0B2 A0B4 A0B6]  |  [A1  [A1B1 A1B3 A1B5 A1B7]
+              A2  [A2B0 A2B2 A2B4 A2B6]  |   A3  [A3B1 A3B3 A3B5 A3B7]
+              A4  [A4B0 A4B2 A4B4 A4B6]  |   A5  [A5B1 A5B3 A5B5 A5B7]
+              A6] [A6B0 A6B2 A6B4 A6B6]  |   A7] [A7B1 A7B3 A7B5 A7B7]
+                                         |
+
+    ----------------------------------------------------------------------------
+
+                              sum of 2 outer products
+
+                           Acol0 ⊗ Brow0 + Acol1 ⊗ Brow1
+
+                 [A0B0 + A1B1 A0B2 + A1B3 A0B4 + A1B5 A0B6 + A1B7]
+                 [A2B0 + A3B1 A2B2 + A3B3 A2B4 + A3B5 A2B6 + A3B7]
+                 [A4B0 + A5B1 A4B2 + A5B3 A4B4 + A5B5 A4B6 + A5B7]
+                 [A6B0 + A7B1 A6B2 + A7B3 A6B4 + A7B5 A6B6 + A7B7]
+
+    ----------------------------------------------------------------------------
+    ```
+
+    This operation enables the folding of 2 outer products chained via the
+    accumulator into a single outer product.
+
+    For example:
+
+    ```mlir
+    %a0_ext = arith.extf %a0 : vector<[4]xf16> to vector<[4]xf32>
+    %b0_ext = arith.extf %b0 : vector<[4]xf16> to vector<[4]xf32>
+    %a1_ext = arith.extf %a1 : vector<[4]xf16> to vector<[4]xf32>
+    %b1_ext = arith.extf %b1 : vector<[4]xf16> to vector<[4]xf32>
+
+    %0 = arm_sme.outerproduct %a0_ext, %b0_ext : vector<[4]xf32>, vector<[4]xf32>
+    %1 = arm_sme.outerproduct %a1_ext, %b1_ext acc(%0) : vector<[4]xf32>, vector<[4]xf32>
+    ```
+
+    The 2 outer products in the example above can be fused into a single outer
+    product as follows:
+
+	```mlir
+    %undef = llvm.mlir.undef : vector<[8]xf16>
+    %a0_ins = vector.scalable.insert %a0, %undef[0] : vector<[4]xf16> into vector<[8]xf16>
+    %a1_ins = vector.scalable.insert %a1, %undef[0] : vector<[4]xf16> into vector<[8]xf16>
+    %a_packed = "arm_sve.intr.zip1"(%a0_ins, %a1_ins) : (vector<[8]xf16>, vector<[8]xf16>) -> vector<[8]xf16>
+    %b0_ins = vector.scalable.insert %b0, %undef[0] : vector<[4]xf16> into vector<[8]xf16>
+    %b1_ins = vector.scalable.insert %b1, %undef[0] : vector<[4]xf16> into vector<[8]xf16>
+    %b_packed = "arm_sve.intr.zip1"(%b0_ins, %b1_ins) : (vector<[8]xf16>, vector<[8]xf16>) -> vector<[8]xf16>
+    %0 = arm_sme.fmopa_wide_2way %a_packed, %b_packed : vector<[8]xf16>, vector<[8]xf16> into vector<[4]x[4]xf32>
+	```
+
+    This is implemented in the `-arm-sme-outer-product-widening` pass.
+
+    Example: FP16 to FP32
+    ```mlir
+    %result = arm_sme.fmopa_wide_2way $lhs, $rhs : vector<[8]xf16>, vector<[8]xf16> into vector<[4]x[4]xf32>
+    ```
+
+    Example: BF16 to FP32
+    ```mlir
+    %result = arm_sme.fmopa_wide_2way $lhs, $rhs : vector<[8]xbf16>, vector<[8]xbf16> into vector<[4]x[4]xf32>
+    ```
+
+    | Spec | Features |
+    | ---- | -------- |
+    | [FMOPA (widening, 2-way, FP16 to FP32)](https://developer.arm.com/documentation/ddi0602/2023-09/SME-Instructions/FMOPA--widening--2-way--FP16-to-FP32---Half-precision-floating-point-sum-of-outer-products-and-accumulate-) | +sme |
+    | [BFMOPA (widening, 2-way, BF16 to FP32)](https://developer.arm.com/documentation/ddi0602/2023-09/SME-Instructions/BFMOPA--widening---BFloat16-sum-of-outer-products-and-accumulate-) | +sme |
+
+    [1] https://developer.arm.com/documentation/ddi0616
+  }];
+}
+
+// TODO: support:
+// - FMOPA 2-way FP8 to FP16
+// - FMOPA 4-way FP16 to FP32
+// once intrinsic support lands in the backend.
+
+def FMopsWide2WayOp
+  : OuterProductWide2Way<"fmops_wide_2way",
+      [ScalableVectorOfRankAndLengthAndType<[1], [8], [F16, BF16]>],
+      [nxnxv4f32]> {
+  let summary = "Floating-point sum of 2 outer products and subtract";
+  let description = [{
+    Equivalent to `fmopa_wide_2way` but outer products are subtracted from
+    destination `result`.
+
+    Example: FP16 to FP32
+    ```mlir
+    %result = arm_sme.fmops_wide_2way $lhs, $rhs : vector<[8]xf16>, vector<[8]xf16> into vector<[4]x[4]xf32>
+    ```
+
+    Example: BF16 to FP32
+    ```mlir
+    %result = arm_sme.fmops_wide_2way $lhs, $rhs : vector<[8]xbf16>, vector<[8]xbf16> into vector<[4]x[4]xf32>
+
+    Refer to
+    [fmopa_wide_2way](#arm_smefmopa_wide_2way-arm_smefmopa_wide_2wayop) for a
+    detailed description of 2-way outer products.
+
+    | Spec | Features |
+    | ---- | -------- |
+    | [FMOPS (widening, 2-way, FP16 to FP32)](https://developer.arm.com/documentation/ddi0602/2023-09/SME-Instructions/FMOPS--widening---Half-precision-floating-point-sum-of-outer-products-and-subtract-) | +sme |
+    | [BFMOPS (widening, 2-way, BF16 to FP32)](https://developer.arm.com/documentation/ddi0602/2023-09/SME-Instructions/BMOPS--Bitwise-exclusive-NOR-population-count-outer-product-and-subtract-) | +sme |
+    ```
+  }];
+}
+
+def SMopaWide2WayOp
+  : OuterProductWide2Way<"smopa_wide_2way",
+      [ScalableVectorOfRankAndLengthAndType<[1], [8], [I16]>],
+      [nxnxv4i32]> {
+  let summary = "Signed integer sum of 2 outer products and accumulate";
+  let description = [{
+    Example:
+    ```mlir
+    %result = arm_sme.smopa_wide_2way $lhs, $rhs : vector<[8]xi16>, vector<[8]xi16> into vector<[4]x[4]xi32>
+
+    Refer to
+    [fmopa_wide_2way](#arm_smefmopa_wide_2way-arm_smefmopa_wide_2wayop) for a
+    detailed description of 2-way outer products.
+
+    | Spec | Features |
+    | ---- | -------- |
+    | [SMOPA (2-way)](https://developer.arm.com/documentation/ddi0602/2023-09/SME-Instructions/SMOPA--2-way---Signed-integer-sum-of-outer-products-and-accumulate-) | +sme2 |
+    ```
+  }];
+}
+
+def SMopsWide2WayOp
+  : OuterProductWide2Way<"smops_wide_2way",
+      [ScalableVectorOfRankAndLengthAndType<[1], [8], [I16]>],
+      [nxnxv4i32]> {
+  let summary = "Signed integer sum of 2 outer products and subtract";
+  let description = [{
+    Example:
+    ```mlir
+    %result = arm_sme.smops_wide_2way $lhs, $rhs : vector<[8]xi16>, vector<[8]xi16> into vector<[4]x[4]xi32>
+
+    Refer to
+    [fmopa_wide_2way](#arm_smefmopa_wide_2way-arm_smefmopa_wide_2wayop) for a
+    detailed description of 2-way outer products.
+
+    | Spec | Features |
+    | ---- | -------- |
+    | [SMOPS (2-way)](https://developer.arm.com/documentation/ddi0602/2023-09/SME-Instructions/SMOPS--2-way---Signed-integer-sum-of-outer-products-and-subtract-) | +sme2 |
+    ```
+  }];
+}
+
+def UMopaWide2WayOp
+  : OuterProductWide2Way<"umopa_wide_2way",
+      [ScalableVectorOfRankAndLengthAndType<[1], [8], [I16]>],
+      [nxnxv4i32]> {
+  let summary = "Unsiged integer sum of 2 outer products and accumulate";
+  let description = [{
+    Example:
+    ```mlir
+    %result = arm_sme.umopa_wide_2way $lhs, $rhs : vector<[8]xi16>, vector<[8]xi16> into vector<[4]x[4]xi32>
+
+    Refer to
+    [fmopa_wide_2way](#arm_smefmopa_wide_2way-arm_smefmopa_wide_2wayop) for a
+    detailed description of 2-way outer products.
+
+    | Spec | Features |
+    | ---- | -------- |
+    | [UMOPA (2-way)](https://developer.arm.com/documentation/ddi0602/2023-09/SME-Instructions/UMOPA--2-way---Unsigned-integer-sum-of-outer-products-and-accumulate-) | +sme2 |
+    ```
+  }];
+}
+
+def UMopsWide2WayOp
+  : OuterProductWide2Way<"umops_wide_2way",
+      [ScalableVectorOfRankAndLengthAndType<[1], [8], [I16]>],
+      [nxnxv4i32]> {
+  let summary = "Unsiged integer sum of 2 outer products and subtract";
+  let description = [{
+    Example:
+    ```mlir
+    %result = arm_sme.umops_wide_2way $lhs, $rhs : vector<[8]xi16>, vector<[8]xi16> into vector<[4]x[4]xi32>
+
+    Refer to
+    [fmopa_wide_2way](#arm_smefmopa_wide_2way-arm_smefmopa_wide_2wayop) for a
+    detailed description of 2-way outer products.
+
+    | Spec | Features |
+    | ---- | -------- |
+    | [UMOPS (2-way)](https://developer.arm.com/documentation/ddi0602/2023-09/SME-Instructions/UMOPS--2-way---Unsigned-integer-sum-of-outer-products-and-subtract-) | +sme2 |
+    ```
+  }];
+}
+
+def SMopaWide4WayOp
+  : OuterProductWide4Way<"smopa_wide_4way",
+      [ScalableVectorOfRankAndLengthAndType<[1], [16], [I8]>,
+       ScalableVectorOfRankAndLengthAndType<[1], [8], [I16]>],
+      [nxnxv4i32, nxnxv2i64]> {
+  let summary = "Signed integer sum of 4 outer products and accumulate";
+  let description = [{
+    This operation represents a sum of 4 widened outer products. It takes 2 1-D
+    scalable vectors as input and a 2-D scalable vector (ZA tile) as output.
+
+    For example (i8 to i32):
+
+    ```mlir
+    %result = arm_sme.smopa_wide_4way $lhs, $rhs :
+      vector<[16]xi8>, vector<[16]xi8> into vector<[4]x[4]xi32>
+    ```
+
+    The `lhs` encodes a matrix of shape SVLSx4 and the `rhs` a matrix of
+    4xSVLS, where SVLS (spec [1], section B2.1) is the number of 32-bit
+    elements in a vector of SVL bits. To illustrate, below is a breakdown of
+    this operation for SVL=128 (i.e., vscale=1):
+
+    ```
+                                        LHS
+              [A0 A1 A2 A3 A4 A5 A6 A7 A8 A9 A10 A11 A12 A15 A14 A15]
+
+                                        RHS
+              [B0 B1 B2 B3 B4 B5 B6 B7 B8 B9 B10 B11 B12 B13 B14 B15]
+
+    ----------------------------------------------------------------------------
+
+                                  implicit layout
+
+                    [A0   A1  A2  A3]    |    [B0 B4  B8 B12]
+                    [A4   A5  A6  A7]    |    [B1 B5  B9 B13]
+                    [A8   A9 A10 A11]    |    [B2 B6 B10 B14]
+                    [A12 A13 A14 A15]    |    [B3 B7 B11 B15]
+
+    ----------------------------------------------------------------------------
+
+                                  4 outer products
+
+                 Acol0 ⊗ Brow0           |            Acol1 ⊗ Brow1
+                 -------------           |            -------------
+                                         |
+             [B0 B4 B8 B12]              |        [B1 B5 B9 B13]
+                                         |
+       [A0   [ A0B0  A0B4  A0B8  A0B12]  |  [A1   [ A1B1  A1B5  A1B9  A1B13]
+        A4   [ A4B0  A4B4  A4B8  A4B12]  |   A5   [ A5B1  A5B5  A5B9  A5B13]
+        A8   [ A8B0  A8B4  A8B8  A8B12]  |   A9   [ A9B1  A9B5  A9B9  A9B13]
+        A12] [A12B0 A12B4 A12B8 A12B12]  |   A13] [A13B1 A13B5 A13B9 A13B13]
+                                         |
+                 Acol2 ⊗ Brow2           |            Acol3 ⊗ Brow3
+                 -------------           |            -------------
+                                         |
+             [B2, B6, B10, B14]          |        [B3 B7 B11 B15]
+                                         |
+       [A2   [ A2B2  A2B6  A2B10  A2B14] |  [A3   [ A3B3  A3B7  A3B11  A3B15]
+        A6   [ A6B2  A6B6  A6B10  A6B14] |   A7   [ A7B3  A7B7  A7B11  A7B15]
+        A10  [A10B2 A10B6 A10B10 A10B14] |   A11  [A11B3 A11B7 A11B11 A11B15]
+        A14] [A14B2 A14B6 A14B10 A14B14] |   A15] [A15B3 A15B7 A15B11 A15B15]
+                                         |
+
+    ----------------------------------------------------------------------------
+
+                              sum of 4 outer products
+
+           Acol0 ⊗ Brow0 + Acol1 ⊗ Brow1 + Acol2 ⊗ Brow2 + Acol3 ⊗ Brow3
+
+     [ A0B0 +  A1B1 +  A2B2 +  A3B3 ... ...  A0B12 +  A1B13 +  A2B14 +  A3B15]
+     [ A4B0 +  A5B1 +  A6B2 +  A7B3 ... ...  A4B12 +  A5B13 +  A6B14 +  A7B15]
+     [ A8B0 +  A9B1 + A10B2 + A11B3 ... ...  A8B12 +  A9B13 + A10B14 + A11B15]
+     [A12B0 + A13B1 + A14B2 + A15B3 ... ... A12B12 + A13B13 + A14B14 + A15B15]
+
+    ----------------------------------------------------------------------------
+    ```
+
+    This operation enables the folding of 4 outer products chained via the
+    accumulator into a single outer product.
+
+    For example:
+
+    ```mlir
+    %a0_ext = arith.extsi %a0 : vector<[4]xi8> to vector<[4]xi32>
+    %b0_ext = arith.extsi %b0 : vector<[4]xi8> to vector<[4]xi32>
+
+    %a1_ext = arith.extsi %a1 : vector<[4]xi8> to vector<[4]xi32>
+    %b1_ext = arith.extsi %b1 : vector<[4]xi8> to vector<[4]xi32>
+
+    %a2_ext = arith.extsi %a2 : vector<[4]xi8> to vector<[4]xi32>
+    %b2_ext = arith.extsi %b2 : vector<[4]xi8> to vector<[4]xi32>
+
+    %a3_ext = arith.extsi %a3 : vector<[4]xi8> to vector<[4]xi32>
+    %b3_ext = arith.extsi %b3 : vector<[4]xi8> to vector<[4]xi32>
+
+    %0 = arm_sme.outerproduct %a0_ext, %b0_ext : vector<[4]xi32>, vector<[4]xi32>
+    %1 = arm_sme.outerproduct %a1_ext, %b1_ext acc(%0) : vector<[4]xi32>, vector<[4]xi32>
+    %2 = arm_sme.outerproduct %a2_ext, %b2_ext acc(%1) : vector<[4]xi32>, vector<[4]xi32>
+    %3 = arm_sme.outerproduct %a3_ext, %b3_ext acc(%2) : vector<[4]xi32>, vector<[4]xi32>
+    ```
+
+    The 4 outer products in the example above can be fused into a single outer
+    product as follows:
+
+	```mlir
+    %undef = llvm.mlir.undef : vector<[8]xf16>
+    %a0_ins = vector.scalable.insert %a0, %undef[0] : vector<[4]xf16> into vector<[8]xf16>
+    %a1_ins = vector.scalable.insert %a1, %undef[0] : vector<[4]xf16> into vector<[8]xf16>
+    %a2_ins = vector.scalable.insert %a2, %undef[0] : vector<[4]xf16> into vector<[8]xf16>
+    %a3_ins = vector.scalable.insert %a3, %undef[0] : vector<[4]xf16> into vector<[8]xf16>
+    %lhs0 = "arm_sve.intr.zip1"(%a0_ins, %a2_ins) : (vector<[8]xf16>, vector<[8]xf16>) -> vector<[8]xf16>
+    %lhs1 = "arm_sve.intr.zip1"(%a1_ins, %a3_ins) : (vector<[8]xf16>, vector<[8]xf16>) -> vector<[8]xf16>
+    %lhs = "arm_sve.intr.zip1"(%lhs0, %lhs1) : (vector<[8]xf16>, vector<[8]xf16>) -> vector<[8]xf16>
+
+    %b0_ins = vector.scalable.insert %b0, %undef[0] : vector<[4]xf16> into vector<[8]xf16>
+    %b1_ins = vector.scalable.insert %b1, %undef[0] : vector<[4]xf16> into vector<[8]xf16>
+    %b2_ins = vector.scalable.insert %b2, %undef[0] : vector<[4]xf16> into vector<[8]xf16>
+    %b3_ins = vector.scalable.insert %b3, %undef[0] : vector<[4]xf16> into vector<[8]xf16>
+    %rhs0 = "arm_sve.intr.zip1"(%b0_ins, %b2_ins) : (vector<[8]xf16>, vector<[8]xf16>) -> vector<[8]xf16>
+    %rhs1 = "arm_sve.intr.zip1"(%b1_ins, %b3_ins) : (vector<[8]xf16>, vector<[8]xf16>) -> vector<[8]xf16>
+    %rhs = "arm_sve.intr.zip1"(%rhs0, %rhs1) : (vector<[8]xf16>, vector<[8]xf16>) -> vector<[8]xf16>
+
+    %0 = arm_sme.smopa_wide_4way %lhs, %rhs : vector<[16]xi8>, vector<[16]xi8> into vector<[4]x[4]xi32>
+	```
+
+    This is implemented in the `-arm-sme-outer-product-widening` pass.
+
+    Example: I8 to I32
+    ```mlir
+    %result = arm_sme.smopa_wide_4way $lhs, $rhs : vector<[16]xi8>, vector<[16]...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Jan 24, 2024

@llvm/pr-subscribers-mlir-vector

Author: Cullen Rhodes (c-rhodes)

Changes

This patch introduces support for 4-way widening outer products. This enables
the folding of 4 'arm_sme.outerproduct' operations that are chained via the
accumulator into single widened operations.

Changes:

  • Adds the following operations:
    • smopa_wide_4way, smops_wide_4way
    • umopa_wide_4way, umops_wide_4way
    • sumopa_wide_4way, sumops_wide_4way
    • sumopa_wide_4way, sumops_wide_4way
  • Implements conversions for the above ops to intrinsics in ArmSMEToLLVM.
  • Extends 'arm-sme-outer-product' pass.

For a detailed description of these operations see the
'arm_sme.smopa_wide_4way' description.


Patch is 152.30 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/79288.diff

17 Files Affected:

  • (modified) mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEIntrinsicOps.td (+4)
  • (modified) mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td (+643)
  • (modified) mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h (+3)
  • (modified) mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td (+39)
  • (modified) mlir/include/mlir/Dialect/ArmSME/Transforms/Transforms.h (+4)
  • (modified) mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td (+4)
  • (modified) mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp (+80-2)
  • (modified) mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt (+2)
  • (added) mlir/lib/Dialect/ArmSME/Transforms/OuterProductWidening.cpp (+501)
  • (modified) mlir/test/Conversion/ArmSMEToLLVM/arm-sme-to-llvm.mlir (+272)
  • (modified) mlir/test/Dialect/ArmSME/invalid.mlir (+66)
  • (added) mlir/test/Dialect/ArmSME/outer-product-widening.mlir (+785)
  • (modified) mlir/test/Dialect/ArmSME/roundtrip.mlir (+272)
  • (added) mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-f16f16f32.mlir (+100)
  • (added) mlir/test/Integration/Dialect/Vector/CPU/ArmSME/test-outerproduct-i8i8i32.mlir (+142)
  • (modified) mlir/test/Target/LLVMIR/arm-sme.mlir (+12)
  • (modified) mlir/test/Target/LLVMIR/arm-sve.mlir (+7)
diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEIntrinsicOps.td b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEIntrinsicOps.td
index d85ef963ae5dc4..f051e03efbcda6 100644
--- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEIntrinsicOps.td
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEIntrinsicOps.td
@@ -105,6 +105,10 @@ def LLVM_aarch64_sme_sumopa_wide : ArmSME_IntrMopOverloadedOp<"sumopa.wide">;
 def LLVM_aarch64_sme_sumops_wide : ArmSME_IntrMopOverloadedOp<"sumops.wide">;
 def LLVM_aarch64_sme_usmopa_wide : ArmSME_IntrMopOverloadedOp<"usmopa.wide">;
 def LLVM_aarch64_sme_usmops_wide : ArmSME_IntrMopOverloadedOp<"usmops.wide">;
+def LLVM_aarch64_sme_smopa_za32 : ArmSME_IntrMopOverloadedOp<"smopa.za32">;
+def LLVM_aarch64_sme_umopa_za32 : ArmSME_IntrMopOverloadedOp<"umopa.za32">;
+def LLVM_aarch64_sme_smops_za32 : ArmSME_IntrMopOverloadedOp<"smops.za32">;
+def LLVM_aarch64_sme_umops_za32 : ArmSME_IntrMopOverloadedOp<"umops.za32">;
 
 class ArmSME_IntrLoadStoreOp<string mnemonic>
     : ArmSME_IntrOp<mnemonic,
diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
index 8a34ad7e52012f..ed8b100eadf3ab 100644
--- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
+++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSMEOps.td
@@ -814,6 +814,649 @@ let arguments = (ins
   }];
 }
 
+class OuterProductWideBase<string mnemonic,
+                           list<Type> allowedInputVectorTypes,
+                           list<Type> allowedResultVectorTypes,
+                           int numOuterProducts> :
+  ArmSME_Op<mnemonic, [
+    ArmSMETileOpInterface,
+    AttrSizedOperandSegments,
+    AllTypesMatch<["lhs", "rhs"]>,
+    HasMatchingMaskTypeConstraint<"lhs", "lhsMask">,
+    HasMatchingMaskTypeConstraint<"rhs", "rhsMask">,
+    PredOpTrait<
+      "both `lhsMask` and `rhsMask` should be provided or neither",
+      CPred<"bool(getLhsMask()) == bool(getRhsMask())">
+    >,
+    OptionalTypesMatchWith<"result and acc have the same type",
+                           "result", "acc", "::llvm::cast<Type>($_self)">,
+    // this trait ensures the input type match the correct output type for ops
+    // that takes multiple inputs and outputs (i.e., 4-way).
+    PredOpTrait<
+      "tile element size equals lhs element size * " # numOuterProducts,
+      CPred<"getTileType().getElementTypeBitWidth() == "
+            "(getLhsType().getElementTypeBitWidth() * " # numOuterProducts # ")">
+    >,
+  ]> {
+
+  let arguments = (ins
+    AnyTypeOf<allowedInputVectorTypes>:$lhs, AnyVector:$rhs,
+    Optional<AnyVector>:$lhsMask, Optional<AnyVector>:$rhsMask,
+    Optional<AnyVector>:$acc);
+  let results = (outs AnyTypeOf<allowedResultVectorTypes>:$result);
+
+  let assemblyFormat = [{
+    $lhs `,` $rhs
+    oilist(
+        `acc` `` `(` $acc `)`
+      | `masks` `` `(` $lhsMask `,` $rhsMask `)`
+    ) attr-dict `:` type($lhs) `,` type($rhs) `into` type($result)
+  }];
+
+  let extraClassDeclaration = [{
+    VectorType getLhsType() { return llvm::cast<VectorType>(getLhs().getType()); }
+    VectorType getRhsType() { return llvm::cast<VectorType>(getRhs().getType()); }
+    VectorType getResultType() { return llvm::cast<VectorType>(getResult().getType()); }
+    std::optional<arm_sme::ArmSMETileType> getAllocatedTileType() {
+      // The outerproduct op allocates a new tile if no accumulator is passed.
+      if (!getAcc())
+        return arm_sme::getSMETileType(getResultType());
+      return std::nullopt;
+    }
+    VectorType getTileType() {
+      return getResultType();
+    }
+  }];
+}
+
+class OuterProductWide2Way<string mnemonic,
+                           list<Type> allowedInputVectorTypes,
+                           list<Type> allowedResultVectorTypes>
+  : OuterProductWideBase<mnemonic, allowedInputVectorTypes,
+                         allowedResultVectorTypes, /*numOuterProducts=*/2>;
+
+class OuterProductWide4Way<string mnemonic,
+                           list<Type> allowedInputVectorTypes,
+                           list<Type> allowedResultVectorTypes>
+  : OuterProductWideBase<mnemonic, allowedInputVectorTypes,
+                         allowedResultVectorTypes, /*numOuterProducts=*/4>;
+
+def FMopaWide2WayOp
+  : OuterProductWide2Way<"fmopa_wide_2way",
+      [ScalableVectorOfRankAndLengthAndType<[1], [8], [F16, BF16]>],
+      [nxnxv4f32]> {
+  let summary = "Floating-point sum of 2 outer products and accumulate";
+
+  let description = [{
+    This operation represents a sum of 2 widened outer products. It takes 2 1-D
+    scalable vectors as input and a 2-D scalable vector (ZA tile) as output.
+
+    For example (fp16 to fp32):
+
+    ```mlir
+    %result = arm_sme.fmopa_wide_2way %lhs, %rhs :
+      vector<[8]xf16>, vector<[8]xf16> into vector<[4]x[4]xf32>
+    ```
+
+    The `lhs` encodes a matrix of shape SVLSx2 and the `rhs` a matrix of
+    2xSVLS, where SVLS (spec [1], section B2.1) is the number of 32-bit
+    elements in a vector of SVL bits. To illustrate, below is a breakdown of
+    this operation for SVL=128 (i.e., vscale=1):
+
+    ```
+                          LHS                          RHS
+               [A0 A1 A2 A3 A4 A5 A6 A7]    [B0 B1 B2 B3 B4 B5 B6 B7]
+
+    ----------------------------------------------------------------------------
+
+                                  implicit layout
+
+                              [A0 A1]    |
+                              [A2 A3]    |    [B0 B2 B4 B6]
+                              [A4 A5]    |    [B1 B3 B5 B7]
+                              [A6 A7]    |
+
+    ----------------------------------------------------------------------------
+
+                                  2 outer products
+
+                      Acol0 ⊗ Brow0      |           Acol1 ⊗ Brow1
+                      -------------      |           -------------
+                                         |
+                  [B0 B2 B4 B6]          |       [B1 B3 B5 B7]
+                                         |
+             [A0  [A0B0 A0B2 A0B4 A0B6]  |  [A1  [A1B1 A1B3 A1B5 A1B7]
+              A2  [A2B0 A2B2 A2B4 A2B6]  |   A3  [A3B1 A3B3 A3B5 A3B7]
+              A4  [A4B0 A4B2 A4B4 A4B6]  |   A5  [A5B1 A5B3 A5B5 A5B7]
+              A6] [A6B0 A6B2 A6B4 A6B6]  |   A7] [A7B1 A7B3 A7B5 A7B7]
+                                         |
+
+    ----------------------------------------------------------------------------
+
+                              sum of 2 outer products
+
+                           Acol0 ⊗ Brow0 + Acol1 ⊗ Brow1
+
+                 [A0B0 + A1B1 A0B2 + A1B3 A0B4 + A1B5 A0B6 + A1B7]
+                 [A2B0 + A3B1 A2B2 + A3B3 A2B4 + A3B5 A2B6 + A3B7]
+                 [A4B0 + A5B1 A4B2 + A5B3 A4B4 + A5B5 A4B6 + A5B7]
+                 [A6B0 + A7B1 A6B2 + A7B3 A6B4 + A7B5 A6B6 + A7B7]
+
+    ----------------------------------------------------------------------------
+    ```
+
+    This operation enables the folding of 2 outer products chained via the
+    accumulator into a single outer product.
+
+    For example:
+
+    ```mlir
+    %a0_ext = arith.extf %a0 : vector<[4]xf16> to vector<[4]xf32>
+    %b0_ext = arith.extf %b0 : vector<[4]xf16> to vector<[4]xf32>
+    %a1_ext = arith.extf %a1 : vector<[4]xf16> to vector<[4]xf32>
+    %b1_ext = arith.extf %b1 : vector<[4]xf16> to vector<[4]xf32>
+
+    %0 = arm_sme.outerproduct %a0_ext, %b0_ext : vector<[4]xf32>, vector<[4]xf32>
+    %1 = arm_sme.outerproduct %a1_ext, %b1_ext acc(%0) : vector<[4]xf32>, vector<[4]xf32>
+    ```
+
+    The 2 outer products in the example above can be fused into a single outer
+    product as follows:
+
+	```mlir
+    %undef = llvm.mlir.undef : vector<[8]xf16>
+    %a0_ins = vector.scalable.insert %a0, %undef[0] : vector<[4]xf16> into vector<[8]xf16>
+    %a1_ins = vector.scalable.insert %a1, %undef[0] : vector<[4]xf16> into vector<[8]xf16>
+    %a_packed = "arm_sve.intr.zip1"(%a0_ins, %a1_ins) : (vector<[8]xf16>, vector<[8]xf16>) -> vector<[8]xf16>
+    %b0_ins = vector.scalable.insert %b0, %undef[0] : vector<[4]xf16> into vector<[8]xf16>
+    %b1_ins = vector.scalable.insert %b1, %undef[0] : vector<[4]xf16> into vector<[8]xf16>
+    %b_packed = "arm_sve.intr.zip1"(%b0_ins, %b1_ins) : (vector<[8]xf16>, vector<[8]xf16>) -> vector<[8]xf16>
+    %0 = arm_sme.fmopa_wide_2way %a_packed, %b_packed : vector<[8]xf16>, vector<[8]xf16> into vector<[4]x[4]xf32>
+	```
+
+    This is implemented in the `-arm-sme-outer-product-widening` pass.
+
+    Example: FP16 to FP32
+    ```mlir
+    %result = arm_sme.fmopa_wide_2way $lhs, $rhs : vector<[8]xf16>, vector<[8]xf16> into vector<[4]x[4]xf32>
+    ```
+
+    Example: BF16 to FP32
+    ```mlir
+    %result = arm_sme.fmopa_wide_2way $lhs, $rhs : vector<[8]xbf16>, vector<[8]xbf16> into vector<[4]x[4]xf32>
+    ```
+
+    | Spec | Features |
+    | ---- | -------- |
+    | [FMOPA (widening, 2-way, FP16 to FP32)](https://developer.arm.com/documentation/ddi0602/2023-09/SME-Instructions/FMOPA--widening--2-way--FP16-to-FP32---Half-precision-floating-point-sum-of-outer-products-and-accumulate-) | +sme |
+    | [BFMOPA (widening, 2-way, BF16 to FP32)](https://developer.arm.com/documentation/ddi0602/2023-09/SME-Instructions/BFMOPA--widening---BFloat16-sum-of-outer-products-and-accumulate-) | +sme |
+
+    [1] https://developer.arm.com/documentation/ddi0616
+  }];
+}
+
+// TODO: support:
+// - FMOPA 2-way FP8 to FP16
+// - FMOPA 4-way FP16 to FP32
+// once intrinsic support lands in the backend.
+
+def FMopsWide2WayOp
+  : OuterProductWide2Way<"fmops_wide_2way",
+      [ScalableVectorOfRankAndLengthAndType<[1], [8], [F16, BF16]>],
+      [nxnxv4f32]> {
+  let summary = "Floating-point sum of 2 outer products and subtract";
+  let description = [{
+    Equivalent to `fmopa_wide_2way` but outer products are subtracted from
+    destination `result`.
+
+    Example: FP16 to FP32
+    ```mlir
+    %result = arm_sme.fmops_wide_2way $lhs, $rhs : vector<[8]xf16>, vector<[8]xf16> into vector<[4]x[4]xf32>
+    ```
+
+    Example: BF16 to FP32
+    ```mlir
+    %result = arm_sme.fmops_wide_2way $lhs, $rhs : vector<[8]xbf16>, vector<[8]xbf16> into vector<[4]x[4]xf32>
+
+    Refer to
+    [fmopa_wide_2way](#arm_smefmopa_wide_2way-arm_smefmopa_wide_2wayop) for a
+    detailed description of 2-way outer products.
+
+    | Spec | Features |
+    | ---- | -------- |
+    | [FMOPS (widening, 2-way, FP16 to FP32)](https://developer.arm.com/documentation/ddi0602/2023-09/SME-Instructions/FMOPS--widening---Half-precision-floating-point-sum-of-outer-products-and-subtract-) | +sme |
+    | [BFMOPS (widening, 2-way, BF16 to FP32)](https://developer.arm.com/documentation/ddi0602/2023-09/SME-Instructions/BMOPS--Bitwise-exclusive-NOR-population-count-outer-product-and-subtract-) | +sme |
+    ```
+  }];
+}
+
+def SMopaWide2WayOp
+  : OuterProductWide2Way<"smopa_wide_2way",
+      [ScalableVectorOfRankAndLengthAndType<[1], [8], [I16]>],
+      [nxnxv4i32]> {
+  let summary = "Signed integer sum of 2 outer products and accumulate";
+  let description = [{
+    Example:
+    ```mlir
+    %result = arm_sme.smopa_wide_2way $lhs, $rhs : vector<[8]xi16>, vector<[8]xi16> into vector<[4]x[4]xi32>
+
+    Refer to
+    [fmopa_wide_2way](#arm_smefmopa_wide_2way-arm_smefmopa_wide_2wayop) for a
+    detailed description of 2-way outer products.
+
+    | Spec | Features |
+    | ---- | -------- |
+    | [SMOPA (2-way)](https://developer.arm.com/documentation/ddi0602/2023-09/SME-Instructions/SMOPA--2-way---Signed-integer-sum-of-outer-products-and-accumulate-) | +sme2 |
+    ```
+  }];
+}
+
+def SMopsWide2WayOp
+  : OuterProductWide2Way<"smops_wide_2way",
+      [ScalableVectorOfRankAndLengthAndType<[1], [8], [I16]>],
+      [nxnxv4i32]> {
+  let summary = "Signed integer sum of 2 outer products and subtract";
+  let description = [{
+    Example:
+    ```mlir
+    %result = arm_sme.smops_wide_2way $lhs, $rhs : vector<[8]xi16>, vector<[8]xi16> into vector<[4]x[4]xi32>
+
+    Refer to
+    [fmopa_wide_2way](#arm_smefmopa_wide_2way-arm_smefmopa_wide_2wayop) for a
+    detailed description of 2-way outer products.
+
+    | Spec | Features |
+    | ---- | -------- |
+    | [SMOPS (2-way)](https://developer.arm.com/documentation/ddi0602/2023-09/SME-Instructions/SMOPS--2-way---Signed-integer-sum-of-outer-products-and-subtract-) | +sme2 |
+    ```
+  }];
+}
+
+def UMopaWide2WayOp
+  : OuterProductWide2Way<"umopa_wide_2way",
+      [ScalableVectorOfRankAndLengthAndType<[1], [8], [I16]>],
+      [nxnxv4i32]> {
+  let summary = "Unsiged integer sum of 2 outer products and accumulate";
+  let description = [{
+    Example:
+    ```mlir
+    %result = arm_sme.umopa_wide_2way $lhs, $rhs : vector<[8]xi16>, vector<[8]xi16> into vector<[4]x[4]xi32>
+
+    Refer to
+    [fmopa_wide_2way](#arm_smefmopa_wide_2way-arm_smefmopa_wide_2wayop) for a
+    detailed description of 2-way outer products.
+
+    | Spec | Features |
+    | ---- | -------- |
+    | [UMOPA (2-way)](https://developer.arm.com/documentation/ddi0602/2023-09/SME-Instructions/UMOPA--2-way---Unsigned-integer-sum-of-outer-products-and-accumulate-) | +sme2 |
+    ```
+  }];
+}
+
+def UMopsWide2WayOp
+  : OuterProductWide2Way<"umops_wide_2way",
+      [ScalableVectorOfRankAndLengthAndType<[1], [8], [I16]>],
+      [nxnxv4i32]> {
+  let summary = "Unsiged integer sum of 2 outer products and subtract";
+  let description = [{
+    Example:
+    ```mlir
+    %result = arm_sme.umops_wide_2way $lhs, $rhs : vector<[8]xi16>, vector<[8]xi16> into vector<[4]x[4]xi32>
+
+    Refer to
+    [fmopa_wide_2way](#arm_smefmopa_wide_2way-arm_smefmopa_wide_2wayop) for a
+    detailed description of 2-way outer products.
+
+    | Spec | Features |
+    | ---- | -------- |
+    | [UMOPS (2-way)](https://developer.arm.com/documentation/ddi0602/2023-09/SME-Instructions/UMOPS--2-way---Unsigned-integer-sum-of-outer-products-and-subtract-) | +sme2 |
+    ```
+  }];
+}
+
+def SMopaWide4WayOp
+  : OuterProductWide4Way<"smopa_wide_4way",
+      [ScalableVectorOfRankAndLengthAndType<[1], [16], [I8]>,
+       ScalableVectorOfRankAndLengthAndType<[1], [8], [I16]>],
+      [nxnxv4i32, nxnxv2i64]> {
+  let summary = "Signed integer sum of 4 outer products and accumulate";
+  let description = [{
+    This operation represents a sum of 4 widened outer products. It takes 2 1-D
+    scalable vectors as input and a 2-D scalable vector (ZA tile) as output.
+
+    For example (i8 to i32):
+
+    ```mlir
+    %result = arm_sme.smopa_wide_4way $lhs, $rhs :
+      vector<[16]xi8>, vector<[16]xi8> into vector<[4]x[4]xi32>
+    ```
+
+    The `lhs` encodes a matrix of shape SVLSx4 and the `rhs` a matrix of
+    4xSVLS, where SVLS (spec [1], section B2.1) is the number of 32-bit
+    elements in a vector of SVL bits. To illustrate, below is a breakdown of
+    this operation for SVL=128 (i.e., vscale=1):
+
+    ```
+                                        LHS
+              [A0 A1 A2 A3 A4 A5 A6 A7 A8 A9 A10 A11 A12 A15 A14 A15]
+
+                                        RHS
+              [B0 B1 B2 B3 B4 B5 B6 B7 B8 B9 B10 B11 B12 B13 B14 B15]
+
+    ----------------------------------------------------------------------------
+
+                                  implicit layout
+
+                    [A0   A1  A2  A3]    |    [B0 B4  B8 B12]
+                    [A4   A5  A6  A7]    |    [B1 B5  B9 B13]
+                    [A8   A9 A10 A11]    |    [B2 B6 B10 B14]
+                    [A12 A13 A14 A15]    |    [B3 B7 B11 B15]
+
+    ----------------------------------------------------------------------------
+
+                                  4 outer products
+
+                 Acol0 ⊗ Brow0           |            Acol1 ⊗ Brow1
+                 -------------           |            -------------
+                                         |
+             [B0 B4 B8 B12]              |        [B1 B5 B9 B13]
+                                         |
+       [A0   [ A0B0  A0B4  A0B8  A0B12]  |  [A1   [ A1B1  A1B5  A1B9  A1B13]
+        A4   [ A4B0  A4B4  A4B8  A4B12]  |   A5   [ A5B1  A5B5  A5B9  A5B13]
+        A8   [ A8B0  A8B4  A8B8  A8B12]  |   A9   [ A9B1  A9B5  A9B9  A9B13]
+        A12] [A12B0 A12B4 A12B8 A12B12]  |   A13] [A13B1 A13B5 A13B9 A13B13]
+                                         |
+                 Acol2 ⊗ Brow2           |            Acol3 ⊗ Brow3
+                 -------------           |            -------------
+                                         |
+             [B2, B6, B10, B14]          |        [B3 B7 B11 B15]
+                                         |
+       [A2   [ A2B2  A2B6  A2B10  A2B14] |  [A3   [ A3B3  A3B7  A3B11  A3B15]
+        A6   [ A6B2  A6B6  A6B10  A6B14] |   A7   [ A7B3  A7B7  A7B11  A7B15]
+        A10  [A10B2 A10B6 A10B10 A10B14] |   A11  [A11B3 A11B7 A11B11 A11B15]
+        A14] [A14B2 A14B6 A14B10 A14B14] |   A15] [A15B3 A15B7 A15B11 A15B15]
+                                         |
+
+    ----------------------------------------------------------------------------
+
+                              sum of 4 outer products
+
+           Acol0 ⊗ Brow0 + Acol1 ⊗ Brow1 + Acol2 ⊗ Brow2 + Acol3 ⊗ Brow3
+
+     [ A0B0 +  A1B1 +  A2B2 +  A3B3 ... ...  A0B12 +  A1B13 +  A2B14 +  A3B15]
+     [ A4B0 +  A5B1 +  A6B2 +  A7B3 ... ...  A4B12 +  A5B13 +  A6B14 +  A7B15]
+     [ A8B0 +  A9B1 + A10B2 + A11B3 ... ...  A8B12 +  A9B13 + A10B14 + A11B15]
+     [A12B0 + A13B1 + A14B2 + A15B3 ... ... A12B12 + A13B13 + A14B14 + A15B15]
+
+    ----------------------------------------------------------------------------
+    ```
+
+    This operation enables the folding of 4 outer products chained via the
+    accumulator into a single outer product.
+
+    For example:
+
+    ```mlir
+    %a0_ext = arith.extsi %a0 : vector<[4]xi8> to vector<[4]xi32>
+    %b0_ext = arith.extsi %b0 : vector<[4]xi8> to vector<[4]xi32>
+
+    %a1_ext = arith.extsi %a1 : vector<[4]xi8> to vector<[4]xi32>
+    %b1_ext = arith.extsi %b1 : vector<[4]xi8> to vector<[4]xi32>
+
+    %a2_ext = arith.extsi %a2 : vector<[4]xi8> to vector<[4]xi32>
+    %b2_ext = arith.extsi %b2 : vector<[4]xi8> to vector<[4]xi32>
+
+    %a3_ext = arith.extsi %a3 : vector<[4]xi8> to vector<[4]xi32>
+    %b3_ext = arith.extsi %b3 : vector<[4]xi8> to vector<[4]xi32>
+
+    %0 = arm_sme.outerproduct %a0_ext, %b0_ext : vector<[4]xi32>, vector<[4]xi32>
+    %1 = arm_sme.outerproduct %a1_ext, %b1_ext acc(%0) : vector<[4]xi32>, vector<[4]xi32>
+    %2 = arm_sme.outerproduct %a2_ext, %b2_ext acc(%1) : vector<[4]xi32>, vector<[4]xi32>
+    %3 = arm_sme.outerproduct %a3_ext, %b3_ext acc(%2) : vector<[4]xi32>, vector<[4]xi32>
+    ```
+
+    The 4 outer products in the example above can be fused into a single outer
+    product as follows:
+
+	```mlir
+    %undef = llvm.mlir.undef : vector<[8]xf16>
+    %a0_ins = vector.scalable.insert %a0, %undef[0] : vector<[4]xf16> into vector<[8]xf16>
+    %a1_ins = vector.scalable.insert %a1, %undef[0] : vector<[4]xf16> into vector<[8]xf16>
+    %a2_ins = vector.scalable.insert %a2, %undef[0] : vector<[4]xf16> into vector<[8]xf16>
+    %a3_ins = vector.scalable.insert %a3, %undef[0] : vector<[4]xf16> into vector<[8]xf16>
+    %lhs0 = "arm_sve.intr.zip1"(%a0_ins, %a2_ins) : (vector<[8]xf16>, vector<[8]xf16>) -> vector<[8]xf16>
+    %lhs1 = "arm_sve.intr.zip1"(%a1_ins, %a3_ins) : (vector<[8]xf16>, vector<[8]xf16>) -> vector<[8]xf16>
+    %lhs = "arm_sve.intr.zip1"(%lhs0, %lhs1) : (vector<[8]xf16>, vector<[8]xf16>) -> vector<[8]xf16>
+
+    %b0_ins = vector.scalable.insert %b0, %undef[0] : vector<[4]xf16> into vector<[8]xf16>
+    %b1_ins = vector.scalable.insert %b1, %undef[0] : vector<[4]xf16> into vector<[8]xf16>
+    %b2_ins = vector.scalable.insert %b2, %undef[0] : vector<[4]xf16> into vector<[8]xf16>
+    %b3_ins = vector.scalable.insert %b3, %undef[0] : vector<[4]xf16> into vector<[8]xf16>
+    %rhs0 = "arm_sve.intr.zip1"(%b0_ins, %b2_ins) : (vector<[8]xf16>, vector<[8]xf16>) -> vector<[8]xf16>
+    %rhs1 = "arm_sve.intr.zip1"(%b1_ins, %b3_ins) : (vector<[8]xf16>, vector<[8]xf16>) -> vector<[8]xf16>
+    %rhs = "arm_sve.intr.zip1"(%rhs0, %rhs1) : (vector<[8]xf16>, vector<[8]xf16>) -> vector<[8]xf16>
+
+    %0 = arm_sme.smopa_wide_4way %lhs, %rhs : vector<[16]xi8>, vector<[16]xi8> into vector<[4]x[4]xi32>
+	```
+
+    This is implemented in the `-arm-sme-outer-product-widening` pass.
+
+    Example: I8 to I32
+    ```mlir
+    %result = arm_sme.smopa_wide_4way $lhs, $rhs : vector<[16]xi8>, vector<[16]xi8>...
[truncated]

@c-rhodes c-rhodes force-pushed the mlir-sme-4way-outerproducts branch 3 times, most recently from 142869d to 33c0dad Compare January 31, 2024 09:27
c-rhodes added a commit to c-rhodes/llvm-project that referenced this pull request Feb 2, 2024
In mixed matmul lowering (e.g., i8 to i32) we're seeing the following
sequence:

  %0 = arith.extsi %src : vector<4x[8]xi8> to vector<4x[8]xi32>
  %1 = vector.extract %0[0] : vector<[8]xi32> from vector<4x[8]xi32>
  %lhs = vector.scalable.extract %1[0] : vector<[4]xi32> from vector<[8]xi32>

  ... (same for rhs)

  %2 = vector.outerproduct %lhs, %rhs, %acc vector<[4]xi32>, vector<[4]xi32>

  // x4 chained by accumulator

This chain of 4 outer products can be fused into a single 4-way widening
variant but the pass doesn't match on the IR, as it expects the source
of the inputs to be an extend and it can't look through the extracts.

This patch fixes this with two rewrites that swaps extract(extend) into
extend(extract).

Related to llvm#78975, llvm#79288.
Copy link
Member

@MacDue MacDue left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Few initial comments:

Comment on lines 355 to 403
if (kind == arm_sme::CombiningKind::Add) {
if (isa<arith::ExtSIOp>(lhsExtOp) && isa<arith::ExtSIOp>(rhsExtOp))
rewriter.replaceOpWithNewOp<arm_sme::SMopa4WayOp>(
op4, op.getResultType(), lhs, rhs, lhsMask, rhsMask, op1.getAcc());
else if (isa<arith::ExtUIOp>(lhsExtOp) && isa<arith::ExtUIOp>(rhsExtOp))
rewriter.replaceOpWithNewOp<arm_sme::UMopa4WayOp>(
op4, op.getResultType(), lhs, rhs, lhsMask, rhsMask, op1.getAcc());
else if (isa<arith::ExtSIOp>(lhsExtOp) && isa<arith::ExtUIOp>(rhsExtOp))
rewriter.replaceOpWithNewOp<arm_sme::SuMopa4WayOp>(
op4, op.getResultType(), lhs, rhs, lhsMask, rhsMask, op1.getAcc());
else if (isa<arith::ExtUIOp>(lhsExtOp) && isa<arith::ExtSIOp>(rhsExtOp))
rewriter.replaceOpWithNewOp<arm_sme::UsMopa4WayOp>(
op4, op.getResultType(), lhs, rhs, lhsMask, rhsMask, op1.getAcc());
else
llvm_unreachable("unexpected extend op!");
} else if (kind == arm_sme::CombiningKind::Sub) {
if (isa<arith::ExtSIOp>(lhsExtOp) && isa<arith::ExtSIOp>(rhsExtOp))
rewriter.replaceOpWithNewOp<arm_sme::SMops4WayOp>(
op4, op.getResultType(), lhs, rhs, lhsMask, rhsMask, op1.getAcc());
else if (isa<arith::ExtUIOp>(lhsExtOp) && isa<arith::ExtUIOp>(rhsExtOp))
rewriter.replaceOpWithNewOp<arm_sme::UMops4WayOp>(
op4, op.getResultType(), lhs, rhs, lhsMask, rhsMask, op1.getAcc());
else if (isa<arith::ExtSIOp>(lhsExtOp) && isa<arith::ExtUIOp>(rhsExtOp))
rewriter.replaceOpWithNewOp<arm_sme::SuMops4WayOp>(
op4, op.getResultType(), lhs, rhs, lhsMask, rhsMask, op1.getAcc());
else if (isa<arith::ExtUIOp>(lhsExtOp) && isa<arith::ExtSIOp>(rhsExtOp))
rewriter.replaceOpWithNewOp<arm_sme::UsMops4WayOp>(
op4, op.getResultType(), lhs, rhs, lhsMask, rhsMask, op1.getAcc());
else
llvm_unreachable("unexpected extend op!");
} else {
llvm_unreachable("unexpected arm_sme::CombiningKind!");
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use TypeSwitch? Similarly to 2-way ;-)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there's two types to switch on because of the signed-by-unsigned / unsigned-by-signed variants and i don't believe it's possible to use TypeSwitch for this

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missed that! Not sure how to improve this then :/

c-rhodes added a commit to c-rhodes/llvm-project that referenced this pull request Feb 5, 2024
In mixed matmul lowering (e.g., i8 to i32) we're seeing the following
sequence:

  %0 = arith.extsi %src : vector<4x[8]xi8> to vector<4x[8]xi32>
  %1 = vector.extract %0[0] : vector<[8]xi32> from vector<4x[8]xi32>
  %lhs = vector.scalable.extract %1[0] : vector<[4]xi32> from vector<[8]xi32>

  ... (same for rhs)

  %2 = vector.outerproduct %lhs, %rhs, %acc vector<[4]xi32>, vector<[4]xi32>

  // x4 chained by accumulator

This chain of 4 outer products can be fused into a single 4-way widening
variant but the pass doesn't match on the IR, as it expects the source
of the inputs to be an extend and it can't look through the extracts.

This patch fixes this with two rewrites that swaps extract(extend) into
extend(extract).

Related to llvm#78975, llvm#79288.
c-rhodes added a commit that referenced this pull request Feb 5, 2024
In mixed matmul lowering (e.g., i8 to i32) we're seeing the following
sequence:

  %0 = arith.extsi %src : vector<4x[8]xi8> to vector<4x[8]xi32>
  %1 = vector.extract %0[0] : vector<[8]xi32> from vector<4x[8]xi32>
%lhs = vector.scalable.extract %1[0] : vector<[4]xi32> from
vector<[8]xi32>

  ... (same for rhs)

%2 = vector.outerproduct %lhs, %rhs, %acc vector<[4]xi32>,
vector<[4]xi32>

  // x4 chained by accumulator

This chain of 4 outer products can be fused into a single 4-way widening
variant but the pass doesn't match on the IR, as it expects the source
of the inputs to be an extend and it can't look through the extracts.

This patch fixes this with two rewrites that swaps extract(extend) into
extend(extract).

Related to #78975, #79288.
@c-rhodes c-rhodes force-pushed the mlir-sme-4way-outerproducts branch from 5e3c649 to 96eecb9 Compare February 5, 2024 14:42
agozillon pushed a commit to agozillon/llvm-project that referenced this pull request Feb 5, 2024
In mixed matmul lowering (e.g., i8 to i32) we're seeing the following
sequence:

  %0 = arith.extsi %src : vector<4x[8]xi8> to vector<4x[8]xi32>
  %1 = vector.extract %0[0] : vector<[8]xi32> from vector<4x[8]xi32>
%lhs = vector.scalable.extract %1[0] : vector<[4]xi32> from
vector<[8]xi32>

  ... (same for rhs)

%2 = vector.outerproduct %lhs, %rhs, %acc vector<[4]xi32>,
vector<[4]xi32>

  // x4 chained by accumulator

This chain of 4 outer products can be fused into a single 4-way widening
variant but the pass doesn't match on the IR, as it expects the source
of the inputs to be an extend and it can't look through the extracts.

This patch fixes this with two rewrites that swaps extract(extend) into
extend(extract).

Related to llvm#78975, llvm#79288.
This patch introduces support for 4-way widening outer products. This enables
the folding of 4 'arm_sme.outerproduct' operations that are chained via the
accumulator into single widened operations.

Changes:

- Adds the following operations:
  - smopa_4way, smops_4way
  - umopa_4way, umops_4way
  - sumopa_4way, sumops_4way
  - sumopa_4way, sumops_4way
- Implements conversions for the above ops to intrinsics in ArmSMEToLLVM.
- Extends 'arm-sme-outer-product' pass.

For a detailed description of these operations see the
'arm_sme.smopa_4way' description.

Address comments. Changes:

- add common match failures.
- move isCompatible to static function.
- update isCompatible to take optional `rhsExtType`.
- use isCompatible in-place of isWidenable.
- add canFuseOuterProducts for 4-way.
- llvm::hasSingleElement -> hasOneUse.
- op.erase -> rewriter.eraseOp.

Address comments. Changes:

Same as 2-way comments.
- fix check for consistent masking.
- rewrite as loop that walks outer product chain.
- use lambda for match check.
add comment to clarify each variant.
@c-rhodes c-rhodes force-pushed the mlir-sme-4way-outerproducts branch from 96eecb9 to aa6733d Compare February 6, 2024 10:42
Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some small nits/suggestions, otherwise LGTM, thanks!

Comment on lines 355 to 403
if (kind == arm_sme::CombiningKind::Add) {
if (isa<arith::ExtSIOp>(lhsExtOp) && isa<arith::ExtSIOp>(rhsExtOp))
rewriter.replaceOpWithNewOp<arm_sme::SMopa4WayOp>(
op4, op.getResultType(), lhs, rhs, lhsMask, rhsMask, op1.getAcc());
else if (isa<arith::ExtUIOp>(lhsExtOp) && isa<arith::ExtUIOp>(rhsExtOp))
rewriter.replaceOpWithNewOp<arm_sme::UMopa4WayOp>(
op4, op.getResultType(), lhs, rhs, lhsMask, rhsMask, op1.getAcc());
else if (isa<arith::ExtSIOp>(lhsExtOp) && isa<arith::ExtUIOp>(rhsExtOp))
rewriter.replaceOpWithNewOp<arm_sme::SuMopa4WayOp>(
op4, op.getResultType(), lhs, rhs, lhsMask, rhsMask, op1.getAcc());
else if (isa<arith::ExtUIOp>(lhsExtOp) && isa<arith::ExtSIOp>(rhsExtOp))
rewriter.replaceOpWithNewOp<arm_sme::UsMopa4WayOp>(
op4, op.getResultType(), lhs, rhs, lhsMask, rhsMask, op1.getAcc());
else
llvm_unreachable("unexpected extend op!");
} else if (kind == arm_sme::CombiningKind::Sub) {
if (isa<arith::ExtSIOp>(lhsExtOp) && isa<arith::ExtSIOp>(rhsExtOp))
rewriter.replaceOpWithNewOp<arm_sme::SMops4WayOp>(
op4, op.getResultType(), lhs, rhs, lhsMask, rhsMask, op1.getAcc());
else if (isa<arith::ExtUIOp>(lhsExtOp) && isa<arith::ExtUIOp>(rhsExtOp))
rewriter.replaceOpWithNewOp<arm_sme::UMops4WayOp>(
op4, op.getResultType(), lhs, rhs, lhsMask, rhsMask, op1.getAcc());
else if (isa<arith::ExtSIOp>(lhsExtOp) && isa<arith::ExtUIOp>(rhsExtOp))
rewriter.replaceOpWithNewOp<arm_sme::SuMops4WayOp>(
op4, op.getResultType(), lhs, rhs, lhsMask, rhsMask, op1.getAcc());
else if (isa<arith::ExtUIOp>(lhsExtOp) && isa<arith::ExtSIOp>(rhsExtOp))
rewriter.replaceOpWithNewOp<arm_sme::UsMops4WayOp>(
op4, op.getResultType(), lhs, rhs, lhsMask, rhsMask, op1.getAcc());
else
llvm_unreachable("unexpected extend op!");
} else {
llvm_unreachable("unexpected arm_sme::CombiningKind!");
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missed that! Not sure how to improve this then :/

Copy link
Member

@MacDue MacDue left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@c-rhodes c-rhodes merged commit fff86c6 into llvm:main Feb 7, 2024
@c-rhodes c-rhodes deleted the mlir-sme-4way-outerproducts branch February 7, 2024 08:18
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants