Skip to content

Commit fa9e411

Browse files
committed
[mlir][ArmSVE] Add intrinsics for the SME2 multi-vector zips
These are added to the ArmSVE dialect for consistency with LLVM, which registers SME2 intrinsics that don't require ZA under SVE.
1 parent e5638c5 commit fa9e411

File tree

2 files changed

+65
-2
lines changed

2 files changed

+65
-2
lines changed

mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,14 +59,15 @@ class ArmSVE_Op<string mnemonic, list<Trait> traits = []> :
5959
class ArmSVE_IntrOp<string mnemonic,
6060
list<Trait> traits = [],
6161
list<int> overloadedOperands = [],
62-
list<int> overloadedResults = []> :
62+
list<int> overloadedResults = [],
63+
int numResults = 1> :
6364
LLVM_IntrOpBase</*Dialect dialect=*/ArmSVE_Dialect,
6465
/*string opName=*/"intr." # mnemonic,
6566
/*string enumName=*/"aarch64_sve_" # !subst(".", "_", mnemonic),
6667
/*list<int> overloadedResults=*/overloadedResults,
6768
/*list<int> overloadedOperands=*/overloadedOperands,
6869
/*list<Trait> traits=*/traits,
69-
/*int numResults=*/1>;
70+
/*int numResults=*/numResults>;
7071

7172
class ArmSVE_IntrBinaryOverloadedOp<string mnemonic,
7273
list<Trait> traits = []>:
@@ -410,4 +411,24 @@ def ConvertToSvboolIntrOp :
410411
/*overloadedResults=*/[]>,
411412
Arguments<(ins SVEPredicate:$mask)>;
412413

414+
// Note: This multi-vector intrinsic requires SME2.
415+
def ZipX2IntrOp : ArmSVE_IntrOp<"zip.x2",
416+
/*traits=*/[],
417+
/*overloadedOperands=*/[0],
418+
/*overloadedResults=*/[],
419+
/*numResults=*/4>,
420+
Arguments<(ins Arg<AnyScalableVector, "v1">:$v1,
421+
Arg<AnyScalableVector, "v2">:$v2)>;
422+
423+
// Note: This multi-vector intrinsic requires SME2.
424+
def ZipX4IntrOp : ArmSVE_IntrOp<"zip.x4",
425+
/*traits=*/[],
426+
/*overloadedOperands=*/[0],
427+
/*overloadedResults=*/[],
428+
/*numResults=*/4>,
429+
Arguments<(ins Arg<AnyScalableVector, "v1">:$v1,
430+
Arg<AnyScalableVector, "v2">:$v2,
431+
Arg<AnyScalableVector, "v3">:$v3,
432+
Arg<AnyScalableVector, "v3">:$v4)>;
433+
413434
#endif // ARMSVE_OPS

mlir/test/Target/LLVMIR/arm-sve.mlir

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -314,3 +314,45 @@ llvm.func @arm_sve_convert_to_svbool(
314314
: (vector<[1]xi1>) -> vector<[16]xi1>
315315
llvm.return
316316
}
317+
318+
// CHECK-LABEL: arm_sve_zip_x2(
319+
// CHECK-SAME: <vscale x 16 x i8> %[[V1:[0-9]+]],
320+
// CHECK-SAME: <vscale x 8 x i16> %[[V2:[0-9]+]],
321+
// CHECK-SAME: <vscale x 4 x i32> %[[V3:[0-9]+]],
322+
// CHECK-SAME: <vscale x 2 x i64> %[[V4:[0-9]+]])
323+
llvm.func @arm_sve_zip_x2(%nxv16i8: vector<[16]xi8>, %nxv8i16: vector<[8]xi16>, %nxv4i32: vector<[4]xi32>, %nxv2i64: vector<[2]xi64>) {
324+
// CHECK: call { <vscale x 16 x i8>, <vscale x 16 x i8> } @llvm.aarch64.sve.zip.x2.nxv16i8(<vscale x 16 x i8> %[[V1]], <vscale x 16 x i8> %[[V1]])
325+
%0 = "arm_sve.intr.zip.x2"(%nxv16i8, %nxv16i8) : (vector<[16]xi8>, vector<[16]xi8>)
326+
-> !llvm.struct<(vector<[16]xi8>, vector<[16]xi8>)>
327+
// CHECK: call { <vscale x 8 x i16>, <vscale x 8 x i16> } @llvm.aarch64.sve.zip.x2.nxv8i16(<vscale x 8 x i16> %[[V2]], <vscale x 8 x i16> %[[V2]])
328+
%1 = "arm_sve.intr.zip.x2"(%nxv8i16, %nxv8i16) : (vector<[8]xi16>, vector<[8]xi16>)
329+
-> !llvm.struct<(vector<[8]xi16>, vector<[8]xi16>)>
330+
// CHECK: call { <vscale x 4 x i32>, <vscale x 4 x i32> } @llvm.aarch64.sve.zip.x2.nxv4i32(<vscale x 4 x i32> %[[V3]], <vscale x 4 x i32> %[[V3]])
331+
%2 = "arm_sve.intr.zip.x2"(%nxv4i32, %nxv4i32) : (vector<[4]xi32>, vector<[4]xi32>)
332+
-> !llvm.struct<(vector<[4]xi32>, vector<[4]xi32>)>
333+
// CHECK: call { <vscale x 2 x i64>, <vscale x 2 x i64> } @llvm.aarch64.sve.zip.x2.nxv2i64(<vscale x 2 x i64> %[[V4]], <vscale x 2 x i64> %[[V4]])
334+
%3 = "arm_sve.intr.zip.x2"(%nxv2i64, %nxv2i64) : (vector<[2]xi64>, vector<[2]xi64>)
335+
-> !llvm.struct<(vector<[2]xi64>, vector<[2]xi64>)>
336+
llvm.return
337+
}
338+
339+
// CHECK-LABEL: arm_sve_zip_x4(
340+
// CHECK-SAME: <vscale x 16 x i8> %[[V1:[0-9]+]],
341+
// CHECK-SAME: <vscale x 8 x i16> %[[V2:[0-9]+]],
342+
// CHECK-SAME: <vscale x 4 x i32> %[[V3:[0-9]+]],
343+
// CHECK-SAME: <vscale x 2 x i64> %[[V4:[0-9]+]])
344+
llvm.func @arm_sve_zip_x4(%nxv16i8: vector<[16]xi8>, %nxv8i16: vector<[8]xi16>, %nxv4i32: vector<[4]xi32>, %nxv2i64: vector<[2]xi64>) {
345+
// CHECK: call { <vscale x 16 x i8>, <vscale x 16 x i8>, <vscale x 16 x i8>, <vscale x 16 x i8> } @llvm.aarch64.sve.zip.x4.nxv16i8(<vscale x 16 x i8> %[[V1]], <vscale x 16 x i8> %[[V1]], <vscale x 16 x i8> %[[V1]], <vscale x 16 x i8> %[[V1]])
346+
%0 = "arm_sve.intr.zip.x4"(%nxv16i8, %nxv16i8, %nxv16i8, %nxv16i8) : (vector<[16]xi8>, vector<[16]xi8>, vector<[16]xi8>, vector<[16]xi8>)
347+
-> !llvm.struct<(vector<[16]xi8>, vector<[16]xi8>, vector<[16]xi8>, vector<[16]xi8>)>
348+
// CHECK: call { <vscale x 8 x i16>, <vscale x 8 x i16>, <vscale x 8 x i16>, <vscale x 8 x i16> } @llvm.aarch64.sve.zip.x4.nxv8i16(<vscale x 8 x i16> %[[V2]], <vscale x 8 x i16> %[[V2]], <vscale x 8 x i16> %[[V2]], <vscale x 8 x i16> %[[V2]])
349+
%1 = "arm_sve.intr.zip.x4"(%nxv8i16, %nxv8i16, %nxv8i16, %nxv8i16) : (vector<[8]xi16>, vector<[8]xi16>, vector<[8]xi16>, vector<[8]xi16>)
350+
-> !llvm.struct<(vector<[8]xi16>, vector<[8]xi16>, vector<[8]xi16>, vector<[8]xi16>)>
351+
// CHECK: call { <vscale x 4 x i32>, <vscale x 4 x i32>, <vscale x 4 x i32>, <vscale x 4 x i32> } @llvm.aarch64.sve.zip.x4.nxv4i32(<vscale x 4 x i32> %[[V3]], <vscale x 4 x i32> %[[V3]], <vscale x 4 x i32> %[[V3]], <vscale x 4 x i32> %[[V3]])
352+
%2 = "arm_sve.intr.zip.x4"(%nxv4i32, %nxv4i32, %nxv4i32, %nxv4i32) : (vector<[4]xi32>, vector<[4]xi32>, vector<[4]xi32>, vector<[4]xi32>)
353+
-> !llvm.struct<(vector<[4]xi32>, vector<[4]xi32>, vector<[4]xi32>, vector<[4]xi32>)>
354+
// CHECK: call { <vscale x 2 x i64>, <vscale x 2 x i64>, <vscale x 2 x i64>, <vscale x 2 x i64> } @llvm.aarch64.sve.zip.x4.nxv2i64(<vscale x 2 x i64> %[[V4]], <vscale x 2 x i64> %[[V4]], <vscale x 2 x i64> %[[V4]], <vscale x 2 x i64> %[[V4]])
355+
%3 = "arm_sve.intr.zip.x4"(%nxv2i64, %nxv2i64, %nxv2i64, %nxv2i64) : (vector<[2]xi64>, vector<[2]xi64>, vector<[2]xi64>, vector<[2]xi64>)
356+
-> !llvm.struct<(vector<[2]xi64>, vector<[2]xi64>, vector<[2]xi64>, vector<[2]xi64>)>
357+
llvm.return
358+
}

0 commit comments

Comments
 (0)