-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][ArmSVE] Lower predicate-sized vector.create_masks to whilelt #95531
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
This better/more canonical codegen than the generic LLVM lowering, which is a pattern the backend currently does not recognize. See: llvm#81840.
@llvm/pr-subscribers-mlir-sve @llvm/pr-subscribers-mlir Author: Benjamin Maxwell (MacDue) ChangesThis better/more canonical codegen than the generic LLVM lowering, which is a pattern the backend currently does not recognize. See: #81840. Full diff: https://github.com/llvm/llvm-project/pull/95531.diff 4 Files Affected:
diff --git a/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td b/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td
index f2d330c98e7d6..aea55830c6607 100644
--- a/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td
+++ b/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td
@@ -552,4 +552,11 @@ def ZipX4IntrOp : ArmSVE_IntrOp<"zip.x4",
Arg<AnyScalableVector, "v3">:$v3,
Arg<AnyScalableVector, "v3">:$v4)>;
+def WhileLTIntrOp :
+ ArmSVE_IntrOp<"whilelt",
+ [TypeIs<"res", SVEPredicate>, Pure],
+ /*overloadedOperands=*/[0],
+ /*overloadedResults=*/[0]>,
+ Arguments<(ins I64:$base, I64:$n)>;
+
#endif // ARMSVE_OPS
diff --git a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp
index 387937e811ced..7facb3f6b9da0 100644
--- a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp
+++ b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp
@@ -140,6 +140,40 @@ using ConvertFromSvboolOpLowering =
using ZipX2OpLowering = OneToOneConvertToLLVMPattern<ZipX2Op, ZipX2IntrOp>;
using ZipX4OpLowering = OneToOneConvertToLLVMPattern<ZipX4Op, ZipX4IntrOp>;
+/// Converts `vector.create_mask` ops that match the size of an SVE predicate
+/// to the `whilelt` intrinsic. This produces more canonical codegen than the
+/// generic LLVM lowering, see https://github.com/llvm/llvm-project/issues/81840
+/// for more details. Note that we can't (the more general) get.active.lane.mask
+/// as its semantics don't neatly map on to `vector.create_mask`, as it does an
+/// unsigned comparison (whereas `create_mask` is signed), and is UB/posion if
+/// `n` is zero (whereas `create_mask` just returns an all-false mask).
+struct PredicateCreateMaskOpLowering
+ : public ConvertOpToLLVMPattern<vector::CreateMaskOp> {
+ using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
+
+ LogicalResult
+ matchAndRewrite(vector::CreateMaskOp createMaskOp,
+ vector::CreateMaskOp::Adaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto maskType = createMaskOp.getVectorType();
+ if (maskType.getRank() != 1 || !maskType.isScalable())
+ return failure();
+
+ // TODO: Support masks which are multiples of SVE predicates.
+ auto maskBaseSize = maskType.getDimSize(0);
+ if (maskBaseSize < 2 || maskBaseSize > 16 ||
+ !llvm::isPowerOf2_32(uint32_t(maskBaseSize)))
+ return failure();
+
+ auto loc = createMaskOp.getLoc();
+ auto zero = rewriter.create<LLVM::ZeroOp>(
+ loc, typeConverter->convertType(rewriter.getI64Type()));
+ rewriter.replaceOpWithNewOp<WhileLTIntrOp>(createMaskOp, maskType, zero,
+ adaptor.getOperands()[0]);
+ return success();
+ }
+};
+
} // namespace
/// Populate the given list with patterns that convert from ArmSVE to LLVM.
@@ -169,6 +203,9 @@ void mlir::populateArmSVELegalizeForLLVMExportPatterns(
ConvertFromSvboolOpLowering,
ZipX2OpLowering,
ZipX4OpLowering>(converter);
+ // Add predicate conversion with a high benefit as it produces much nicer code
+ // than the generic lowering.
+ patterns.add<PredicateCreateMaskOpLowering>(converter, /*benifit=*/4096);
// clang-format on
}
@@ -191,7 +228,8 @@ void mlir::configureArmSVELegalizeForExportTarget(
ConvertToSvboolIntrOp,
ConvertFromSvboolIntrOp,
ZipX2IntrOp,
- ZipX4IntrOp>();
+ ZipX4IntrOp,
+ WhileLTIntrOp>();
target.addIllegalOp<SdotOp,
SmmlaOp,
UdotOp,
diff --git a/mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir b/mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir
index 8d11c2bcaa8d5..3fc5e6e9fcc96 100644
--- a/mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir
+++ b/mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -convert-vector-to-llvm="enable-arm-sve" -convert-func-to-llvm -reconcile-unrealized-casts -split-input-file %s | FileCheck %s
+// RUN: mlir-opt -convert-vector-to-llvm="enable-arm-sve" -convert-func-to-llvm -cse -reconcile-unrealized-casts -split-input-file %s | FileCheck %s
func.func @arm_sve_sdot(%a: vector<[16]xi8>,
%b: vector<[16]xi8>,
@@ -211,3 +211,31 @@ func.func @arm_sve_zip_x4(
%0, %1, %2, %3 = arm_sve.zip.x4 %a, %b, %c, %d : vector<[16]xi8>
return %0, %1, %2, %3 : vector<[16]xi8>, vector<[16]xi8>, vector<[16]xi8>, vector<[16]xi8>
}
+
+// -----
+
+// CHECK-LABEL: @arm_sve_predicate_sized_create_masks(
+// CHECK-SAME: %[[INDEX:.*]]: i64
+func.func @arm_sve_predicate_sized_create_masks(%index: index) -> (vector<[2]xi1>, vector<[4]xi1>, vector<[8]xi1>, vector<[16]xi1>) {
+ // CHECK: %[[ZERO:.*]] = llvm.mlir.zero : i64
+ // CHECK: %[[P2:.*]] = "arm_sve.intr.whilelt"(%[[ZERO]], %[[INDEX]]) : (i64, i64) -> vector<[2]xi1>
+ %0 = vector.create_mask %index : vector<[2]xi1>
+ // CHECK: %[[P4:.*]] = "arm_sve.intr.whilelt"(%[[ZERO]], %[[INDEX]]) : (i64, i64) -> vector<[4]xi1>
+ %1 = vector.create_mask %index : vector<[4]xi1>
+ // CHECK: %[[P8:.*]] = "arm_sve.intr.whilelt"(%[[ZERO]], %[[INDEX]]) : (i64, i64) -> vector<[8]xi1>
+ %2 = vector.create_mask %index : vector<[8]xi1>
+ // CHECK: %[[P16:.*]] = "arm_sve.intr.whilelt"(%[[ZERO]], %[[INDEX]]) : (i64, i64) -> vector<[16]xi1>
+ %3 = vector.create_mask %index : vector<[16]xi1>
+ return %0, %1, %2, %3 : vector<[2]xi1>, vector<[4]xi1>, vector<[8]xi1>, vector<[16]xi1>
+}
+
+// -----
+
+// CHECK-LABEL: @arm_sve_unsupported_create_masks
+func.func @arm_sve_unsupported_create_masks(%index: index) -> (vector<[1]xi1>, vector<[7]xi1>, vector<[32]xi1>) {
+ // CHECK-NOT: arm_sve.intr.whilelt
+ %0 = vector.create_mask %index : vector<[1]xi1>
+ %1 = vector.create_mask %index : vector<[7]xi1>
+ %2 = vector.create_mask %index : vector<[32]xi1>
+ return %0, %1, %2 : vector<[1]xi1>, vector<[7]xi1>, vector<[32]xi1>
+}
diff --git a/mlir/test/Target/LLVMIR/arm-sve.mlir b/mlir/test/Target/LLVMIR/arm-sve.mlir
index c7cd1b74ccdb5..34413d46b440e 100644
--- a/mlir/test/Target/LLVMIR/arm-sve.mlir
+++ b/mlir/test/Target/LLVMIR/arm-sve.mlir
@@ -356,3 +356,18 @@ llvm.func @arm_sve_zip_x4(%nxv16i8: vector<[16]xi8>, %nxv8i16: vector<[8]xi16>,
-> !llvm.struct<(vector<[2]xi64>, vector<[2]xi64>, vector<[2]xi64>, vector<[2]xi64>)>
llvm.return
}
+
+// CHECK-LABEL: arm_sve_whilelt(
+// CHECK-SAME: i64 %[[BASE:[0-9]+]],
+// CHECK-SAME: i64 %[[N:[0-9]+]]
+llvm.func @arm_sve_whilelt(%base: i64, %n: i64) {
+ // call <vscale x 2 x i1> @llvm.aarch64.sve.whilelt.nxv2i1.i64(i64 %[[BASE]], i64 %[[N]])
+ %1 = "arm_sve.intr.whilelt"(%base, %n) : (i64, i64) -> vector<[2]xi1>
+ // call <vscale x 4 x i1> @llvm.aarch64.sve.whilelt.nxv4i1.i64(i64 %[[BASE]], i64 %[[N]])
+ %2 = "arm_sve.intr.whilelt"(%base, %n) : (i64, i64) -> vector<[4]xi1>
+ // call <vscale x 8 x i1> @llvm.aarch64.sve.whilelt.nxv8i1.i64(i64 %[[BASE]], i64 %[[N]])
+ %3 = "arm_sve.intr.whilelt"(%base, %n) : (i64, i64) -> vector<[8]xi1>
+ // call <vscale x 16 x i1> @llvm.aarch64.sve.whilelt.nxv16i1.i64(i64 %[[BASE]], i64 %[[N]])
+ %4 = "arm_sve.intr.whilelt"(%base, %n) : (i64, i64) -> vector<[16]xi1>
+ llvm.return
+}
|
As an example:
Currently lowers to:
With this patch it lowers to:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks for the patch Ben this is a nice improvement!
Left some comments mostly minor, also occurred to me this is the first Vector -> ArmSVE conversion but we're doing it during LLVM conversion. I don't think we want to prematurely introduce a vector-to-arm-sve conversion pass, but something to bear in mind if this grows.
It's Vector -> LLVM SVE intrinsics (not high-level ArmSVE operations). I'm not sure if a general Vector -> ArmSVE conversion would make much sense, as I don't think it'd be much more expressive than the vector dialect (and there's no problems needing solving there, e.g. like tile allocation for SME). Also, with some later patches in mind the ArmSVE ops I use have specific requirements (like |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice, thanks!
What was the lowering path before? As in, we'd start with vector.create_mask
and then ...?
It's shown in the linked issue ( |
This produces better/more canonical codegen than the generic LLVM lowering, which is a pattern the backend currently does not recognize. See: #81840.