Skip to content

[mlir][ArmSVE] Add arm_sve.zip.x2 and arm_sve.zip.x4 ops #81278

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Feb 16, 2024

Conversation

@llvmbot
Copy link
Member

llvmbot commented Feb 9, 2024

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-sve

Author: Benjamin Maxwell (MacDue)

Changes

This adds ops for the two and four-way SME 2 multi-vector zips.

See:


Full diff: https://github.com/llvm/llvm-project/pull/81278.diff

5 Files Affected:

  • (modified) mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td (+121)
  • (modified) mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp (+12-3)
  • (modified) mlir/test/Dialect/ArmSVE/invalid.mlir (+15)
  • (modified) mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir (+24)
  • (modified) mlir/test/Dialect/ArmSVE/roundtrip.mlir (+62)
diff --git a/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td b/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td
index f237f232487e50..63e70b412d9619 100644
--- a/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td
+++ b/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td
@@ -49,6 +49,12 @@ def SVBoolMask : VectorWithTrailingDimScalableOfSizeAndType<
 def SVEPredicateMask : VectorWithTrailingDimScalableOfSizeAndType<
   [16, 8, 4, 2, 1], [I1]>;
 
+// A constraint for a 1-D scalable vector of `length`.
+class Scalable1DVectorOfLength<int length, list<Type> elementTypes> : ShapedContainerType<
+  elementTypes, And<[IsVectorOfShape<[length]>, IsVectorTypeWithAnyDimScalablePred]>,
+  "a 1-D scalable vector with length " # length,
+  "::mlir::VectorType">;
+
 //===----------------------------------------------------------------------===//
 // ArmSVE op definitions
 //===----------------------------------------------------------------------===//
@@ -321,6 +327,121 @@ def ConvertToSvboolOp : ArmSVE_Op<"convert_to_svbool",
   let assemblyFormat = "$source attr-dict `:` type($source)";
 }
 
+// Inputs valid for the multi-vector zips (not including the 128-bit element zipqs)
+def ZipInputVectorType : AnyTypeOf<[
+  Scalable1DVectorOfLength<2, [I64, F64]>,
+  Scalable1DVectorOfLength<4, [I32, F32]>,
+  Scalable1DVectorOfLength<8, [I16, F16, BF16]>,
+  Scalable1DVectorOfLength<16, [I8]>],
+  "an SVE vector with element size <= 64-bit">;
+
+def ZipX2Op  : ArmSVE_Op<"zip.x2", [
+  Pure,
+  AllTypesMatch<["sourceV1", "sourceV2", "resultV1", "resultV2"]>]
+> {
+  let summary = "Multi-vector two-way zip op";
+
+  let description = [{
+    This operation interleaves elements from two input SVE vectors, returning
+    two new SVE vectors (`resultV1` and `resultV2`), which contain the low and
+    high halves of the result respectively.
+
+    Example:
+    ```mlir
+    // sourceV1 = [ A1, A2, A3, ... An ]
+    // sourceV2 = [ B1, B2, B3, ... Bn ]
+    // (resultV1, resultV2) = [ A1, B1, A2, B2, A3, B3, ... An, Bn ]
+    %resultV1, %resultV2 = arm_sve.zip.x2 %sourceV1, %sourceV2 : vector<[16]xi8>
+    ```
+
+    Note: This requires SME 2 (`+sme2` in LLVM target features)
+
+    [Source](https://developer.arm.com/documentation/ddi0602/2023-12/SME-Instructions/ZIP--two-registers---Interleave-elements-from-two-vectors-?lang=en)
+  }];
+
+  let arguments = (ins ZipInputVectorType:$sourceV1,
+                       ZipInputVectorType:$sourceV2);
+
+  let results = (outs ZipInputVectorType:$resultV1,
+                      ZipInputVectorType:$resultV2);
+
+  let builders = [
+    OpBuilder<(ins "Value":$v1, "Value":$v2), [{
+      build($_builder, $_state, v1.getType(), v1.getType(), v1, v2);
+  }]>];
+
+  let assemblyFormat = "$sourceV1 `,` $sourceV2 attr-dict `:` type($sourceV1)";
+
+  let extraClassDeclaration = [{
+    VectorType getVectorType() {
+      return ::llvm::cast<VectorType>(getSourceV1().getType());
+    }
+  }];
+}
+
+def ZipX4Op  : ArmSVE_Op<"zip.x4", [
+  Pure,
+  AllTypesMatch<[
+    "sourceV1", "sourceV2", "sourceV3", "sourceV4",
+    "resultV1", "resultV2", "resultV3", "resultV4"]>]
+> {
+  let summary = "Multi-vector four-way zip op";
+
+  let description = [{
+    This operation interleaves elements from four input SVE vectors, returning
+    four new SVE vectors, each of which contain a quarter of the result. The
+    first quarter will be in `resultV1`, second in `resultV2`, third in
+    `resultV3`, and fourth in `resultV4`.
+
+    ```mlir
+    // sourceV1 = [ A1, A2, ... An ]
+    // sourceV2 = [ B1, B2, ... Bn ]
+    // sourceV3 = [ C1, C2, ... Cn ]
+    // sourceV4 = [ D1, D2, ... Dn ]
+    // (resultV1, resultV2, resultV3, resultV4)
+    //   = [ A1, B1, C1, D1, A2, B2, C2, D2, ... An, Bn, Cn, Dn ]
+    %resultV1, %resultV2, %resultV3, %resultV4 = arm_sve.zip.x4
+      %sourceV1, %sourceV2, %sourceV3, %sourceV4 : vector<[16]xi8>
+    ```
+
+    Warning: The result of this op is undefined 64-bit elements on hardware with
+    less than 256-bit vectors!
+
+    Note: This requires SME 2 (`+sme2` in LLVM target features)
+
+    [Source](https://developer.arm.com/documentation/ddi0602/2023-12/SME-Instructions/ZIP--four-registers---Interleave-elements-from-four-vectors-?lang=en)
+  }];
+
+  let arguments = (ins ZipInputVectorType:$sourceV1,
+                       ZipInputVectorType:$sourceV2,
+                       ZipInputVectorType:$sourceV3,
+                       ZipInputVectorType:$sourceV4);
+
+  let results = (outs ZipInputVectorType:$resultV1,
+                      ZipInputVectorType:$resultV2,
+                      ZipInputVectorType:$resultV3,
+                      ZipInputVectorType:$resultV4);
+
+  let builders = [
+    OpBuilder<(ins "Value":$v1, "Value":$v2, "Value":$v3, "Value":$v4), [{
+      build($_builder, $_state,
+        v1.getType(), v1.getType(),
+        v1.getType(), v1.getType(),
+        v1, v2, v3, v4);
+  }]>];
+
+  let assemblyFormat = [{
+    $sourceV1 `,` $sourceV2 `,` $sourceV3 `,` $sourceV4 attr-dict
+      `:` type($sourceV1)
+  }];
+
+  let extraClassDeclaration = [{
+    VectorType getVectorType() {
+      return ::llvm::cast<VectorType>(getSourceV1().getType());
+    }
+  }];
+}
+
 def ScalableMaskedAddIOp : ScalableMaskedIOp<"masked.addi", "addition",
                                              [Commutative]>;
 
diff --git a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp
index 32c87c1b824074..387937e811ced8 100644
--- a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp
+++ b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp
@@ -137,6 +137,9 @@ using ConvertToSvboolOpLowering =
 using ConvertFromSvboolOpLowering =
     SvboolConversionOpLowering<ConvertFromSvboolOp, ConvertFromSvboolIntrOp>;
 
+using ZipX2OpLowering = OneToOneConvertToLLVMPattern<ZipX2Op, ZipX2IntrOp>;
+using ZipX4OpLowering = OneToOneConvertToLLVMPattern<ZipX4Op, ZipX4IntrOp>;
+
 } // namespace
 
 /// Populate the given list with patterns that convert from ArmSVE to LLVM.
@@ -163,7 +166,9 @@ void mlir::populateArmSVELegalizeForLLVMExportPatterns(
                ScalableMaskedUDivIOpLowering,
                ScalableMaskedDivFOpLowering,
                ConvertToSvboolOpLowering,
-               ConvertFromSvboolOpLowering>(converter);
+               ConvertFromSvboolOpLowering,
+               ZipX2OpLowering,
+               ZipX4OpLowering>(converter);
   // clang-format on
 }
 
@@ -184,7 +189,9 @@ void mlir::configureArmSVELegalizeForExportTarget(
                     ScalableMaskedUDivIIntrOp,
                     ScalableMaskedDivFIntrOp,
                     ConvertToSvboolIntrOp,
-                    ConvertFromSvboolIntrOp>();
+                    ConvertFromSvboolIntrOp,
+                    ZipX2IntrOp,
+                    ZipX4IntrOp>();
   target.addIllegalOp<SdotOp,
                       SmmlaOp,
                       UdotOp,
@@ -199,6 +206,8 @@ void mlir::configureArmSVELegalizeForExportTarget(
                       ScalableMaskedUDivIOp,
                       ScalableMaskedDivFOp,
                       ConvertToSvboolOp,
-                      ConvertFromSvboolOp>();
+                      ConvertFromSvboolOp,
+                      ZipX2Op,
+                      ZipX4Op>();
   // clang-format on
 }
diff --git a/mlir/test/Dialect/ArmSVE/invalid.mlir b/mlir/test/Dialect/ArmSVE/invalid.mlir
index a1fa0d0292b7b7..1258d3532c049c 100644
--- a/mlir/test/Dialect/ArmSVE/invalid.mlir
+++ b/mlir/test/Dialect/ArmSVE/invalid.mlir
@@ -49,3 +49,18 @@ func.func @arm_sve_convert_to_svbool__bad_mask_scalability(%mask : vector<[4]x[8
 }
 
 
+// -----
+
+func.func @arm_sve_zip_x2_bad_vector_type(%a : vector<[7]xi8>) {
+  // expected-error@+1 {{op operand #0 must be an SVE vector with element size <= 64-bit, but got 'vector<[7]xi8>'}}
+  arm_sve.zip.x2 %a, %a : vector<[7]xi8>
+  return
+}
+
+// -----
+
+func.func @arm_sve_zip_x4_bad_vector_type(%a : vector<[5]xf64>) {
+  // expected-error@+1 {{op operand #0 must be an SVE vector with element size <= 64-bit, but got 'vector<[5]xf64>'}}
+  arm_sve.zip.x4 %a, %a, %a, %a : vector<[5]xf64>
+  return
+}
diff --git a/mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir b/mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir
index 8e76fb7119b844..8d11c2bcaa8d51 100644
--- a/mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir
+++ b/mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir
@@ -187,3 +187,27 @@ func.func @convert_2d_mask_from_svbool(%svbool: vector<3x[16]xi1>) -> vector<3x[
   // CHECK-NEXT: llvm.return %[[MASK]] : !llvm.array<3 x vector<[1]xi1>>
   return %mask : vector<3x[1]xi1>
 }
+
+// -----
+
+func.func @arm_sve_zip_x2(%a: vector<[8]xi16>, %b: vector<[8]xi16>)
+  -> (vector<[8]xi16>, vector<[8]xi16>)
+{
+  // CHECK: arm_sve.intr.zip.x2
+  %0, %1 = arm_sve.zip.x2 %a, %b : vector<[8]xi16>
+  return %0, %1 : vector<[8]xi16>, vector<[8]xi16>
+}
+
+// -----
+
+func.func @arm_sve_zip_x4(
+  %a: vector<[16]xi8>,
+  %b: vector<[16]xi8>,
+  %c: vector<[16]xi8>,
+  %d: vector<[16]xi8>
+) -> (vector<[16]xi8>, vector<[16]xi8>, vector<[16]xi8>, vector<[16]xi8>)
+{
+  // CHECK: arm_sve.intr.zip.x4
+  %0, %1, %2, %3 = arm_sve.zip.x4 %a, %b, %c, %d : vector<[16]xi8>
+  return %0, %1, %2, %3 : vector<[16]xi8>, vector<[16]xi8>, vector<[16]xi8>, vector<[16]xi8>
+}
diff --git a/mlir/test/Dialect/ArmSVE/roundtrip.mlir b/mlir/test/Dialect/ArmSVE/roundtrip.mlir
index c9a0b6db8fa803..f7b79aa2f275c4 100644
--- a/mlir/test/Dialect/ArmSVE/roundtrip.mlir
+++ b/mlir/test/Dialect/ArmSVE/roundtrip.mlir
@@ -163,3 +163,65 @@ func.func @arm_sve_convert_from_svbool(%a: vector<[16]xi1>,
 
   return
 }
+
+// -----
+
+func.func @arm_sve_zip_x2(
+  %v1: vector<[2]xi64>,
+  %v2: vector<[2]xf64>,
+  %v3: vector<[4]xi32>,
+  %v4: vector<[4]xf32>,
+  %v5: vector<[8]xi16>,
+  %v6: vector<[8]xf16>,
+  %v7: vector<[8]xbf16>,
+  %v8: vector<[16]xi8>
+) {
+  // CHECK: arm_sve.zip.x2 %{{.*}} : vector<[2]xi64>
+  %a1, %b1 = arm_sve.zip.x2 %v1, %v1 : vector<[2]xi64>
+  // CHECK: arm_sve.zip.x2 %{{.*}} : vector<[2]xf64>
+  %a2, %b2 = arm_sve.zip.x2 %v2, %v2 : vector<[2]xf64>
+  // CHECK: arm_sve.zip.x2 %{{.*}} : vector<[4]xi32>
+  %a3, %b3 = arm_sve.zip.x2 %v3, %v3 : vector<[4]xi32>
+  // CHECK: arm_sve.zip.x2 %{{.*}} : vector<[4]xf32>
+  %a4, %b4 = arm_sve.zip.x2 %v4, %v4 : vector<[4]xf32>
+  // CHECK: arm_sve.zip.x2 %{{.*}} : vector<[8]xi16>
+  %a5, %b5 = arm_sve.zip.x2 %v5, %v5 : vector<[8]xi16>
+  // CHECK: arm_sve.zip.x2 %{{.*}} : vector<[8]xf16>
+  %a6, %b6 = arm_sve.zip.x2 %v6, %v6 : vector<[8]xf16>
+  // CHECK: arm_sve.zip.x2 %{{.*}} : vector<[8]xbf16>
+  %a7, %b7 = arm_sve.zip.x2 %v7, %v7 : vector<[8]xbf16>
+  // CHECK: arm_sve.zip.x2 %{{.*}} : vector<[16]xi8>
+  %a8, %b8 = arm_sve.zip.x2 %v8, %v8 : vector<[16]xi8>
+  return
+}
+
+// -----
+
+func.func @arm_sve_zip_x4(
+  %v1: vector<[2]xi64>,
+  %v2: vector<[2]xf64>,
+  %v3: vector<[4]xi32>,
+  %v4: vector<[4]xf32>,
+  %v5: vector<[8]xi16>,
+  %v6: vector<[8]xf16>,
+  %v7: vector<[8]xbf16>,
+  %v8: vector<[16]xi8>
+) {
+  // CHECK: arm_sve.zip.x4 %{{.*}} : vector<[2]xi64>
+  %a1, %b1, %c1, %d1 = arm_sve.zip.x4 %v1, %v1, %v1, %v1 : vector<[2]xi64>
+  // CHECK: arm_sve.zip.x4 %{{.*}} : vector<[2]xf64>
+  %a2, %b2, %c2, %d2 = arm_sve.zip.x4 %v2, %v2, %v2, %v2 : vector<[2]xf64>
+  // CHECK: arm_sve.zip.x4 %{{.*}} : vector<[4]xi32>
+  %a3, %b3, %c3, %d3 = arm_sve.zip.x4 %v3, %v3, %v3, %v3 : vector<[4]xi32>
+  // CHECK: arm_sve.zip.x4 %{{.*}} : vector<[4]xf32>
+  %a4, %b4, %c4, %d4 = arm_sve.zip.x4 %v4, %v4, %v4, %v4 : vector<[4]xf32>
+  // CHECK: arm_sve.zip.x4 %{{.*}} : vector<[8]xi16>
+  %a5, %b5, %c5, %d5 = arm_sve.zip.x4 %v5, %v5, %v5, %v5 : vector<[8]xi16>
+  // CHECK: arm_sve.zip.x4 %{{.*}} : vector<[8]xf16>
+  %a6, %b6, %c6, %d6 = arm_sve.zip.x4 %v6, %v6, %v6, %v6 : vector<[8]xf16>
+  // CHECK: arm_sve.zip.x4 %{{.*}} : vector<[8]xbf16>
+  %a7, %b7, %c7, %d7 = arm_sve.zip.x4 %v7, %v7, %v7, %v7 : vector<[8]xbf16>
+  // CHECK: arm_sve.zip.x4 %{{.*}} : vector<[16]xi8>
+  %a8, %b8, %c8, %d8 = arm_sve.zip.x4 %v8, %v8, %v8, %v8 : vector<[16]xi8>
+  return
+}

Copy link
Collaborator

@c-rhodes c-rhodes left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM cheers

@MacDue MacDue merged commit 7dcca62 into llvm:main Feb 16, 2024
@MacDue MacDue deleted the multi_vector_zip_ops branch February 16, 2024 11:34
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants