-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][ArmSVE] Add arm_sve.zip.x2
and arm_sve.zip.x4
ops
#81278
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Conversation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This adds ops for the two and four-way SME 2 multi-vector zips. See: https://developer.arm.com/documentation/ddi0602/2023-12/SME-Instructions/ZIP--two-registers---Interleave-elements-from-two-vectors-?lang=en https://developer.arm.com/documentation/ddi0602/2023-12/SME-Instructions/ZIP--four-registers---Interleave-elements-from-four-vectors-?lang=en
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-sve Author: Benjamin Maxwell (MacDue) ChangesThis adds ops for the two and four-way SME 2 multi-vector zips. See:
Full diff: https://github.com/llvm/llvm-project/pull/81278.diff 5 Files Affected:
diff --git a/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td b/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td
index f237f232487e50..63e70b412d9619 100644
--- a/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td
+++ b/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td
@@ -49,6 +49,12 @@ def SVBoolMask : VectorWithTrailingDimScalableOfSizeAndType<
def SVEPredicateMask : VectorWithTrailingDimScalableOfSizeAndType<
[16, 8, 4, 2, 1], [I1]>;
+// A constraint for a 1-D scalable vector of `length`.
+class Scalable1DVectorOfLength<int length, list<Type> elementTypes> : ShapedContainerType<
+ elementTypes, And<[IsVectorOfShape<[length]>, IsVectorTypeWithAnyDimScalablePred]>,
+ "a 1-D scalable vector with length " # length,
+ "::mlir::VectorType">;
+
//===----------------------------------------------------------------------===//
// ArmSVE op definitions
//===----------------------------------------------------------------------===//
@@ -321,6 +327,121 @@ def ConvertToSvboolOp : ArmSVE_Op<"convert_to_svbool",
let assemblyFormat = "$source attr-dict `:` type($source)";
}
+// Inputs valid for the multi-vector zips (not including the 128-bit element zipqs)
+def ZipInputVectorType : AnyTypeOf<[
+ Scalable1DVectorOfLength<2, [I64, F64]>,
+ Scalable1DVectorOfLength<4, [I32, F32]>,
+ Scalable1DVectorOfLength<8, [I16, F16, BF16]>,
+ Scalable1DVectorOfLength<16, [I8]>],
+ "an SVE vector with element size <= 64-bit">;
+
+def ZipX2Op : ArmSVE_Op<"zip.x2", [
+ Pure,
+ AllTypesMatch<["sourceV1", "sourceV2", "resultV1", "resultV2"]>]
+> {
+ let summary = "Multi-vector two-way zip op";
+
+ let description = [{
+ This operation interleaves elements from two input SVE vectors, returning
+ two new SVE vectors (`resultV1` and `resultV2`), which contain the low and
+ high halves of the result respectively.
+
+ Example:
+ ```mlir
+ // sourceV1 = [ A1, A2, A3, ... An ]
+ // sourceV2 = [ B1, B2, B3, ... Bn ]
+ // (resultV1, resultV2) = [ A1, B1, A2, B2, A3, B3, ... An, Bn ]
+ %resultV1, %resultV2 = arm_sve.zip.x2 %sourceV1, %sourceV2 : vector<[16]xi8>
+ ```
+
+ Note: This requires SME 2 (`+sme2` in LLVM target features)
+
+ [Source](https://developer.arm.com/documentation/ddi0602/2023-12/SME-Instructions/ZIP--two-registers---Interleave-elements-from-two-vectors-?lang=en)
+ }];
+
+ let arguments = (ins ZipInputVectorType:$sourceV1,
+ ZipInputVectorType:$sourceV2);
+
+ let results = (outs ZipInputVectorType:$resultV1,
+ ZipInputVectorType:$resultV2);
+
+ let builders = [
+ OpBuilder<(ins "Value":$v1, "Value":$v2), [{
+ build($_builder, $_state, v1.getType(), v1.getType(), v1, v2);
+ }]>];
+
+ let assemblyFormat = "$sourceV1 `,` $sourceV2 attr-dict `:` type($sourceV1)";
+
+ let extraClassDeclaration = [{
+ VectorType getVectorType() {
+ return ::llvm::cast<VectorType>(getSourceV1().getType());
+ }
+ }];
+}
+
+def ZipX4Op : ArmSVE_Op<"zip.x4", [
+ Pure,
+ AllTypesMatch<[
+ "sourceV1", "sourceV2", "sourceV3", "sourceV4",
+ "resultV1", "resultV2", "resultV3", "resultV4"]>]
+> {
+ let summary = "Multi-vector four-way zip op";
+
+ let description = [{
+ This operation interleaves elements from four input SVE vectors, returning
+ four new SVE vectors, each of which contain a quarter of the result. The
+ first quarter will be in `resultV1`, second in `resultV2`, third in
+ `resultV3`, and fourth in `resultV4`.
+
+ ```mlir
+ // sourceV1 = [ A1, A2, ... An ]
+ // sourceV2 = [ B1, B2, ... Bn ]
+ // sourceV3 = [ C1, C2, ... Cn ]
+ // sourceV4 = [ D1, D2, ... Dn ]
+ // (resultV1, resultV2, resultV3, resultV4)
+ // = [ A1, B1, C1, D1, A2, B2, C2, D2, ... An, Bn, Cn, Dn ]
+ %resultV1, %resultV2, %resultV3, %resultV4 = arm_sve.zip.x4
+ %sourceV1, %sourceV2, %sourceV3, %sourceV4 : vector<[16]xi8>
+ ```
+
+ Warning: The result of this op is undefined 64-bit elements on hardware with
+ less than 256-bit vectors!
+
+ Note: This requires SME 2 (`+sme2` in LLVM target features)
+
+ [Source](https://developer.arm.com/documentation/ddi0602/2023-12/SME-Instructions/ZIP--four-registers---Interleave-elements-from-four-vectors-?lang=en)
+ }];
+
+ let arguments = (ins ZipInputVectorType:$sourceV1,
+ ZipInputVectorType:$sourceV2,
+ ZipInputVectorType:$sourceV3,
+ ZipInputVectorType:$sourceV4);
+
+ let results = (outs ZipInputVectorType:$resultV1,
+ ZipInputVectorType:$resultV2,
+ ZipInputVectorType:$resultV3,
+ ZipInputVectorType:$resultV4);
+
+ let builders = [
+ OpBuilder<(ins "Value":$v1, "Value":$v2, "Value":$v3, "Value":$v4), [{
+ build($_builder, $_state,
+ v1.getType(), v1.getType(),
+ v1.getType(), v1.getType(),
+ v1, v2, v3, v4);
+ }]>];
+
+ let assemblyFormat = [{
+ $sourceV1 `,` $sourceV2 `,` $sourceV3 `,` $sourceV4 attr-dict
+ `:` type($sourceV1)
+ }];
+
+ let extraClassDeclaration = [{
+ VectorType getVectorType() {
+ return ::llvm::cast<VectorType>(getSourceV1().getType());
+ }
+ }];
+}
+
def ScalableMaskedAddIOp : ScalableMaskedIOp<"masked.addi", "addition",
[Commutative]>;
diff --git a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp
index 32c87c1b824074..387937e811ced8 100644
--- a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp
+++ b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp
@@ -137,6 +137,9 @@ using ConvertToSvboolOpLowering =
using ConvertFromSvboolOpLowering =
SvboolConversionOpLowering<ConvertFromSvboolOp, ConvertFromSvboolIntrOp>;
+using ZipX2OpLowering = OneToOneConvertToLLVMPattern<ZipX2Op, ZipX2IntrOp>;
+using ZipX4OpLowering = OneToOneConvertToLLVMPattern<ZipX4Op, ZipX4IntrOp>;
+
} // namespace
/// Populate the given list with patterns that convert from ArmSVE to LLVM.
@@ -163,7 +166,9 @@ void mlir::populateArmSVELegalizeForLLVMExportPatterns(
ScalableMaskedUDivIOpLowering,
ScalableMaskedDivFOpLowering,
ConvertToSvboolOpLowering,
- ConvertFromSvboolOpLowering>(converter);
+ ConvertFromSvboolOpLowering,
+ ZipX2OpLowering,
+ ZipX4OpLowering>(converter);
// clang-format on
}
@@ -184,7 +189,9 @@ void mlir::configureArmSVELegalizeForExportTarget(
ScalableMaskedUDivIIntrOp,
ScalableMaskedDivFIntrOp,
ConvertToSvboolIntrOp,
- ConvertFromSvboolIntrOp>();
+ ConvertFromSvboolIntrOp,
+ ZipX2IntrOp,
+ ZipX4IntrOp>();
target.addIllegalOp<SdotOp,
SmmlaOp,
UdotOp,
@@ -199,6 +206,8 @@ void mlir::configureArmSVELegalizeForExportTarget(
ScalableMaskedUDivIOp,
ScalableMaskedDivFOp,
ConvertToSvboolOp,
- ConvertFromSvboolOp>();
+ ConvertFromSvboolOp,
+ ZipX2Op,
+ ZipX4Op>();
// clang-format on
}
diff --git a/mlir/test/Dialect/ArmSVE/invalid.mlir b/mlir/test/Dialect/ArmSVE/invalid.mlir
index a1fa0d0292b7b7..1258d3532c049c 100644
--- a/mlir/test/Dialect/ArmSVE/invalid.mlir
+++ b/mlir/test/Dialect/ArmSVE/invalid.mlir
@@ -49,3 +49,18 @@ func.func @arm_sve_convert_to_svbool__bad_mask_scalability(%mask : vector<[4]x[8
}
+// -----
+
+func.func @arm_sve_zip_x2_bad_vector_type(%a : vector<[7]xi8>) {
+ // expected-error@+1 {{op operand #0 must be an SVE vector with element size <= 64-bit, but got 'vector<[7]xi8>'}}
+ arm_sve.zip.x2 %a, %a : vector<[7]xi8>
+ return
+}
+
+// -----
+
+func.func @arm_sve_zip_x4_bad_vector_type(%a : vector<[5]xf64>) {
+ // expected-error@+1 {{op operand #0 must be an SVE vector with element size <= 64-bit, but got 'vector<[5]xf64>'}}
+ arm_sve.zip.x4 %a, %a, %a, %a : vector<[5]xf64>
+ return
+}
diff --git a/mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir b/mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir
index 8e76fb7119b844..8d11c2bcaa8d51 100644
--- a/mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir
+++ b/mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir
@@ -187,3 +187,27 @@ func.func @convert_2d_mask_from_svbool(%svbool: vector<3x[16]xi1>) -> vector<3x[
// CHECK-NEXT: llvm.return %[[MASK]] : !llvm.array<3 x vector<[1]xi1>>
return %mask : vector<3x[1]xi1>
}
+
+// -----
+
+func.func @arm_sve_zip_x2(%a: vector<[8]xi16>, %b: vector<[8]xi16>)
+ -> (vector<[8]xi16>, vector<[8]xi16>)
+{
+ // CHECK: arm_sve.intr.zip.x2
+ %0, %1 = arm_sve.zip.x2 %a, %b : vector<[8]xi16>
+ return %0, %1 : vector<[8]xi16>, vector<[8]xi16>
+}
+
+// -----
+
+func.func @arm_sve_zip_x4(
+ %a: vector<[16]xi8>,
+ %b: vector<[16]xi8>,
+ %c: vector<[16]xi8>,
+ %d: vector<[16]xi8>
+) -> (vector<[16]xi8>, vector<[16]xi8>, vector<[16]xi8>, vector<[16]xi8>)
+{
+ // CHECK: arm_sve.intr.zip.x4
+ %0, %1, %2, %3 = arm_sve.zip.x4 %a, %b, %c, %d : vector<[16]xi8>
+ return %0, %1, %2, %3 : vector<[16]xi8>, vector<[16]xi8>, vector<[16]xi8>, vector<[16]xi8>
+}
diff --git a/mlir/test/Dialect/ArmSVE/roundtrip.mlir b/mlir/test/Dialect/ArmSVE/roundtrip.mlir
index c9a0b6db8fa803..f7b79aa2f275c4 100644
--- a/mlir/test/Dialect/ArmSVE/roundtrip.mlir
+++ b/mlir/test/Dialect/ArmSVE/roundtrip.mlir
@@ -163,3 +163,65 @@ func.func @arm_sve_convert_from_svbool(%a: vector<[16]xi1>,
return
}
+
+// -----
+
+func.func @arm_sve_zip_x2(
+ %v1: vector<[2]xi64>,
+ %v2: vector<[2]xf64>,
+ %v3: vector<[4]xi32>,
+ %v4: vector<[4]xf32>,
+ %v5: vector<[8]xi16>,
+ %v6: vector<[8]xf16>,
+ %v7: vector<[8]xbf16>,
+ %v8: vector<[16]xi8>
+) {
+ // CHECK: arm_sve.zip.x2 %{{.*}} : vector<[2]xi64>
+ %a1, %b1 = arm_sve.zip.x2 %v1, %v1 : vector<[2]xi64>
+ // CHECK: arm_sve.zip.x2 %{{.*}} : vector<[2]xf64>
+ %a2, %b2 = arm_sve.zip.x2 %v2, %v2 : vector<[2]xf64>
+ // CHECK: arm_sve.zip.x2 %{{.*}} : vector<[4]xi32>
+ %a3, %b3 = arm_sve.zip.x2 %v3, %v3 : vector<[4]xi32>
+ // CHECK: arm_sve.zip.x2 %{{.*}} : vector<[4]xf32>
+ %a4, %b4 = arm_sve.zip.x2 %v4, %v4 : vector<[4]xf32>
+ // CHECK: arm_sve.zip.x2 %{{.*}} : vector<[8]xi16>
+ %a5, %b5 = arm_sve.zip.x2 %v5, %v5 : vector<[8]xi16>
+ // CHECK: arm_sve.zip.x2 %{{.*}} : vector<[8]xf16>
+ %a6, %b6 = arm_sve.zip.x2 %v6, %v6 : vector<[8]xf16>
+ // CHECK: arm_sve.zip.x2 %{{.*}} : vector<[8]xbf16>
+ %a7, %b7 = arm_sve.zip.x2 %v7, %v7 : vector<[8]xbf16>
+ // CHECK: arm_sve.zip.x2 %{{.*}} : vector<[16]xi8>
+ %a8, %b8 = arm_sve.zip.x2 %v8, %v8 : vector<[16]xi8>
+ return
+}
+
+// -----
+
+func.func @arm_sve_zip_x4(
+ %v1: vector<[2]xi64>,
+ %v2: vector<[2]xf64>,
+ %v3: vector<[4]xi32>,
+ %v4: vector<[4]xf32>,
+ %v5: vector<[8]xi16>,
+ %v6: vector<[8]xf16>,
+ %v7: vector<[8]xbf16>,
+ %v8: vector<[16]xi8>
+) {
+ // CHECK: arm_sve.zip.x4 %{{.*}} : vector<[2]xi64>
+ %a1, %b1, %c1, %d1 = arm_sve.zip.x4 %v1, %v1, %v1, %v1 : vector<[2]xi64>
+ // CHECK: arm_sve.zip.x4 %{{.*}} : vector<[2]xf64>
+ %a2, %b2, %c2, %d2 = arm_sve.zip.x4 %v2, %v2, %v2, %v2 : vector<[2]xf64>
+ // CHECK: arm_sve.zip.x4 %{{.*}} : vector<[4]xi32>
+ %a3, %b3, %c3, %d3 = arm_sve.zip.x4 %v3, %v3, %v3, %v3 : vector<[4]xi32>
+ // CHECK: arm_sve.zip.x4 %{{.*}} : vector<[4]xf32>
+ %a4, %b4, %c4, %d4 = arm_sve.zip.x4 %v4, %v4, %v4, %v4 : vector<[4]xf32>
+ // CHECK: arm_sve.zip.x4 %{{.*}} : vector<[8]xi16>
+ %a5, %b5, %c5, %d5 = arm_sve.zip.x4 %v5, %v5, %v5, %v5 : vector<[8]xi16>
+ // CHECK: arm_sve.zip.x4 %{{.*}} : vector<[8]xf16>
+ %a6, %b6, %c6, %d6 = arm_sve.zip.x4 %v6, %v6, %v6, %v6 : vector<[8]xf16>
+ // CHECK: arm_sve.zip.x4 %{{.*}} : vector<[8]xbf16>
+ %a7, %b7, %c7, %d7 = arm_sve.zip.x4 %v7, %v7, %v7, %v7 : vector<[8]xbf16>
+ // CHECK: arm_sve.zip.x4 %{{.*}} : vector<[16]xi8>
+ %a8, %b8, %c8, %d8 = arm_sve.zip.x4 %v8, %v8, %v8, %v8 : vector<[16]xi8>
+ return
+}
|
dcaballe
approved these changes
Feb 9, 2024
c-rhodes
reviewed
Feb 12, 2024
c-rhodes
approved these changes
Feb 12, 2024
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM cheers
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
This adds ops for the two and four-way SME 2 multi-vector zips.
See: