Skip to content

[AArch64] Add custom lowering of nxv32i1 get.active.lane.mask nodes #141969

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jun 6, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 39 additions & 1 deletion llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1501,6 +1501,10 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
setOperationAction(ISD::GET_ACTIVE_LANE_MASK, VT, Legal);
}

if (Subtarget->hasSVE2p1() ||
(Subtarget->hasSME2() && Subtarget->isStreaming()))
setOperationAction(ISD::GET_ACTIVE_LANE_MASK, MVT::nxv32i1, Custom);

for (auto VT : {MVT::v16i8, MVT::v8i8, MVT::v4i16, MVT::v2i32})
setOperationAction(ISD::GET_ACTIVE_LANE_MASK, VT, Custom);
}
Expand Down Expand Up @@ -18165,7 +18169,7 @@ performActiveLaneMaskCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
/*IsEqual=*/false))
return While;

if (!ST->hasSVE2p1())
if (!ST->hasSVE2p1() && !(ST->hasSME2() && ST->isStreaming()))
return SDValue();

if (!N->hasNUsesOfValue(2, 0))
Expand Down Expand Up @@ -27328,6 +27332,37 @@ void AArch64TargetLowering::ReplaceExtractSubVectorResults(
Results.push_back(DAG.getNode(ISD::TRUNCATE, DL, VT, Half));
}

void AArch64TargetLowering::ReplaceGetActiveLaneMaskResults(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should protect the setOperationAction() call and then be an assert here, that way the function is only called when the necessary instructions are available.

Do you mind also extending the PR to cover Subtarget.hasSME2() && Subtarget.isStreaming()?

SDNode *N, SmallVectorImpl<SDValue> &Results, SelectionDAG &DAG) const {
assert((Subtarget->hasSVE2p1() ||
(Subtarget->hasSME2() && Subtarget->isStreaming())) &&
"Custom lower of get.active.lane.mask missing required feature.");

assert(N->getValueType(0) == MVT::nxv32i1 &&
"Unexpected result type for get.active.lane.mask");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it worth having a test with a i128 index type to test the trunc case?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking at this again, I think there should also be an assert for the operand types here. The reason is that shouldExpandGetActiveLaneMask returns true when the operand types are bigger than i64, so we should never reach this function if Idx/TC are something like i128.


SDLoc DL(N);
SDValue Idx = N->getOperand(0);
SDValue TC = N->getOperand(1);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's look like this code is expecting the result type to <vscale x 32 x i1>. Is it worth adding an assert for this?


assert(Idx.getValueType().getFixedSizeInBits() <= 64 &&
"Unexpected operand type for get.active.lane.mask");

if (Idx.getValueType() != MVT::i64) {
Idx = DAG.getNode(ISD::ZERO_EXTEND, DL, MVT::i64, Idx);
TC = DAG.getNode(ISD::ZERO_EXTEND, DL, MVT::i64, TC);
}

SDValue ID =
DAG.getTargetConstant(Intrinsic::aarch64_sve_whilelo_x2, DL, MVT::i64);
EVT HalfVT = N->getValueType(0).getHalfNumVectorElementsVT(*DAG.getContext());
auto WideMask =
DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, {HalfVT, HalfVT}, {ID, Idx, TC});

Results.push_back(DAG.getNode(ISD::CONCAT_VECTORS, DL, N->getValueType(0),
{WideMask.getValue(0), WideMask.getValue(1)}));
}

// Create an even/odd pair of X registers holding integer value V.
static SDValue createGPRPairNode(SelectionDAG &DAG, SDValue V) {
SDLoc dl(V.getNode());
Expand Down Expand Up @@ -27714,6 +27749,9 @@ void AArch64TargetLowering::ReplaceNodeResults(
// CONCAT_VECTORS -- but delegate to common code for result type
// legalisation
return;
case ISD::GET_ACTIVE_LANE_MASK:
ReplaceGetActiveLaneMaskResults(N, Results, DAG);
return;
case ISD::INTRINSIC_WO_CHAIN: {
EVT VT = N->getValueType(0);

Expand Down
3 changes: 3 additions & 0 deletions llvm/lib/Target/AArch64/AArch64ISelLowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -1318,6 +1318,9 @@ class AArch64TargetLowering : public TargetLowering {
void ReplaceExtractSubVectorResults(SDNode *N,
SmallVectorImpl<SDValue> &Results,
SelectionDAG &DAG) const;
void ReplaceGetActiveLaneMaskResults(SDNode *N,
SmallVectorImpl<SDValue> &Results,
SelectionDAG &DAG) const;

bool shouldNormalizeToSelectSequence(LLVMContext &, EVT) const override;

Expand Down
137 changes: 107 additions & 30 deletions llvm/test/CodeGen/AArch64/get-active-lane-mask-extract.ll
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 4
; RUN: llc -mattr=+sve < %s | FileCheck %s -check-prefix CHECK-SVE
; RUN: llc -mattr=+sve2p1 < %s | FileCheck %s -check-prefix CHECK-SVE2p1
; RUN: llc -mattr=+sve2p1 < %s | FileCheck %s -check-prefix CHECK-SVE2p1-SME2 -check-prefix CHECK-SVE2p1
; RUN: llc -mattr=+sve -mattr=+sme2 -force-streaming < %s | FileCheck %s -check-prefix CHECK-SVE2p1-SME2 -check-prefix CHECK-SME2
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is -mattr=+sve required for the SME2 test?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It shouldn't be needed for the get_active_lane_mask, but without SVE the tests fail with Don't know how to legalize this scalable vector type because of the extract_subvectors.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To follow up on this, the assert I mentioned above was only happening when I did not also pass the -force-streaming flag. Dropping only -mattr=+sve just results in poor codgen because the get.active.lane.mask is expanded.

target triple = "aarch64-linux"

; Test combining of getActiveLaneMask with a pair of extract_vector operations.
Expand All @@ -13,12 +14,12 @@ define void @test_2x8bit_mask_with_32bit_index_and_trip_count(i32 %i, i32 %n) #0
; CHECK-SVE-NEXT: punpkhi p1.h, p1.b
; CHECK-SVE-NEXT: b use
;
; CHECK-SVE2p1-LABEL: test_2x8bit_mask_with_32bit_index_and_trip_count:
; CHECK-SVE2p1: // %bb.0:
; CHECK-SVE2p1-NEXT: mov w8, w1
; CHECK-SVE2p1-NEXT: mov w9, w0
; CHECK-SVE2p1-NEXT: whilelo { p0.h, p1.h }, x9, x8
; CHECK-SVE2p1-NEXT: b use
; CHECK-SVE2p1-SME2-LABEL: test_2x8bit_mask_with_32bit_index_and_trip_count:
; CHECK-SVE2p1-SME2: // %bb.0:
; CHECK-SVE2p1-SME2-NEXT: mov w8, w1
; CHECK-SVE2p1-SME2-NEXT: mov w9, w0
; CHECK-SVE2p1-SME2-NEXT: whilelo { p0.h, p1.h }, x9, x8
; CHECK-SVE2p1-SME2-NEXT: b use
%r = call <vscale x 16 x i1> @llvm.get.active.lane.mask.nxv16i1.i32(i32 %i, i32 %n)
%v0 = call <vscale x 8 x i1> @llvm.vector.extract.nxv8i1.nxv16i1.i64(<vscale x 16 x i1> %r, i64 0)
%v1 = call <vscale x 8 x i1> @llvm.vector.extract.nxv8i1.nxv16i1.i64(<vscale x 16 x i1> %r, i64 8)
Expand All @@ -34,10 +35,10 @@ define void @test_2x8bit_mask_with_64bit_index_and_trip_count(i64 %i, i64 %n) #0
; CHECK-SVE-NEXT: punpkhi p1.h, p1.b
; CHECK-SVE-NEXT: b use
;
; CHECK-SVE2p1-LABEL: test_2x8bit_mask_with_64bit_index_and_trip_count:
; CHECK-SVE2p1: // %bb.0:
; CHECK-SVE2p1-NEXT: whilelo { p0.h, p1.h }, x0, x1
; CHECK-SVE2p1-NEXT: b use
; CHECK-SVE2p1-SME2-LABEL: test_2x8bit_mask_with_64bit_index_and_trip_count:
; CHECK-SVE2p1-SME2: // %bb.0:
; CHECK-SVE2p1-SME2-NEXT: whilelo { p0.h, p1.h }, x0, x1
; CHECK-SVE2p1-SME2-NEXT: b use
%r = call <vscale x 16 x i1> @llvm.get.active.lane.mask.nxv16i1.i64(i64 %i, i64 %n)
%v0 = call <vscale x 8 x i1> @llvm.vector.extract.nxv8i1.nxv16i1.i64(<vscale x 16 x i1> %r, i64 0)
%v1 = call <vscale x 8 x i1> @llvm.vector.extract.nxv8i1.nxv16i1.i64(<vscale x 16 x i1> %r, i64 8)
Expand All @@ -53,12 +54,12 @@ define void @test_edge_case_2x1bit_mask(i64 %i, i64 %n) #0 {
; CHECK-SVE-NEXT: punpkhi p1.h, p1.b
; CHECK-SVE-NEXT: b use
;
; CHECK-SVE2p1-LABEL: test_edge_case_2x1bit_mask:
; CHECK-SVE2p1: // %bb.0:
; CHECK-SVE2p1-NEXT: whilelo p1.d, x0, x1
; CHECK-SVE2p1-NEXT: punpklo p0.h, p1.b
; CHECK-SVE2p1-NEXT: punpkhi p1.h, p1.b
; CHECK-SVE2p1-NEXT: b use
; CHECK-SVE2p1-SME2-LABEL: test_edge_case_2x1bit_mask:
; CHECK-SVE2p1-SME2: // %bb.0:
; CHECK-SVE2p1-SME2-NEXT: whilelo p1.d, x0, x1
; CHECK-SVE2p1-SME2-NEXT: punpklo p0.h, p1.b
; CHECK-SVE2p1-SME2-NEXT: punpkhi p1.h, p1.b
; CHECK-SVE2p1-SME2-NEXT: b use
%r = call <vscale x 2 x i1> @llvm.get.active.lane.mask.nxv2i1.i64(i64 %i, i64 %n)
%v0 = call <vscale x 1 x i1> @llvm.vector.extract.nxv1i1.nxv2i1.i64(<vscale x 2 x i1> %r, i64 0)
%v1 = call <vscale x 1 x i1> @llvm.vector.extract.nxv1i1.nxv2i1.i64(<vscale x 2 x i1> %r, i64 1)
Expand All @@ -74,10 +75,10 @@ define void @test_boring_case_2x2bit_mask(i64 %i, i64 %n) #0 {
; CHECK-SVE-NEXT: punpkhi p1.h, p1.b
; CHECK-SVE-NEXT: b use
;
; CHECK-SVE2p1-LABEL: test_boring_case_2x2bit_mask:
; CHECK-SVE2p1: // %bb.0:
; CHECK-SVE2p1-NEXT: whilelo { p0.d, p1.d }, x0, x1
; CHECK-SVE2p1-NEXT: b use
; CHECK-SVE2p1-SME2-LABEL: test_boring_case_2x2bit_mask:
; CHECK-SVE2p1-SME2: // %bb.0:
; CHECK-SVE2p1-SME2-NEXT: whilelo { p0.d, p1.d }, x0, x1
; CHECK-SVE2p1-SME2-NEXT: b use
%r = call <vscale x 4 x i1> @llvm.get.active.lane.mask.nxv4i1.i64(i64 %i, i64 %n)
%v0 = call <vscale x 2 x i1> @llvm.vector.extract.nxv2i1.nxv4i1.i64(<vscale x 4 x i1> %r, i64 0)
%v1 = call <vscale x 2 x i1> @llvm.vector.extract.nxv2i1.nxv4i1.i64(<vscale x 4 x i1> %r, i64 2)
Expand All @@ -96,22 +97,22 @@ define void @test_partial_extract(i64 %i, i64 %n) #0 {
; CHECK-SVE-NEXT: punpklo p1.h, p2.b
; CHECK-SVE-NEXT: b use
;
; CHECK-SVE2p1-LABEL: test_partial_extract:
; CHECK-SVE2p1: // %bb.0:
; CHECK-SVE2p1-NEXT: whilelo p0.h, x0, x1
; CHECK-SVE2p1-NEXT: punpklo p1.h, p0.b
; CHECK-SVE2p1-NEXT: punpkhi p2.h, p0.b
; CHECK-SVE2p1-NEXT: punpklo p0.h, p1.b
; CHECK-SVE2p1-NEXT: punpklo p1.h, p2.b
; CHECK-SVE2p1-NEXT: b use
; CHECK-SVE2p1-SME2-LABEL: test_partial_extract:
; CHECK-SVE2p1-SME2: // %bb.0:
; CHECK-SVE2p1-SME2-NEXT: whilelo p0.h, x0, x1
; CHECK-SVE2p1-SME2-NEXT: punpklo p1.h, p0.b
; CHECK-SVE2p1-SME2-NEXT: punpkhi p2.h, p0.b
; CHECK-SVE2p1-SME2-NEXT: punpklo p0.h, p1.b
; CHECK-SVE2p1-SME2-NEXT: punpklo p1.h, p2.b
; CHECK-SVE2p1-SME2-NEXT: b use
%r = call <vscale x 8 x i1> @llvm.get.active.lane.mask.nxv8i1.i64(i64 %i, i64 %n)
%v0 = call <vscale x 2 x i1> @llvm.vector.extract.nxv2i1.nxv8i1.i64(<vscale x 8 x i1> %r, i64 0)
%v1 = call <vscale x 2 x i1> @llvm.vector.extract.nxv2i1.nxv8i1.i64(<vscale x 8 x i1> %r, i64 4)
tail call void @use(<vscale x 2 x i1> %v0, <vscale x 2 x i1> %v1)
ret void
}

;; Negative test for when extracting a fixed-length vector.
; Negative test for when extracting a fixed-length vector.
define void @test_fixed_extract(i64 %i, i64 %n) #0 {
; CHECK-SVE-LABEL: test_fixed_extract:
; CHECK-SVE: // %bb.0:
Expand Down Expand Up @@ -144,13 +145,89 @@ define void @test_fixed_extract(i64 %i, i64 %n) #0 {
; CHECK-SVE2p1-NEXT: mov v1.s[1], w11
; CHECK-SVE2p1-NEXT: // kill: def $d1 killed $d1 killed $q1
; CHECK-SVE2p1-NEXT: b use
;
; CHECK-SME2-LABEL: test_fixed_extract:
; CHECK-SME2: // %bb.0:
; CHECK-SME2-NEXT: whilelo p0.h, x0, x1
; CHECK-SME2-NEXT: cset w8, mi
; CHECK-SME2-NEXT: mov z0.h, p0/z, #1 // =0x1
; CHECK-SME2-NEXT: mov z1.h, z0.h[1]
; CHECK-SME2-NEXT: mov z2.h, z0.h[5]
; CHECK-SME2-NEXT: mov z3.h, z0.h[4]
; CHECK-SME2-NEXT: fmov s0, w8
; CHECK-SME2-NEXT: zip1 z0.s, z0.s, z1.s
; CHECK-SME2-NEXT: zip1 z1.s, z3.s, z2.s
; CHECK-SME2-NEXT: // kill: def $d0 killed $d0 killed $z0
; CHECK-SME2-NEXT: // kill: def $d1 killed $d1 killed $z1
; CHECK-SME2-NEXT: b use
%r = call <vscale x 8 x i1> @llvm.get.active.lane.mask.nxv8i1.i64(i64 %i, i64 %n)
%v0 = call <2 x i1> @llvm.vector.extract.v2i1.nxv8i1.i64(<vscale x 8 x i1> %r, i64 0)
%v1 = call <2 x i1> @llvm.vector.extract.v2i1.nxv8i1.i64(<vscale x 8 x i1> %r, i64 4)
tail call void @use(<2 x i1> %v0, <2 x i1> %v1)
ret void
}

; Illegal Types

define void @test_2x16bit_mask_with_32bit_index_and_trip_count(i32 %i, i32 %n) #0 {
; CHECK-SVE-LABEL: test_2x16bit_mask_with_32bit_index_and_trip_count:
; CHECK-SVE: // %bb.0:
; CHECK-SVE-NEXT: rdvl x8, #1
; CHECK-SVE-NEXT: adds w8, w0, w8
; CHECK-SVE-NEXT: csinv w8, w8, wzr, lo
; CHECK-SVE-NEXT: whilelo p0.b, w0, w1
; CHECK-SVE-NEXT: whilelo p1.b, w8, w1
; CHECK-SVE-NEXT: b use
;
; CHECK-SVE2p1-SME2-LABEL: test_2x16bit_mask_with_32bit_index_and_trip_count:
; CHECK-SVE2p1-SME2: // %bb.0:
; CHECK-SVE2p1-SME2-NEXT: mov w8, w1
; CHECK-SVE2p1-SME2-NEXT: mov w9, w0
; CHECK-SVE2p1-SME2-NEXT: whilelo { p0.b, p1.b }, x9, x8
; CHECK-SVE2p1-SME2-NEXT: b use
%r = call <vscale x 32 x i1> @llvm.get.active.lane.mask.nxv32i1.i32(i32 %i, i32 %n)
%v0 = call <vscale x 16 x i1> @llvm.vector.extract.nxv16i1.nxv32i1.i64(<vscale x 32 x i1> %r, i64 0)
%v1 = call <vscale x 16 x i1> @llvm.vector.extract.nxv16i1.nxv32i1.i64(<vscale x 32 x i1> %r, i64 16)
tail call void @use(<vscale x 16 x i1> %v0, <vscale x 16 x i1> %v1)
ret void
}

define void @test_2x32bit_mask_with_32bit_index_and_trip_count(i32 %i, i32 %n) #0 {
; CHECK-SVE-LABEL: test_2x32bit_mask_with_32bit_index_and_trip_count:
; CHECK-SVE: // %bb.0:
; CHECK-SVE-NEXT: rdvl x8, #2
; CHECK-SVE-NEXT: rdvl x9, #1
; CHECK-SVE-NEXT: adds w8, w0, w8
; CHECK-SVE-NEXT: csinv w8, w8, wzr, lo
; CHECK-SVE-NEXT: adds w10, w8, w9
; CHECK-SVE-NEXT: csinv w10, w10, wzr, lo
; CHECK-SVE-NEXT: whilelo p3.b, w10, w1
; CHECK-SVE-NEXT: adds w9, w0, w9
; CHECK-SVE-NEXT: csinv w9, w9, wzr, lo
; CHECK-SVE-NEXT: whilelo p0.b, w0, w1
; CHECK-SVE-NEXT: whilelo p1.b, w9, w1
; CHECK-SVE-NEXT: whilelo p2.b, w8, w1
; CHECK-SVE-NEXT: b use
;
; CHECK-SVE2p1-SME2-LABEL: test_2x32bit_mask_with_32bit_index_and_trip_count:
; CHECK-SVE2p1-SME2: // %bb.0:
; CHECK-SVE2p1-SME2-NEXT: rdvl x8, #2
; CHECK-SVE2p1-SME2-NEXT: mov w9, w1
; CHECK-SVE2p1-SME2-NEXT: mov w10, w0
; CHECK-SVE2p1-SME2-NEXT: adds w8, w0, w8
; CHECK-SVE2p1-SME2-NEXT: csinv w8, w8, wzr, lo
; CHECK-SVE2p1-SME2-NEXT: whilelo { p0.b, p1.b }, x10, x9
; CHECK-SVE2p1-SME2-NEXT: whilelo { p2.b, p3.b }, x8, x9
; CHECK-SVE2p1-SME2-NEXT: b use
%r = call <vscale x 64 x i1> @llvm.get.active.lane.mask.nxv64i1.i32(i32 %i, i32 %n)
%v0 = call <vscale x 16 x i1> @llvm.vector.extract.nxv16i1.nxv64i1.i64(<vscale x 64 x i1> %r, i64 0)
%v1 = call <vscale x 16 x i1> @llvm.vector.extract.nxv16i1.nxv64i1.i64(<vscale x 64 x i1> %r, i64 16)
%v2 = call <vscale x 16 x i1> @llvm.vector.extract.nxv16i1.nxv64i1.i64(<vscale x 64 x i1> %r, i64 32)
%v3 = call <vscale x 16 x i1> @llvm.vector.extract.nxv16i1.nxv64i1.i64(<vscale x 64 x i1> %r, i64 48)
tail call void @use(<vscale x 16 x i1> %v0, <vscale x 16 x i1> %v1, <vscale x 16 x i1> %v2, <vscale x 16 x i1> %v3)
ret void
}

declare void @use(...)

attributes #0 = { nounwind }
Loading