-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][ArmSVE] Add arm_sve.psel
operation
#95764
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
This adds a new operation for the SME/SVE2 psel instruction. This allows selecting a predicate based on a bit within another predicate, essentially allowing for 2-D predication. Informally the semantics are: ```mlir %pd = arm_sve.psel %p1, %p2[%index] : vector<[4]xi1>, vector<[8]xi1> ``` => ``` if p2[index % num_elements(p2)] == 1: pd = p1 : type(p1) else: pd = all-false : type(p1) ```
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-sve Author: Benjamin Maxwell (MacDue) ChangesThis adds a new operation for the SME/SVE2 psel instruction. This allows selecting a predicate based on a bit within another predicate, essentially allowing for 2-D predication. Informally the semantics are: %pd = arm_sve.psel %p1, %p2[%index] : vector<[4]xi1>, vector<[8]xi1> =>
Full diff: https://github.com/llvm/llvm-project/pull/95764.diff 6 Files Affected:
diff --git a/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td b/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td
index aea55830c6607..5b98b21720ada 100644
--- a/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td
+++ b/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td
@@ -442,6 +442,43 @@ def ZipX4Op : ArmSVE_Op<"zip.x4", [
}];
}
+def PselOp : ArmSVE_Op<"psel", [
+ Pure,
+ AllTypesMatch<["p1", "result"]>,
+]> {
+ let summary = "Predicate select";
+
+ let description = [{
+ This operation returns the input predicate `p1` or an all-false predicate
+ based on the bit at `p2[index]`. Informally the semantics are:
+ ```
+ if p2[index % num_elements(p2)] == 1:
+ return p1 : type(p1)
+ return all-false : type(p1)
+ ```
+
+ Example:
+ ```mlir
+ // Note: p1 and p2 can have different sizes.
+ %pd = arm_sve.psel %p1, %p2[%index] : vector<[4]xi1>, vector<[8]xi1>
+ ```
+
+ Note: This requires SME or SVE2 (`+sme` or `+sve2` in LLVM target features).
+ }];
+
+ let arguments = (ins SVEPredicate:$p1, SVEPredicate:$p2, Index:$index);
+ let results = (outs SVEPredicate:$result);
+
+ let builders = [
+ OpBuilder<(ins "Value":$p1, "Value":$p2, "Value":$index), [{
+ build($_builder, $_state, p1.getType(), p1, p2, index);
+ }]>];
+
+ let assemblyFormat = [{
+ $p1 `,` $p2 `[` $index `]` attr-dict `:` type($p1) `,` type($p2)
+ }];
+}
+
def ScalableMaskedAddIOp : ScalableMaskedIOp<"masked.addi", "addition",
[Commutative]>;
@@ -552,6 +589,14 @@ def ZipX4IntrOp : ArmSVE_IntrOp<"zip.x4",
Arg<AnyScalableVector, "v3">:$v3,
Arg<AnyScalableVector, "v3">:$v4)>;
+// Note: This intrinsic requires SME or SVE2.
+def PselIntrOp : ArmSVE_IntrOp<"psel",
+ /*traits=*/[Pure, TypeIs<"res", SVBool>],
+ /*overloadedOperands=*/[1]>,
+ Arguments<(ins Arg<SVBool, "p1">:$p1,
+ Arg<SVEPredicate, "p2">:$p2,
+ Arg<I32, "index">:$index)>;
+
def WhileLTIntrOp :
ArmSVE_IntrOp<"whilelt",
[TypeIs<"res", SVEPredicate>, Pure],
diff --git a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp
index ed4f4cc7f0718..10f39a0855f5f 100644
--- a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp
+++ b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp
@@ -140,6 +140,28 @@ using ConvertFromSvboolOpLowering =
using ZipX2OpLowering = OneToOneConvertToLLVMPattern<ZipX2Op, ZipX2IntrOp>;
using ZipX4OpLowering = OneToOneConvertToLLVMPattern<ZipX4Op, ZipX4IntrOp>;
+/// Lower `arm_sve.psel` to LLVM intrinsics. This is almost a 1-to-1 conversion
+/// but first input (P1) and result predicates need conversion to/from svbool.
+struct PselOpLowering : public ConvertOpToLLVMPattern<PselOp> {
+ using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
+
+ LogicalResult
+ matchAndRewrite(PselOp pselOp, PselOp::Adaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto svboolType = VectorType::get(16, rewriter.getI1Type(), true);
+ auto loc = pselOp.getLoc();
+ auto svboolP1 = rewriter.create<ConvertToSvboolIntrOp>(loc, svboolType,
+ adaptor.getP1());
+ auto indexI32 = rewriter.create<arith::IndexCastOp>(
+ loc, rewriter.getI32Type(), pselOp.getIndex());
+ auto pselIntr = rewriter.create<PselIntrOp>(loc, svboolType, svboolP1,
+ pselOp.getP2(), indexI32);
+ rewriter.replaceOpWithNewOp<ConvertFromSvboolIntrOp>(
+ pselOp, adaptor.getP1().getType(), pselIntr);
+ return success();
+ }
+};
+
/// Converts `vector.create_mask` ops that match the size of an SVE predicate
/// to the `whilelt` intrinsic. This produces more canonical codegen than the
/// generic LLVM lowering, see https://github.com/llvm/llvm-project/issues/81840
@@ -202,7 +224,8 @@ void mlir::populateArmSVELegalizeForLLVMExportPatterns(
ConvertToSvboolOpLowering,
ConvertFromSvboolOpLowering,
ZipX2OpLowering,
- ZipX4OpLowering>(converter);
+ ZipX4OpLowering,
+ PselOpLowering>(converter);
// Add vector.create_mask conversion with a high benefit as it produces much
// nicer code than the generic lowering.
patterns.add<CreateMaskOpLowering>(converter, /*benefit=*/4096);
@@ -229,6 +252,7 @@ void mlir::configureArmSVELegalizeForExportTarget(
ConvertFromSvboolIntrOp,
ZipX2IntrOp,
ZipX4IntrOp,
+ PselIntrOp,
WhileLTIntrOp>();
target.addIllegalOp<SdotOp,
SmmlaOp,
diff --git a/mlir/test/Dialect/ArmSVE/invalid.mlir b/mlir/test/Dialect/ArmSVE/invalid.mlir
index 1258d3532c049..27b19f4321f8d 100644
--- a/mlir/test/Dialect/ArmSVE/invalid.mlir
+++ b/mlir/test/Dialect/ArmSVE/invalid.mlir
@@ -64,3 +64,11 @@ func.func @arm_sve_zip_x4_bad_vector_type(%a : vector<[5]xf64>) {
arm_sve.zip.x4 %a, %a, %a, %a : vector<[5]xf64>
return
}
+
+// -----
+
+func.func @arm_sve_psel_bad_vector_type(%a : vector<[7]xi1>, %index: index) {
+ // expected-error@+1 {{op operand #0 must be of ranks 1scalable vector of 1-bit signless integer values of length 16/8/4/2/1, but got 'vector<[7]xi1>'}}
+ arm_sve.psel %a, %a[%index] : vector<[7]xi1>, vector<[7]xi1>
+ return
+}
diff --git a/mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir b/mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir
index 3fc5e6e9fcc96..ef792fcf988ce 100644
--- a/mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir
+++ b/mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir
@@ -239,3 +239,35 @@ func.func @arm_sve_unsupported_create_masks(%index: index) -> (vector<[1]xi1>, v
%2 = vector.create_mask %index : vector<[32]xi1>
return %0, %1, %2 : vector<[1]xi1>, vector<[7]xi1>, vector<[32]xi1>
}
+
+// -----
+
+// CHECK-LABEL: @arm_sve_psel_matching_predicate_types(
+// CHECK-SAME: %[[P0:[a-z0-9]+]]: vector<[4]xi1>,
+// CHECK-SAME: %[[P1:[a-z0-9]+]]: vector<[4]xi1>,
+// CHECK-SAME: %[[INDEX:[a-z0-9]+]]: i64
+func.func @arm_sve_psel_matching_predicate_types(%a: vector<[4]xi1>, %b: vector<[4]xi1>, %index: index) -> vector<[4]xi1>
+{
+ // CHECK-DAG: %[[INDEX_I32:.*]] = llvm.trunc %[[INDEX]] : i64 to i32
+ // CHECK-DAG: %[[P0_IN:.*]] = "arm_sve.intr.convert.to.svbool"(%[[P0]]) : (vector<[4]xi1>) -> vector<[16]xi1>
+ // CHECK-NEXT: %[[PSEL:.*]] = "arm_sve.intr.psel"(%[[P0_IN]], %[[P1]], %[[INDEX_I32]]) : (vector<[16]xi1>, vector<[4]xi1>, i32) -> vector<[16]xi1>
+ // CHECK-NEXT: %[[RES:.*]] = "arm_sve.intr.convert.from.svbool"(%[[PSEL]]) : (vector<[16]xi1>) -> vector<[4]xi1>
+ %0 = arm_sve.psel %a, %b[%index] : vector<[4]xi1>, vector<[4]xi1>
+ return %0 : vector<[4]xi1>
+}
+
+// -----
+
+// CHECK-LABEL: @arm_sve_psel_mixed_predicate_types(
+// CHECK-SAME: %[[P0:[a-z0-9]+]]: vector<[8]xi1>,
+// CHECK-SAME: %[[P1:[a-z0-9]+]]: vector<[16]xi1>,
+// CHECK-SAME: %[[INDEX:[a-z0-9]+]]: i64
+func.func @arm_sve_psel_mixed_predicate_types(%a: vector<[8]xi1>, %b: vector<[16]xi1>, %index: index) -> vector<[8]xi1>
+{
+ // CHECK-DAG: %[[INDEX_I32:.*]] = llvm.trunc %[[INDEX]] : i64 to i32
+ // CHECK-DAG: %[[P0_IN:.*]] = "arm_sve.intr.convert.to.svbool"(%[[P0]]) : (vector<[8]xi1>) -> vector<[16]xi1>
+ // CHECK-NEXT: %[[PSEL:.*]] = "arm_sve.intr.psel"(%[[P0_IN]], %[[P1]], %[[INDEX_I32]]) : (vector<[16]xi1>, vector<[16]xi1>, i32) -> vector<[16]xi1>
+ // CHECK-NEXT: %[[RES:.*]] = "arm_sve.intr.convert.from.svbool"(%[[PSEL]]) : (vector<[16]xi1>) -> vector<[8]xi1>
+ %0 = arm_sve.psel %a, %b[%index] : vector<[8]xi1>, vector<[16]xi1>
+ return %0 : vector<[8]xi1>
+}
diff --git a/mlir/test/Dialect/ArmSVE/roundtrip.mlir b/mlir/test/Dialect/ArmSVE/roundtrip.mlir
index f7b79aa2f275c..0f0c5a8575772 100644
--- a/mlir/test/Dialect/ArmSVE/roundtrip.mlir
+++ b/mlir/test/Dialect/ArmSVE/roundtrip.mlir
@@ -225,3 +225,32 @@ func.func @arm_sve_zip_x4(
%a8, %b8, %c8, %d8 = arm_sve.zip.x4 %v8, %v8, %v8, %v8 : vector<[16]xi8>
return
}
+
+// -----
+
+func.func @arm_sve_psel(
+ %p0: vector<[2]xi1>,
+ %p1: vector<[4]xi1>,
+ %p2: vector<[8]xi1>,
+ %p3: vector<[16]xi1>,
+ %index: index
+) {
+ // CHECK: arm_sve.psel %{{.*}}, %{{.*}}[%{{.*}}] : vector<[2]xi1>, vector<[2]xi1>
+ %0 = arm_sve.psel %p0, %p0[%index] : vector<[2]xi1>, vector<[2]xi1>
+ // CHECK: arm_sve.psel %{{.*}}, %{{.*}}[%{{.*}}] : vector<[4]xi1>, vector<[4]xi1>
+ %1 = arm_sve.psel %p1, %p1[%index] : vector<[4]xi1>, vector<[4]xi1>
+ // CHECK: arm_sve.psel %{{.*}}, %{{.*}}[%{{.*}}] : vector<[8]xi1>, vector<[8]xi1>
+ %2 = arm_sve.psel %p2, %p2[%index] : vector<[8]xi1>, vector<[8]xi1>
+ // CHECK: arm_sve.psel %{{.*}}, %{{.*}}[%{{.*}}] : vector<[16]xi1>, vector<[16]xi1>
+ %3 = arm_sve.psel %p3, %p3[%index] : vector<[16]xi1>, vector<[16]xi1>
+ /// Some mixed predicate type examples:
+ // CHECK: arm_sve.psel %{{.*}}, %{{.*}}[%{{.*}}] : vector<[2]xi1>, vector<[4]xi1>
+ %4 = arm_sve.psel %p0, %p1[%index] : vector<[2]xi1>, vector<[4]xi1>
+ // CHECK: arm_sve.psel %{{.*}}, %{{.*}}[%{{.*}}] : vector<[4]xi1>, vector<[8]xi1>
+ %5 = arm_sve.psel %p1, %p2[%index] : vector<[4]xi1>, vector<[8]xi1>
+ // CHECK: arm_sve.psel %{{.*}}, %{{.*}}[%{{.*}}] : vector<[8]xi1>, vector<[16]xi1>
+ %6 = arm_sve.psel %p2, %p3[%index] : vector<[8]xi1>, vector<[16]xi1>
+ // CHECK: arm_sve.psel %{{.*}}, %{{.*}}[%{{.*}}] : vector<[16]xi1>, vector<[2]xi1>
+ %7 = arm_sve.psel %p3, %p0[%index] : vector<[16]xi1>, vector<[2]xi1>
+ return
+}
diff --git a/mlir/test/Target/LLVMIR/arm-sve.mlir b/mlir/test/Target/LLVMIR/arm-sve.mlir
index 34413d46b440e..ed5a1fc7ba2e4 100644
--- a/mlir/test/Target/LLVMIR/arm-sve.mlir
+++ b/mlir/test/Target/LLVMIR/arm-sve.mlir
@@ -371,3 +371,22 @@ llvm.func @arm_sve_whilelt(%base: i64, %n: i64) {
%4 = "arm_sve.intr.whilelt"(%base, %n) : (i64, i64) -> vector<[16]xi1>
llvm.return
}
+
+// CHECK-LABEL: arm_sve_psel(
+// CHECK-SAME: <vscale x 16 x i1> %[[PN:[0-9]+]],
+// CHECK-SAME: <vscale x 2 x i1> %[[P1:[0-9]+]],
+// CHECK-SAME: <vscale x 4 x i1> %[[P2:[0-9]+]],
+// CHECK-SAME: <vscale x 8 x i1> %[[P3:[0-9]+]],
+// CHECK-SAME: <vscale x 16 x i1> %[[P4:[0-9]+]],
+// CHECK-SAME: i32 %[[INDEX:[0-9]+]])
+llvm.func @arm_sve_psel(%pn: vector<[16]xi1>, %p1: vector<[2]xi1>, %p2: vector<[4]xi1>, %p3: vector<[8]xi1>, %p4: vector<[16]xi1>, %index: i32) {
+ // CHECK: call <vscale x 16 x i1> @llvm.aarch64.sve.psel.nxv2i1(<vscale x 16 x i1> %[[PN]], <vscale x 2 x i1> %[[P1]], i32 %[[INDEX]])
+ "arm_sve.intr.psel"(%pn, %p1, %index) : (vector<[16]xi1>, vector<[2]xi1>, i32) -> vector<[16]xi1>
+ // CHECK: call <vscale x 16 x i1> @llvm.aarch64.sve.psel.nxv4i1(<vscale x 16 x i1> %[[PN]], <vscale x 4 x i1> %[[P2]], i32 %[[INDEX]])
+ "arm_sve.intr.psel"(%pn, %p2, %index) : (vector<[16]xi1>, vector<[4]xi1>, i32) -> vector<[16]xi1>
+ // CHECK: call <vscale x 16 x i1> @llvm.aarch64.sve.psel.nxv8i1(<vscale x 16 x i1> %[[PN]], <vscale x 8 x i1> %[[P3]], i32 %[[INDEX]])
+ "arm_sve.intr.psel"(%pn, %p3, %index) : (vector<[16]xi1>, vector<[8]xi1>, i32) -> vector<[16]xi1>
+ // CHECK: call <vscale x 16 x i1> @llvm.aarch64.sve.psel.nxv16i1(<vscale x 16 x i1> %[[PN]], <vscale x 16 x i1> %[[P4]], i32 %[[INDEX]])
+ "arm_sve.intr.psel"(%pn, %p4, %index) : (vector<[16]xi1>, vector<[16]xi1>, i32) -> vector<[16]xi1>
+ llvm.return
+}
|
Co-authored-by: Cullen Rhodes <[email protected]>
%pd = arm_sve.psel %p1, %p2[%index] : vector<[4]xi1>, vector<[8]xi1> | ||
``` | ||
|
||
Note: This requires SME or SVE2.1 (`+sme` or `+sve2p1` in LLVM target features). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note: This requires SME or SVE2.1 (`+sme` or `+sve2p1` in LLVM target features). | |
Note: This requires SME or SVE2.1 (`+sme` or `+sve2p1` in LLVM target features when lowering to assembly and/or machine code). |
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This would be inconsistent with other ArmSVE ops and is not really true. Feature flags are used much earlier to check things within IREE, for example.
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/116/builds/221 Here is the relevant piece of the build log for the reference:
|
Failure seems to be random CI flake/timeout issue. |
This adds a new operation for the SME/SVE2.1 psel instruction. This allows selecting a predicate based on a bit within another predicate, essentially allowing for 2-D predication. Informally, the semantics are: ```mlir %pd = arm_sve.psel %p1, %p2[%index] : vector<[4]xi1>, vector<[8]xi1> ``` => ``` if p2[index % num_elements(p2)] == 1: pd = p1 : type(p1) else: pd = all-false : type(p1) ```
This adds a new operation for the SME/SVE2.1 psel instruction. This allows selecting a predicate based on a bit within another predicate, essentially allowing for 2-D predication. Informally, the semantics are:
=>