-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[MLIR][ArmSVE] Add an ArmSVE dialect operation which maps to svdupq_lane
#135633
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir-sve @llvm/pr-subscribers-mlir Author: Momchil Velikov (momchil-velikov) ChangesSupersedes #135356 this time using user branches. Full diff: https://github.com/llvm/llvm-project/pull/135633.diff 4 Files Affected:
diff --git a/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td b/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td
index cdcf4d8752e87..1a59062ccc93d 100644
--- a/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td
+++ b/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td
@@ -61,6 +61,13 @@ class Scalable1DVectorOfLength<int length, list<Type> elementTypes> : ShapedCont
"a 1-D scalable vector with length " # length,
"::mlir::VectorType">;
+def SVEVector : AnyTypeOf<[
+ Scalable1DVectorOfLength<2, [I64, F64]>,
+ Scalable1DVectorOfLength<4, [I32, F32]>,
+ Scalable1DVectorOfLength<8, [I16, F16, BF16]>,
+ Scalable1DVectorOfLength<16, [I8]>],
+ "an SVE vector with element size <= 64-bit">;
+
//===----------------------------------------------------------------------===//
// ArmSVE op definitions
//===----------------------------------------------------------------------===//
@@ -72,14 +79,22 @@ class ArmSVE_IntrOp<string mnemonic,
list<Trait> traits = [],
list<int> overloadedOperands = [],
list<int> overloadedResults = [],
- int numResults = 1> :
+ int numResults = 1,
+ list<int> immArgPositions = [],
+ list<string> immArgAttrNames = []> :
LLVM_IntrOpBase</*Dialect dialect=*/ArmSVE_Dialect,
/*string opName=*/"intr." # mnemonic,
/*string enumName=*/"aarch64_sve_" # !subst(".", "_", mnemonic),
/*list<int> overloadedResults=*/overloadedResults,
/*list<int> overloadedOperands=*/overloadedOperands,
/*list<Trait> traits=*/traits,
- /*int numResults=*/numResults>;
+ /*int numResults=*/numResults,
+ /*bit requiresAccessGroup=*/0,
+ /*bit requiresAliasAnalysis=*/0,
+ /*bit requiresFastmath=*/0,
+ /*bit requiresOpBundles=*/0,
+ /*list<int> immArgPositions=*/immArgPositions,
+ /*list<string> immArgAttrNames=*/immArgAttrNames>;
class ArmSVE_IntrBinaryOverloadedOp<string mnemonic,
list<Trait> traits = []>:
@@ -509,6 +524,42 @@ def ScalableMaskedUDivIOp : ScalableMaskedIOp<"masked.divi_unsigned",
def ScalableMaskedDivFOp : ScalableMaskedFOp<"masked.divf", "division">;
+def DupQLaneOp : ArmSVE_Op<"dupq_lane", [Pure, AllTypesMatch<["src", "dst"]>]> {
+ let summary = "Broadcast indexed 128-bit segment to vector";
+
+ let description = [{
+ This operation fills each 128-bit segment of a vector with the elements
+ from the indexed 128-bit segment of the source vector. If the VL is
+ 128 bits the operation is a NOP. If the index exceeds the number of
+ 128-bit segments in a vector the result is an all-zeroes vector.
+
+ Example:
+ ```mlir
+ // VL == 256
+ // %X = [A B C D x x x x]
+ %Y = arm_sve.dupq_lane %X[0] : vector<[4]xi32>
+ // Y = [A B C D A B C D]
+
+ // %U = [x x x x x x x x A B C D E F G H]
+ %V = arm_sve.dupq_lane %U[1] : vector<[8]xf16>
+ // %V = [A B C D E F H A B C D E F H]
+ ```
+ }];
+
+ let arguments = (ins SVEVector:$src,
+ I64Attr:$lane);
+ let results = (outs SVEVector:$dst);
+
+ let builders = [
+ OpBuilder<(ins "Value":$src, "int64_t":$lane), [{
+ build($_builder, $_state, src.getType(), src, lane);
+ }]>];
+
+ let assemblyFormat = [{
+ $src `[` $lane `]` attr-dict `:` type($dst)
+ }];
+}
+
def UmmlaIntrOp :
ArmSVE_IntrBinaryOverloadedOp<"ummla">,
Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>;
@@ -610,4 +661,14 @@ def WhileLTIntrOp :
/*overloadedResults=*/[0]>,
Arguments<(ins I64:$base, I64:$n)>;
+def DupQLaneIntrOp : ArmSVE_IntrOp<"dupq_lane",
+ /*traits=*/[],
+ /*overloadedOperands=*/[0],
+ /*overloadedResults=*/[],
+ /*numResults=*/1,
+ /*immArgPositions*/[1],
+ /*immArgAttrNames*/["lane"]>,
+ Arguments<(ins Arg<ScalableVectorOfRank<[1]>, "v">:$v,
+ Arg<I64Attr, "lane">:$lane)>;
+
#endif // ARMSVE_OPS
diff --git a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp
index 2bdb640699d03..fe13ed03356b2 100644
--- a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp
+++ b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp
@@ -24,6 +24,8 @@ using SdotOpLowering = OneToOneConvertToLLVMPattern<SdotOp, SdotIntrOp>;
using SmmlaOpLowering = OneToOneConvertToLLVMPattern<SmmlaOp, SmmlaIntrOp>;
using UdotOpLowering = OneToOneConvertToLLVMPattern<UdotOp, UdotIntrOp>;
using UmmlaOpLowering = OneToOneConvertToLLVMPattern<UmmlaOp, UmmlaIntrOp>;
+using DupQLaneLowering =
+ OneToOneConvertToLLVMPattern<DupQLaneOp, DupQLaneIntrOp>;
using ScalableMaskedAddIOpLowering =
OneToOneConvertToLLVMPattern<ScalableMaskedAddIOp,
ScalableMaskedAddIIntrOp>;
@@ -192,6 +194,7 @@ void mlir::populateArmSVELegalizeForLLVMExportPatterns(
SmmlaOpLowering,
UdotOpLowering,
UmmlaOpLowering,
+ DupQLaneLowering,
ScalableMaskedAddIOpLowering,
ScalableMaskedAddFOpLowering,
ScalableMaskedSubIOpLowering,
@@ -219,6 +222,7 @@ void mlir::configureArmSVELegalizeForExportTarget(
SmmlaIntrOp,
UdotIntrOp,
UmmlaIntrOp,
+ DupQLaneIntrOp,
ScalableMaskedAddIIntrOp,
ScalableMaskedAddFIntrOp,
ScalableMaskedSubIIntrOp,
@@ -238,6 +242,7 @@ void mlir::configureArmSVELegalizeForExportTarget(
SmmlaOp,
UdotOp,
UmmlaOp,
+ DupQLaneOp,
ScalableMaskedAddIOp,
ScalableMaskedAddFOp,
ScalableMaskedSubIOp,
diff --git a/mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir b/mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir
index bdb69a95a52de..5d044517e0ea8 100644
--- a/mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir
+++ b/mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir
@@ -271,3 +271,46 @@ func.func @arm_sve_psel_mixed_predicate_types(%p0: vector<[8]xi1>, %p1: vector<[
%0 = arm_sve.psel %p0, %p1[%index] : vector<[8]xi1>, vector<[16]xi1>
return %0 : vector<[8]xi1>
}
+
+// -----
+
+// CHECK-LABEL: @arm_sve_dupq_lane(
+// CHECK-SAME: %[[A0:[a-z0-9]+]]: vector<[16]xi8>
+// CHECK-SAME: %[[A1:[a-z0-9]+]]: vector<[8]xi16>
+// CHECK-SAME: %[[A2:[a-z0-9]+]]: vector<[8]xf16>
+// CHECK-SAME: %[[A3:[a-z0-9]+]]: vector<[8]xbf16>
+// CHECK-SAME: %[[A4:[a-z0-9]+]]: vector<[4]xi32>
+// CHECK-SAME: %[[A5:[a-z0-9]+]]: vector<[4]xf32>
+// CHECK-SAME: %[[A6:[a-z0-9]+]]: vector<[2]xi64>
+// CHECK-SAME: %[[A7:[a-z0-9]+]]: vector<[2]xf64>
+// CHECK-SAME: -> !llvm.struct<(vector<[16]xi8>, vector<[8]xi16>, vector<[8]xf16>, vector<[8]xbf16>, vector<[4]xi32>, vector<[4]xf32>, vector<[2]xi64>, vector<[2]xf64>)> {
+
+// CHECK: "arm_sve.intr.dupq_lane"(%[[A0]]) <{lane = 0 : i64}> : (vector<[16]xi8>) -> vector<[16]xi8>
+// CHECK: "arm_sve.intr.dupq_lane"(%[[A1]]) <{lane = 1 : i64}> : (vector<[8]xi16>) -> vector<[8]xi16>
+// CHECK: "arm_sve.intr.dupq_lane"(%[[A2]]) <{lane = 2 : i64}> : (vector<[8]xf16>) -> vector<[8]xf16>
+// CHECK: "arm_sve.intr.dupq_lane"(%[[A3]]) <{lane = 3 : i64}> : (vector<[8]xbf16>) -> vector<[8]xbf16>
+// CHECK: "arm_sve.intr.dupq_lane"(%[[A4]]) <{lane = 4 : i64}> : (vector<[4]xi32>) -> vector<[4]xi32>
+// CHECK: "arm_sve.intr.dupq_lane"(%[[A5]]) <{lane = 5 : i64}> : (vector<[4]xf32>) -> vector<[4]xf32>
+// CHECK: "arm_sve.intr.dupq_lane"(%[[A6]]) <{lane = 6 : i64}> : (vector<[2]xi64>) -> vector<[2]xi64>
+// CHECK: "arm_sve.intr.dupq_lane"(%[[A7]]) <{lane = 7 : i64}> : (vector<[2]xf64>) -> vector<[2]xf64>
+func.func @arm_sve_dupq_lane(
+ %v16i8: vector<[16]xi8>, %v8i16: vector<[8]xi16>,
+ %v8f16: vector<[8]xf16>, %v8bf16: vector<[8]xbf16>,
+ %v4i32: vector<[4]xi32>, %v4f32: vector<[4]xf32>,
+ %v2i64: vector<[2]xi64>, %v2f64: vector<[2]xf64>)
+ -> (vector<[16]xi8>, vector<[8]xi16>, vector<[8]xf16>, vector<[8]xbf16>,
+ vector<[4]xi32>, vector<[4]xf32>, vector<[2]xi64>, vector<[2]xf64>) {
+
+ %0 = arm_sve.dupq_lane %v16i8[0] : vector<[16]xi8>
+ %1 = arm_sve.dupq_lane %v8i16[1] : vector<[8]xi16>
+ %2 = arm_sve.dupq_lane %v8f16[2] : vector<[8]xf16>
+ %3 = arm_sve.dupq_lane %v8bf16[3] : vector<[8]xbf16>
+ %4 = arm_sve.dupq_lane %v4i32[4] : vector<[4]xi32>
+ %5 = arm_sve.dupq_lane %v4f32[5] : vector<[4]xf32>
+ %6 = arm_sve.dupq_lane %v2i64[6] : vector<[2]xi64>
+ %7 = arm_sve.dupq_lane %v2f64[7] : vector<[2]xf64>
+
+ return %0, %1, %2, %3, %4, %5, %6, %7
+ : vector<[16]xi8>, vector<[8]xi16>, vector<[8]xf16>, vector<[8]xbf16>,
+ vector<[4]xi32>, vector<[4]xf32>, vector<[2]xi64>, vector<[2]xf64>
+}
diff --git a/mlir/test/Target/LLVMIR/arm-sve.mlir b/mlir/test/Target/LLVMIR/arm-sve.mlir
index ed5a1fc7ba2e4..ced59eb513b57 100644
--- a/mlir/test/Target/LLVMIR/arm-sve.mlir
+++ b/mlir/test/Target/LLVMIR/arm-sve.mlir
@@ -390,3 +390,37 @@ llvm.func @arm_sve_psel(%pn: vector<[16]xi1>, %p1: vector<[2]xi1>, %p2: vector<[
"arm_sve.intr.psel"(%pn, %p4, %index) : (vector<[16]xi1>, vector<[16]xi1>, i32) -> vector<[16]xi1>
llvm.return
}
+
+// CHECK-LABEL: @arm_sve_dupq_lane
+// CHECK-SAME: <vscale x 16 x i8> %0
+// CHECK-SAME: <vscale x 8 x i16> %1
+// CHECK-SAME: <vscale x 8 x half> %2
+// CHECK-SAME: <vscale x 8 x bfloat> %3
+// CHECK-SAME: <vscale x 4 x i32> %4
+// CHECK-SAME: <vscale x 4 x float> %5
+// CHECK-SAME: <vscale x 2 x i64> %6
+// CHECK-SAME: <vscale x 2 x double> %7
+
+
+llvm.func @arm_sve_dupq_lane(%arg0: vector<[16]xi8>, %arg1: vector<[8]xi16>,
+ %arg2: vector<[8]xf16>, %arg3: vector<[8]xbf16>,
+ %arg4: vector<[4]xi32>,%arg5: vector<[4]xf32>,
+ %arg6: vector<[2]xi64>, %arg7: vector<[2]xf64>) {
+ // CHECK: call <vscale x 16 x i8> @llvm.aarch64.sve.dupq.lane.nxv16i8(<vscale x 16 x i8> %0, i64 0)
+ %0 = "arm_sve.intr.dupq_lane"(%arg0) <{lane = 0 : i64}> : (vector<[16]xi8>) -> vector<[16]xi8>
+ // CHECK: call <vscale x 8 x i16> @llvm.aarch64.sve.dupq.lane.nxv8i16(<vscale x 8 x i16> %1, i64 1)
+ %1 = "arm_sve.intr.dupq_lane"(%arg1) <{lane = 1 : i64}> : (vector<[8]xi16>) -> vector<[8]xi16>
+ // CHECK: call <vscale x 8 x half> @llvm.aarch64.sve.dupq.lane.nxv8f16(<vscale x 8 x half> %2, i64 2)
+ %2 = "arm_sve.intr.dupq_lane"(%arg2) <{lane = 2 : i64}> : (vector<[8]xf16>) -> vector<[8]xf16>
+ // CHECK: call <vscale x 8 x bfloat> @llvm.aarch64.sve.dupq.lane.nxv8bf16(<vscale x 8 x bfloat> %3, i64 3)
+ %3 = "arm_sve.intr.dupq_lane"(%arg3) <{lane = 3 : i64}> : (vector<[8]xbf16>) -> vector<[8]xbf16>
+ // CHECK: call <vscale x 4 x i32> @llvm.aarch64.sve.dupq.lane.nxv4i32(<vscale x 4 x i32> %4, i64 4)
+ %4 = "arm_sve.intr.dupq_lane"(%arg4) <{lane = 4 : i64}> : (vector<[4]xi32>) -> vector<[4]xi32>
+ // CHECK: call <vscale x 4 x float> @llvm.aarch64.sve.dupq.lane.nxv4f32(<vscale x 4 x float> %5, i64 5)
+ %5 = "arm_sve.intr.dupq_lane"(%arg5) <{lane = 5 : i64}> : (vector<[4]xf32>) -> vector<[4]xf32>
+ // CHECK: call <vscale x 2 x i64> @llvm.aarch64.sve.dupq.lane.nxv2i64(<vscale x 2 x i64> %6, i64 6)
+ %6 = "arm_sve.intr.dupq_lane"(%arg6) <{lane = 6 : i64}> : (vector<[2]xi64>) -> vector<[2]xi64>
+ // CHECK: call <vscale x 2 x double> @llvm.aarch64.sve.dupq.lane.nxv2f64(<vscale x 2 x double> %7, i64 7)
+ %7 = "arm_sve.intr.dupq_lane"(%arg7) <{lane = 7 : i64}> : (vector<[2]xf64>) -> vector<[2]xf64>
+ llvm.return
+}
|
@llvm/pr-subscribers-mlir-llvm Author: Momchil Velikov (momchil-velikov) ChangesSupersedes #135356 this time using user branches. Full diff: https://github.com/llvm/llvm-project/pull/135633.diff 4 Files Affected:
diff --git a/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td b/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td
index cdcf4d8752e87..1a59062ccc93d 100644
--- a/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td
+++ b/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td
@@ -61,6 +61,13 @@ class Scalable1DVectorOfLength<int length, list<Type> elementTypes> : ShapedCont
"a 1-D scalable vector with length " # length,
"::mlir::VectorType">;
+def SVEVector : AnyTypeOf<[
+ Scalable1DVectorOfLength<2, [I64, F64]>,
+ Scalable1DVectorOfLength<4, [I32, F32]>,
+ Scalable1DVectorOfLength<8, [I16, F16, BF16]>,
+ Scalable1DVectorOfLength<16, [I8]>],
+ "an SVE vector with element size <= 64-bit">;
+
//===----------------------------------------------------------------------===//
// ArmSVE op definitions
//===----------------------------------------------------------------------===//
@@ -72,14 +79,22 @@ class ArmSVE_IntrOp<string mnemonic,
list<Trait> traits = [],
list<int> overloadedOperands = [],
list<int> overloadedResults = [],
- int numResults = 1> :
+ int numResults = 1,
+ list<int> immArgPositions = [],
+ list<string> immArgAttrNames = []> :
LLVM_IntrOpBase</*Dialect dialect=*/ArmSVE_Dialect,
/*string opName=*/"intr." # mnemonic,
/*string enumName=*/"aarch64_sve_" # !subst(".", "_", mnemonic),
/*list<int> overloadedResults=*/overloadedResults,
/*list<int> overloadedOperands=*/overloadedOperands,
/*list<Trait> traits=*/traits,
- /*int numResults=*/numResults>;
+ /*int numResults=*/numResults,
+ /*bit requiresAccessGroup=*/0,
+ /*bit requiresAliasAnalysis=*/0,
+ /*bit requiresFastmath=*/0,
+ /*bit requiresOpBundles=*/0,
+ /*list<int> immArgPositions=*/immArgPositions,
+ /*list<string> immArgAttrNames=*/immArgAttrNames>;
class ArmSVE_IntrBinaryOverloadedOp<string mnemonic,
list<Trait> traits = []>:
@@ -509,6 +524,42 @@ def ScalableMaskedUDivIOp : ScalableMaskedIOp<"masked.divi_unsigned",
def ScalableMaskedDivFOp : ScalableMaskedFOp<"masked.divf", "division">;
+def DupQLaneOp : ArmSVE_Op<"dupq_lane", [Pure, AllTypesMatch<["src", "dst"]>]> {
+ let summary = "Broadcast indexed 128-bit segment to vector";
+
+ let description = [{
+ This operation fills each 128-bit segment of a vector with the elements
+ from the indexed 128-bit segment of the source vector. If the VL is
+ 128 bits the operation is a NOP. If the index exceeds the number of
+ 128-bit segments in a vector the result is an all-zeroes vector.
+
+ Example:
+ ```mlir
+ // VL == 256
+ // %X = [A B C D x x x x]
+ %Y = arm_sve.dupq_lane %X[0] : vector<[4]xi32>
+ // Y = [A B C D A B C D]
+
+ // %U = [x x x x x x x x A B C D E F G H]
+ %V = arm_sve.dupq_lane %U[1] : vector<[8]xf16>
+ // %V = [A B C D E F H A B C D E F H]
+ ```
+ }];
+
+ let arguments = (ins SVEVector:$src,
+ I64Attr:$lane);
+ let results = (outs SVEVector:$dst);
+
+ let builders = [
+ OpBuilder<(ins "Value":$src, "int64_t":$lane), [{
+ build($_builder, $_state, src.getType(), src, lane);
+ }]>];
+
+ let assemblyFormat = [{
+ $src `[` $lane `]` attr-dict `:` type($dst)
+ }];
+}
+
def UmmlaIntrOp :
ArmSVE_IntrBinaryOverloadedOp<"ummla">,
Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>;
@@ -610,4 +661,14 @@ def WhileLTIntrOp :
/*overloadedResults=*/[0]>,
Arguments<(ins I64:$base, I64:$n)>;
+def DupQLaneIntrOp : ArmSVE_IntrOp<"dupq_lane",
+ /*traits=*/[],
+ /*overloadedOperands=*/[0],
+ /*overloadedResults=*/[],
+ /*numResults=*/1,
+ /*immArgPositions*/[1],
+ /*immArgAttrNames*/["lane"]>,
+ Arguments<(ins Arg<ScalableVectorOfRank<[1]>, "v">:$v,
+ Arg<I64Attr, "lane">:$lane)>;
+
#endif // ARMSVE_OPS
diff --git a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp
index 2bdb640699d03..fe13ed03356b2 100644
--- a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp
+++ b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp
@@ -24,6 +24,8 @@ using SdotOpLowering = OneToOneConvertToLLVMPattern<SdotOp, SdotIntrOp>;
using SmmlaOpLowering = OneToOneConvertToLLVMPattern<SmmlaOp, SmmlaIntrOp>;
using UdotOpLowering = OneToOneConvertToLLVMPattern<UdotOp, UdotIntrOp>;
using UmmlaOpLowering = OneToOneConvertToLLVMPattern<UmmlaOp, UmmlaIntrOp>;
+using DupQLaneLowering =
+ OneToOneConvertToLLVMPattern<DupQLaneOp, DupQLaneIntrOp>;
using ScalableMaskedAddIOpLowering =
OneToOneConvertToLLVMPattern<ScalableMaskedAddIOp,
ScalableMaskedAddIIntrOp>;
@@ -192,6 +194,7 @@ void mlir::populateArmSVELegalizeForLLVMExportPatterns(
SmmlaOpLowering,
UdotOpLowering,
UmmlaOpLowering,
+ DupQLaneLowering,
ScalableMaskedAddIOpLowering,
ScalableMaskedAddFOpLowering,
ScalableMaskedSubIOpLowering,
@@ -219,6 +222,7 @@ void mlir::configureArmSVELegalizeForExportTarget(
SmmlaIntrOp,
UdotIntrOp,
UmmlaIntrOp,
+ DupQLaneIntrOp,
ScalableMaskedAddIIntrOp,
ScalableMaskedAddFIntrOp,
ScalableMaskedSubIIntrOp,
@@ -238,6 +242,7 @@ void mlir::configureArmSVELegalizeForExportTarget(
SmmlaOp,
UdotOp,
UmmlaOp,
+ DupQLaneOp,
ScalableMaskedAddIOp,
ScalableMaskedAddFOp,
ScalableMaskedSubIOp,
diff --git a/mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir b/mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir
index bdb69a95a52de..5d044517e0ea8 100644
--- a/mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir
+++ b/mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir
@@ -271,3 +271,46 @@ func.func @arm_sve_psel_mixed_predicate_types(%p0: vector<[8]xi1>, %p1: vector<[
%0 = arm_sve.psel %p0, %p1[%index] : vector<[8]xi1>, vector<[16]xi1>
return %0 : vector<[8]xi1>
}
+
+// -----
+
+// CHECK-LABEL: @arm_sve_dupq_lane(
+// CHECK-SAME: %[[A0:[a-z0-9]+]]: vector<[16]xi8>
+// CHECK-SAME: %[[A1:[a-z0-9]+]]: vector<[8]xi16>
+// CHECK-SAME: %[[A2:[a-z0-9]+]]: vector<[8]xf16>
+// CHECK-SAME: %[[A3:[a-z0-9]+]]: vector<[8]xbf16>
+// CHECK-SAME: %[[A4:[a-z0-9]+]]: vector<[4]xi32>
+// CHECK-SAME: %[[A5:[a-z0-9]+]]: vector<[4]xf32>
+// CHECK-SAME: %[[A6:[a-z0-9]+]]: vector<[2]xi64>
+// CHECK-SAME: %[[A7:[a-z0-9]+]]: vector<[2]xf64>
+// CHECK-SAME: -> !llvm.struct<(vector<[16]xi8>, vector<[8]xi16>, vector<[8]xf16>, vector<[8]xbf16>, vector<[4]xi32>, vector<[4]xf32>, vector<[2]xi64>, vector<[2]xf64>)> {
+
+// CHECK: "arm_sve.intr.dupq_lane"(%[[A0]]) <{lane = 0 : i64}> : (vector<[16]xi8>) -> vector<[16]xi8>
+// CHECK: "arm_sve.intr.dupq_lane"(%[[A1]]) <{lane = 1 : i64}> : (vector<[8]xi16>) -> vector<[8]xi16>
+// CHECK: "arm_sve.intr.dupq_lane"(%[[A2]]) <{lane = 2 : i64}> : (vector<[8]xf16>) -> vector<[8]xf16>
+// CHECK: "arm_sve.intr.dupq_lane"(%[[A3]]) <{lane = 3 : i64}> : (vector<[8]xbf16>) -> vector<[8]xbf16>
+// CHECK: "arm_sve.intr.dupq_lane"(%[[A4]]) <{lane = 4 : i64}> : (vector<[4]xi32>) -> vector<[4]xi32>
+// CHECK: "arm_sve.intr.dupq_lane"(%[[A5]]) <{lane = 5 : i64}> : (vector<[4]xf32>) -> vector<[4]xf32>
+// CHECK: "arm_sve.intr.dupq_lane"(%[[A6]]) <{lane = 6 : i64}> : (vector<[2]xi64>) -> vector<[2]xi64>
+// CHECK: "arm_sve.intr.dupq_lane"(%[[A7]]) <{lane = 7 : i64}> : (vector<[2]xf64>) -> vector<[2]xf64>
+func.func @arm_sve_dupq_lane(
+ %v16i8: vector<[16]xi8>, %v8i16: vector<[8]xi16>,
+ %v8f16: vector<[8]xf16>, %v8bf16: vector<[8]xbf16>,
+ %v4i32: vector<[4]xi32>, %v4f32: vector<[4]xf32>,
+ %v2i64: vector<[2]xi64>, %v2f64: vector<[2]xf64>)
+ -> (vector<[16]xi8>, vector<[8]xi16>, vector<[8]xf16>, vector<[8]xbf16>,
+ vector<[4]xi32>, vector<[4]xf32>, vector<[2]xi64>, vector<[2]xf64>) {
+
+ %0 = arm_sve.dupq_lane %v16i8[0] : vector<[16]xi8>
+ %1 = arm_sve.dupq_lane %v8i16[1] : vector<[8]xi16>
+ %2 = arm_sve.dupq_lane %v8f16[2] : vector<[8]xf16>
+ %3 = arm_sve.dupq_lane %v8bf16[3] : vector<[8]xbf16>
+ %4 = arm_sve.dupq_lane %v4i32[4] : vector<[4]xi32>
+ %5 = arm_sve.dupq_lane %v4f32[5] : vector<[4]xf32>
+ %6 = arm_sve.dupq_lane %v2i64[6] : vector<[2]xi64>
+ %7 = arm_sve.dupq_lane %v2f64[7] : vector<[2]xf64>
+
+ return %0, %1, %2, %3, %4, %5, %6, %7
+ : vector<[16]xi8>, vector<[8]xi16>, vector<[8]xf16>, vector<[8]xbf16>,
+ vector<[4]xi32>, vector<[4]xf32>, vector<[2]xi64>, vector<[2]xf64>
+}
diff --git a/mlir/test/Target/LLVMIR/arm-sve.mlir b/mlir/test/Target/LLVMIR/arm-sve.mlir
index ed5a1fc7ba2e4..ced59eb513b57 100644
--- a/mlir/test/Target/LLVMIR/arm-sve.mlir
+++ b/mlir/test/Target/LLVMIR/arm-sve.mlir
@@ -390,3 +390,37 @@ llvm.func @arm_sve_psel(%pn: vector<[16]xi1>, %p1: vector<[2]xi1>, %p2: vector<[
"arm_sve.intr.psel"(%pn, %p4, %index) : (vector<[16]xi1>, vector<[16]xi1>, i32) -> vector<[16]xi1>
llvm.return
}
+
+// CHECK-LABEL: @arm_sve_dupq_lane
+// CHECK-SAME: <vscale x 16 x i8> %0
+// CHECK-SAME: <vscale x 8 x i16> %1
+// CHECK-SAME: <vscale x 8 x half> %2
+// CHECK-SAME: <vscale x 8 x bfloat> %3
+// CHECK-SAME: <vscale x 4 x i32> %4
+// CHECK-SAME: <vscale x 4 x float> %5
+// CHECK-SAME: <vscale x 2 x i64> %6
+// CHECK-SAME: <vscale x 2 x double> %7
+
+
+llvm.func @arm_sve_dupq_lane(%arg0: vector<[16]xi8>, %arg1: vector<[8]xi16>,
+ %arg2: vector<[8]xf16>, %arg3: vector<[8]xbf16>,
+ %arg4: vector<[4]xi32>,%arg5: vector<[4]xf32>,
+ %arg6: vector<[2]xi64>, %arg7: vector<[2]xf64>) {
+ // CHECK: call <vscale x 16 x i8> @llvm.aarch64.sve.dupq.lane.nxv16i8(<vscale x 16 x i8> %0, i64 0)
+ %0 = "arm_sve.intr.dupq_lane"(%arg0) <{lane = 0 : i64}> : (vector<[16]xi8>) -> vector<[16]xi8>
+ // CHECK: call <vscale x 8 x i16> @llvm.aarch64.sve.dupq.lane.nxv8i16(<vscale x 8 x i16> %1, i64 1)
+ %1 = "arm_sve.intr.dupq_lane"(%arg1) <{lane = 1 : i64}> : (vector<[8]xi16>) -> vector<[8]xi16>
+ // CHECK: call <vscale x 8 x half> @llvm.aarch64.sve.dupq.lane.nxv8f16(<vscale x 8 x half> %2, i64 2)
+ %2 = "arm_sve.intr.dupq_lane"(%arg2) <{lane = 2 : i64}> : (vector<[8]xf16>) -> vector<[8]xf16>
+ // CHECK: call <vscale x 8 x bfloat> @llvm.aarch64.sve.dupq.lane.nxv8bf16(<vscale x 8 x bfloat> %3, i64 3)
+ %3 = "arm_sve.intr.dupq_lane"(%arg3) <{lane = 3 : i64}> : (vector<[8]xbf16>) -> vector<[8]xbf16>
+ // CHECK: call <vscale x 4 x i32> @llvm.aarch64.sve.dupq.lane.nxv4i32(<vscale x 4 x i32> %4, i64 4)
+ %4 = "arm_sve.intr.dupq_lane"(%arg4) <{lane = 4 : i64}> : (vector<[4]xi32>) -> vector<[4]xi32>
+ // CHECK: call <vscale x 4 x float> @llvm.aarch64.sve.dupq.lane.nxv4f32(<vscale x 4 x float> %5, i64 5)
+ %5 = "arm_sve.intr.dupq_lane"(%arg5) <{lane = 5 : i64}> : (vector<[4]xf32>) -> vector<[4]xf32>
+ // CHECK: call <vscale x 2 x i64> @llvm.aarch64.sve.dupq.lane.nxv2i64(<vscale x 2 x i64> %6, i64 6)
+ %6 = "arm_sve.intr.dupq_lane"(%arg6) <{lane = 6 : i64}> : (vector<[2]xi64>) -> vector<[2]xi64>
+ // CHECK: call <vscale x 2 x double> @llvm.aarch64.sve.dupq.lane.nxv2f64(<vscale x 2 x double> %7, i64 7)
+ %7 = "arm_sve.intr.dupq_lane"(%arg7) <{lane = 7 : i64}> : (vector<[2]xf64>) -> vector<[2]xf64>
+ llvm.return
+}
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM % some minor comments. Thanks!
02e68ee
to
7e72ad9
Compare
7e72ad9
to
dbf1aa0
Compare
dbf1aa0
to
7c36d36
Compare
Supersedes #135356 this time using user branches.