Skip to content

[mlir][ArmSVE] Lower predicate-sized vector.create_masks to whilelt #95531

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jun 17, 2024

Conversation

MacDue
Copy link
Member

@MacDue MacDue commented Jun 14, 2024

This produces better/more canonical codegen than the generic LLVM lowering, which is a pattern the backend currently does not recognize. See: #81840.

This better/more canonical codegen than the generic LLVM lowering, which
is a pattern the backend currently does not recognize. See:
llvm#81840.
@llvmbot
Copy link
Member

llvmbot commented Jun 14, 2024

@llvm/pr-subscribers-mlir-sve
@llvm/pr-subscribers-mlir-llvm

@llvm/pr-subscribers-mlir

Author: Benjamin Maxwell (MacDue)

Changes

This better/more canonical codegen than the generic LLVM lowering, which is a pattern the backend currently does not recognize. See: #81840.


Full diff: https://github.com/llvm/llvm-project/pull/95531.diff

4 Files Affected:

  • (modified) mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td (+7)
  • (modified) mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp (+39-1)
  • (modified) mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir (+29-1)
  • (modified) mlir/test/Target/LLVMIR/arm-sve.mlir (+15)
diff --git a/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td b/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td
index f2d330c98e7d6..aea55830c6607 100644
--- a/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td
+++ b/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td
@@ -552,4 +552,11 @@ def ZipX4IntrOp : ArmSVE_IntrOp<"zip.x4",
                    Arg<AnyScalableVector, "v3">:$v3,
                    Arg<AnyScalableVector, "v3">:$v4)>;
 
+def WhileLTIntrOp :
+  ArmSVE_IntrOp<"whilelt",
+    [TypeIs<"res", SVEPredicate>, Pure],
+    /*overloadedOperands=*/[0],
+    /*overloadedResults=*/[0]>,
+  Arguments<(ins I64:$base, I64:$n)>;
+
 #endif // ARMSVE_OPS
diff --git a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp
index 387937e811ced..7facb3f6b9da0 100644
--- a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp
+++ b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp
@@ -140,6 +140,40 @@ using ConvertFromSvboolOpLowering =
 using ZipX2OpLowering = OneToOneConvertToLLVMPattern<ZipX2Op, ZipX2IntrOp>;
 using ZipX4OpLowering = OneToOneConvertToLLVMPattern<ZipX4Op, ZipX4IntrOp>;
 
+/// Converts `vector.create_mask` ops that match the size of an SVE predicate
+/// to the `whilelt` intrinsic. This produces more canonical codegen than the
+/// generic LLVM lowering, see https://github.com/llvm/llvm-project/issues/81840
+/// for more details. Note that we can't (the more general) get.active.lane.mask
+/// as its semantics don't neatly map on to `vector.create_mask`, as it does an
+/// unsigned comparison (whereas `create_mask` is signed), and is UB/posion if
+/// `n` is zero (whereas `create_mask` just returns an all-false mask).
+struct PredicateCreateMaskOpLowering
+    : public ConvertOpToLLVMPattern<vector::CreateMaskOp> {
+  using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
+
+  LogicalResult
+  matchAndRewrite(vector::CreateMaskOp createMaskOp,
+                  vector::CreateMaskOp::Adaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto maskType = createMaskOp.getVectorType();
+    if (maskType.getRank() != 1 || !maskType.isScalable())
+      return failure();
+
+    // TODO: Support masks which are multiples of SVE predicates.
+    auto maskBaseSize = maskType.getDimSize(0);
+    if (maskBaseSize < 2 || maskBaseSize > 16 ||
+        !llvm::isPowerOf2_32(uint32_t(maskBaseSize)))
+      return failure();
+
+    auto loc = createMaskOp.getLoc();
+    auto zero = rewriter.create<LLVM::ZeroOp>(
+        loc, typeConverter->convertType(rewriter.getI64Type()));
+    rewriter.replaceOpWithNewOp<WhileLTIntrOp>(createMaskOp, maskType, zero,
+                                               adaptor.getOperands()[0]);
+    return success();
+  }
+};
+
 } // namespace
 
 /// Populate the given list with patterns that convert from ArmSVE to LLVM.
@@ -169,6 +203,9 @@ void mlir::populateArmSVELegalizeForLLVMExportPatterns(
                ConvertFromSvboolOpLowering,
                ZipX2OpLowering,
                ZipX4OpLowering>(converter);
+  // Add predicate conversion with a high benefit as it produces much nicer code
+  // than the generic lowering.
+  patterns.add<PredicateCreateMaskOpLowering>(converter, /*benifit=*/4096);
   // clang-format on
 }
 
@@ -191,7 +228,8 @@ void mlir::configureArmSVELegalizeForExportTarget(
                     ConvertToSvboolIntrOp,
                     ConvertFromSvboolIntrOp,
                     ZipX2IntrOp,
-                    ZipX4IntrOp>();
+                    ZipX4IntrOp,
+                    WhileLTIntrOp>();
   target.addIllegalOp<SdotOp,
                       SmmlaOp,
                       UdotOp,
diff --git a/mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir b/mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir
index 8d11c2bcaa8d5..3fc5e6e9fcc96 100644
--- a/mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir
+++ b/mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -convert-vector-to-llvm="enable-arm-sve" -convert-func-to-llvm -reconcile-unrealized-casts -split-input-file %s | FileCheck %s
+// RUN: mlir-opt -convert-vector-to-llvm="enable-arm-sve" -convert-func-to-llvm -cse -reconcile-unrealized-casts -split-input-file %s | FileCheck %s
 
 func.func @arm_sve_sdot(%a: vector<[16]xi8>,
                    %b: vector<[16]xi8>,
@@ -211,3 +211,31 @@ func.func @arm_sve_zip_x4(
   %0, %1, %2, %3 = arm_sve.zip.x4 %a, %b, %c, %d : vector<[16]xi8>
   return %0, %1, %2, %3 : vector<[16]xi8>, vector<[16]xi8>, vector<[16]xi8>, vector<[16]xi8>
 }
+
+// -----
+
+// CHECK-LABEL: @arm_sve_predicate_sized_create_masks(
+// CHECK-SAME:                                        %[[INDEX:.*]]: i64
+func.func @arm_sve_predicate_sized_create_masks(%index: index) -> (vector<[2]xi1>, vector<[4]xi1>, vector<[8]xi1>, vector<[16]xi1>) {
+  // CHECK: %[[ZERO:.*]] = llvm.mlir.zero : i64
+  // CHECK: %[[P2:.*]] = "arm_sve.intr.whilelt"(%[[ZERO]], %[[INDEX]]) : (i64, i64) -> vector<[2]xi1>
+  %0 = vector.create_mask %index : vector<[2]xi1>
+  // CHECK: %[[P4:.*]] = "arm_sve.intr.whilelt"(%[[ZERO]], %[[INDEX]]) : (i64, i64) -> vector<[4]xi1>
+  %1 = vector.create_mask %index : vector<[4]xi1>
+  // CHECK: %[[P8:.*]] = "arm_sve.intr.whilelt"(%[[ZERO]], %[[INDEX]]) : (i64, i64) -> vector<[8]xi1>
+  %2 = vector.create_mask %index : vector<[8]xi1>
+  // CHECK: %[[P16:.*]] = "arm_sve.intr.whilelt"(%[[ZERO]], %[[INDEX]]) : (i64, i64) -> vector<[16]xi1>
+  %3 = vector.create_mask %index : vector<[16]xi1>
+  return %0, %1, %2, %3 : vector<[2]xi1>, vector<[4]xi1>, vector<[8]xi1>, vector<[16]xi1>
+}
+
+// -----
+
+// CHECK-LABEL: @arm_sve_unsupported_create_masks
+func.func @arm_sve_unsupported_create_masks(%index: index) -> (vector<[1]xi1>, vector<[7]xi1>, vector<[32]xi1>)  {
+  // CHECK-NOT: arm_sve.intr.whilelt
+  %0 = vector.create_mask %index : vector<[1]xi1>
+  %1 = vector.create_mask %index : vector<[7]xi1>
+  %2 = vector.create_mask %index : vector<[32]xi1>
+  return %0, %1, %2 : vector<[1]xi1>, vector<[7]xi1>, vector<[32]xi1>
+}
diff --git a/mlir/test/Target/LLVMIR/arm-sve.mlir b/mlir/test/Target/LLVMIR/arm-sve.mlir
index c7cd1b74ccdb5..34413d46b440e 100644
--- a/mlir/test/Target/LLVMIR/arm-sve.mlir
+++ b/mlir/test/Target/LLVMIR/arm-sve.mlir
@@ -356,3 +356,18 @@ llvm.func @arm_sve_zip_x4(%nxv16i8: vector<[16]xi8>, %nxv8i16: vector<[8]xi16>,
      -> !llvm.struct<(vector<[2]xi64>, vector<[2]xi64>, vector<[2]xi64>, vector<[2]xi64>)>
   llvm.return
 }
+
+// CHECK-LABEL: arm_sve_whilelt(
+// CHECK-SAME:                  i64 %[[BASE:[0-9]+]],
+// CHECK-SAME:                  i64 %[[N:[0-9]+]]
+llvm.func @arm_sve_whilelt(%base: i64, %n: i64) {
+  // call <vscale x 2 x i1> @llvm.aarch64.sve.whilelt.nxv2i1.i64(i64 %[[BASE]], i64 %[[N]])
+  %1 = "arm_sve.intr.whilelt"(%base, %n) : (i64, i64) -> vector<[2]xi1>
+  // call <vscale x 4 x i1> @llvm.aarch64.sve.whilelt.nxv4i1.i64(i64 %[[BASE]], i64 %[[N]])
+  %2 = "arm_sve.intr.whilelt"(%base, %n) : (i64, i64) -> vector<[4]xi1>
+  // call <vscale x 8 x i1> @llvm.aarch64.sve.whilelt.nxv8i1.i64(i64 %[[BASE]], i64 %[[N]])
+  %3 = "arm_sve.intr.whilelt"(%base, %n) : (i64, i64) -> vector<[8]xi1>
+  // call <vscale x 16 x i1> @llvm.aarch64.sve.whilelt.nxv16i1.i64(i64 %[[BASE]], i64 %[[N]])
+  %4 = "arm_sve.intr.whilelt"(%base, %n) : (i64, i64) -> vector<[16]xi1>
+  llvm.return
+}

@MacDue
Copy link
Member Author

MacDue commented Jun 14, 2024

As an example:

func.func @mask(%a: index) -> vector<[16]xi1> {
  %0 = vector.create_mask %a : vector<[16]xi1>
  return %0 : vector<[16]xi1>
}

Currently lowers to:

	index	z0.s, #0, #1
	mov	z1.s, w0
	ptrue	p0.s
	mov	z2.d, z0.d
	mov	z3.d, z0.d
	cmpgt	p3.s, p0/z, z1.s, z0.s
	incw	z2.s
	incw	z3.s, all, mul #2
	mov	z4.d, z2.d
	cmpgt	p1.s, p0/z, z1.s, z3.s
	incw	z4.s, all, mul #2
	cmpgt	p2.s, p0/z, z1.s, z4.s
	cmpgt	p0.s, p0/z, z1.s, z2.s
	uzp1	p1.h, p1.h, p2.h
	uzp1	p0.h, p3.h, p0.h
	uzp1	p0.b, p0.b, p1.b
	ret

With this patch it lowers to:

	whilelt	p0.b, xzr, x0
	ret

Copy link
Collaborator

@c-rhodes c-rhodes left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for the patch Ben this is a nice improvement!

Left some comments mostly minor, also occurred to me this is the first Vector -> ArmSVE conversion but we're doing it during LLVM conversion. I don't think we want to prematurely introduce a vector-to-arm-sve conversion pass, but something to bear in mind if this grows.

@MacDue
Copy link
Member Author

MacDue commented Jun 14, 2024

Left some comments mostly minor, also occurred to me this is the first Vector -> ArmSVE conversion but we're doing it during LLVM conversion.

It's Vector -> LLVM SVE intrinsics (not high-level ArmSVE operations). I'm not sure if a general Vector -> ArmSVE conversion would make much sense, as I don't think it'd be much more expressive than the vector dialect (and there's no problems needing solving there, e.g. like tile allocation for SME). Also, with some later patches in mind the ArmSVE ops I use have specific requirements (like +sme or +sve2), so a general conversion pass would be fairly limited in what it could do.

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice, thanks!

What was the lowering path before? As in, we'd start with vector.create_mask and then ...?

@MacDue
Copy link
Member Author

MacDue commented Jun 17, 2024

What was the lowering path before? As in, we'd start with vector.create_mask and then ...?

It's shown in the linked issue (cmp(stepvector, splat)). See #81840 for more details.

@MacDue MacDue merged commit 657ec73 into llvm:main Jun 17, 2024
4 of 6 checks passed
@MacDue MacDue deleted the sve_createmask branch June 17, 2024 09:28
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants