Skip to content

[LLVM][SVE] Improve legalisation of fixed length get.active.lane.mask #90213

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 43 additions & 37 deletions llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1674,6 +1674,9 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
setOperationPromotedToType(ISD::VECTOR_SPLICE, MVT::nxv16i1, MVT::nxv16i8);

setOperationAction(ISD::VSCALE, MVT::i32, Custom);

for (auto VT : {MVT::v16i1, MVT::v8i1, MVT::v4i1, MVT::v2i1})
setOperationAction(ISD::INTRINSIC_WO_CHAIN, VT, Custom);
}

if (Subtarget->hasMOPS() && Subtarget->hasMTE()) {
Expand Down Expand Up @@ -5686,8 +5689,24 @@ SDValue AArch64TargetLowering::LowerINTRINSIC_WO_CHAIN(SDValue Op,
case Intrinsic::get_active_lane_mask: {
SDValue ID =
DAG.getTargetConstant(Intrinsic::aarch64_sve_whilelo, dl, MVT::i64);
return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, dl, Op.getValueType(), ID,
Op.getOperand(1), Op.getOperand(2));

EVT VT = Op.getValueType();
if (VT.isScalableVector())
return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, dl, VT, ID, Op.getOperand(1),
Op.getOperand(2));

// We can use the SVE whilelo instruction to lower this intrinsic by
// creating the appropriate sequence of scalable vector operations and
// then extracting a fixed-width subvector from the scalable vector.

EVT ContainerVT = getContainerForFixedLengthVector(DAG, VT);
EVT WhileVT = ContainerVT.changeElementType(MVT::i1);

SDValue Mask = DAG.getNode(ISD::INTRINSIC_WO_CHAIN, dl, WhileVT, ID,
Op.getOperand(1), Op.getOperand(2));
SDValue MaskAsInt = DAG.getNode(ISD::SIGN_EXTEND, dl, ContainerVT, Mask);
return DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, VT, MaskAsInt,
DAG.getVectorIdxConstant(0, dl));
}
case Intrinsic::aarch64_neon_uaddlv: {
EVT OpVT = Op.getOperand(1).getValueType();
Expand Down Expand Up @@ -20462,39 +20481,6 @@ static SDValue performIntrinsicCombine(SDNode *N,
switch (IID) {
default:
break;
case Intrinsic::get_active_lane_mask: {
SDValue Res = SDValue();
EVT VT = N->getValueType(0);
if (VT.isFixedLengthVector()) {
// We can use the SVE whilelo instruction to lower this intrinsic by
// creating the appropriate sequence of scalable vector operations and
// then extracting a fixed-width subvector from the scalable vector.

SDLoc DL(N);
SDValue ID =
DAG.getTargetConstant(Intrinsic::aarch64_sve_whilelo, DL, MVT::i64);

EVT WhileVT = EVT::getVectorVT(
*DAG.getContext(), MVT::i1,
ElementCount::getScalable(VT.getVectorNumElements()));

// Get promoted scalable vector VT, i.e. promote nxv4i1 -> nxv4i32.
EVT PromVT = getPromotedVTForPredicate(WhileVT);

// Get the fixed-width equivalent of PromVT for extraction.
EVT ExtVT =
EVT::getVectorVT(*DAG.getContext(), PromVT.getVectorElementType(),
VT.getVectorElementCount());

Res = DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, WhileVT, ID,
N->getOperand(1), N->getOperand(2));
Res = DAG.getNode(ISD::SIGN_EXTEND, DL, PromVT, Res);
Res = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, ExtVT, Res,
DAG.getConstant(0, DL, MVT::i64));
Res = DAG.getNode(ISD::TRUNCATE, DL, VT, Res);
}
return Res;
}
case Intrinsic::aarch64_neon_vcvtfxs2fp:
case Intrinsic::aarch64_neon_vcvtfxu2fp:
return tryCombineFixedPointConvert(N, DCI, DAG);
Expand Down Expand Up @@ -25568,15 +25554,15 @@ void AArch64TargetLowering::ReplaceNodeResults(
return;
case ISD::INTRINSIC_WO_CHAIN: {
EVT VT = N->getValueType(0);
assert((VT == MVT::i8 || VT == MVT::i16) &&
"custom lowering for unexpected type");

Intrinsic::ID IntID =
static_cast<Intrinsic::ID>(N->getConstantOperandVal(0));
switch (IntID) {
default:
return;
case Intrinsic::aarch64_sve_clasta_n: {
assert((VT == MVT::i8 || VT == MVT::i16) &&
"custom lowering for unexpected type");
SDLoc DL(N);
auto Op2 = DAG.getNode(ISD::ANY_EXTEND, DL, MVT::i32, N->getOperand(2));
auto V = DAG.getNode(AArch64ISD::CLASTA_N, DL, MVT::i32,
Expand All @@ -25585,6 +25571,8 @@ void AArch64TargetLowering::ReplaceNodeResults(
return;
}
case Intrinsic::aarch64_sve_clastb_n: {
assert((VT == MVT::i8 || VT == MVT::i16) &&
"custom lowering for unexpected type");
SDLoc DL(N);
auto Op2 = DAG.getNode(ISD::ANY_EXTEND, DL, MVT::i32, N->getOperand(2));
auto V = DAG.getNode(AArch64ISD::CLASTB_N, DL, MVT::i32,
Expand All @@ -25593,19 +25581,37 @@ void AArch64TargetLowering::ReplaceNodeResults(
return;
}
case Intrinsic::aarch64_sve_lasta: {
assert((VT == MVT::i8 || VT == MVT::i16) &&
"custom lowering for unexpected type");
SDLoc DL(N);
auto V = DAG.getNode(AArch64ISD::LASTA, DL, MVT::i32,
N->getOperand(1), N->getOperand(2));
Results.push_back(DAG.getNode(ISD::TRUNCATE, DL, VT, V));
return;
}
case Intrinsic::aarch64_sve_lastb: {
assert((VT == MVT::i8 || VT == MVT::i16) &&
"custom lowering for unexpected type");
SDLoc DL(N);
auto V = DAG.getNode(AArch64ISD::LASTB, DL, MVT::i32,
N->getOperand(1), N->getOperand(2));
Results.push_back(DAG.getNode(ISD::TRUNCATE, DL, VT, V));
return;
}
case Intrinsic::get_active_lane_mask: {
if (!VT.isFixedLengthVector() || VT.getVectorElementType() != MVT::i1)
return;

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe it looks neater to combine these two if statements into one?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense given they are related. I'll post an update, likely tomorrow.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

// NOTE: Only trivial type promotion is supported.
EVT NewVT = getTypeToTransformTo(*DAG.getContext(), VT);
if (NewVT.getVectorNumElements() != VT.getVectorNumElements())
return;

SDLoc DL(N);
auto V = DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, NewVT, N->ops());
Results.push_back(DAG.getNode(ISD::TRUNCATE, DL, VT, V));
return;
}
}
}
case ISD::READ_REGISTER: {
Expand Down
36 changes: 18 additions & 18 deletions llvm/test/CodeGen/AArch64/active_lane_mask.ll
Original file line number Diff line number Diff line change
Expand Up @@ -353,9 +353,9 @@ define <16 x i1> @lane_mask_v16i1_i32(i32 %index, i32 %TC) {
define <8 x i1> @lane_mask_v8i1_i32(i32 %index, i32 %TC) {
; CHECK-LABEL: lane_mask_v8i1_i32:
; CHECK: // %bb.0:
; CHECK-NEXT: whilelo p0.h, w0, w1
; CHECK-NEXT: mov z0.h, p0/z, #-1 // =0xffffffffffffffff
; CHECK-NEXT: xtn v0.8b, v0.8h
; CHECK-NEXT: whilelo p0.b, w0, w1
; CHECK-NEXT: mov z0.b, p0/z, #-1 // =0xffffffffffffffff
; CHECK-NEXT: // kill: def $d0 killed $d0 killed $z0
; CHECK-NEXT: ret
%active.lane.mask = call <8 x i1> @llvm.get.active.lane.mask.v8i1.i32(i32 %index, i32 %TC)
ret <8 x i1> %active.lane.mask
Expand All @@ -364,9 +364,9 @@ define <8 x i1> @lane_mask_v8i1_i32(i32 %index, i32 %TC) {
define <4 x i1> @lane_mask_v4i1_i32(i32 %index, i32 %TC) {
; CHECK-LABEL: lane_mask_v4i1_i32:
; CHECK: // %bb.0:
; CHECK-NEXT: whilelo p0.s, w0, w1
; CHECK-NEXT: mov z0.s, p0/z, #-1 // =0xffffffffffffffff
; CHECK-NEXT: xtn v0.4h, v0.4s
; CHECK-NEXT: whilelo p0.h, w0, w1
; CHECK-NEXT: mov z0.h, p0/z, #-1 // =0xffffffffffffffff
; CHECK-NEXT: // kill: def $d0 killed $d0 killed $z0
; CHECK-NEXT: ret
%active.lane.mask = call <4 x i1> @llvm.get.active.lane.mask.v4i1.i32(i32 %index, i32 %TC)
ret <4 x i1> %active.lane.mask
Expand All @@ -375,9 +375,9 @@ define <4 x i1> @lane_mask_v4i1_i32(i32 %index, i32 %TC) {
define <2 x i1> @lane_mask_v2i1_i32(i32 %index, i32 %TC) {
; CHECK-LABEL: lane_mask_v2i1_i32:
; CHECK: // %bb.0:
; CHECK-NEXT: whilelo p0.d, w0, w1
; CHECK-NEXT: mov z0.d, p0/z, #-1 // =0xffffffffffffffff
; CHECK-NEXT: xtn v0.2s, v0.2d
; CHECK-NEXT: whilelo p0.s, w0, w1
; CHECK-NEXT: mov z0.s, p0/z, #-1 // =0xffffffffffffffff
; CHECK-NEXT: // kill: def $d0 killed $d0 killed $z0
; CHECK-NEXT: ret
%active.lane.mask = call <2 x i1> @llvm.get.active.lane.mask.v2i1.i32(i32 %index, i32 %TC)
ret <2 x i1> %active.lane.mask
Expand All @@ -397,9 +397,9 @@ define <16 x i1> @lane_mask_v16i1_i64(i64 %index, i64 %TC) {
define <8 x i1> @lane_mask_v8i1_i64(i64 %index, i64 %TC) {
; CHECK-LABEL: lane_mask_v8i1_i64:
; CHECK: // %bb.0:
; CHECK-NEXT: whilelo p0.h, x0, x1
; CHECK-NEXT: mov z0.h, p0/z, #-1 // =0xffffffffffffffff
; CHECK-NEXT: xtn v0.8b, v0.8h
; CHECK-NEXT: whilelo p0.b, x0, x1
; CHECK-NEXT: mov z0.b, p0/z, #-1 // =0xffffffffffffffff
; CHECK-NEXT: // kill: def $d0 killed $d0 killed $z0
; CHECK-NEXT: ret
%active.lane.mask = call <8 x i1> @llvm.get.active.lane.mask.v8i1.i64(i64 %index, i64 %TC)
ret <8 x i1> %active.lane.mask
Expand All @@ -408,9 +408,9 @@ define <8 x i1> @lane_mask_v8i1_i64(i64 %index, i64 %TC) {
define <4 x i1> @lane_mask_v4i1_i64(i64 %index, i64 %TC) {
; CHECK-LABEL: lane_mask_v4i1_i64:
; CHECK: // %bb.0:
; CHECK-NEXT: whilelo p0.s, x0, x1
; CHECK-NEXT: mov z0.s, p0/z, #-1 // =0xffffffffffffffff
; CHECK-NEXT: xtn v0.4h, v0.4s
; CHECK-NEXT: whilelo p0.h, x0, x1
; CHECK-NEXT: mov z0.h, p0/z, #-1 // =0xffffffffffffffff
; CHECK-NEXT: // kill: def $d0 killed $d0 killed $z0
; CHECK-NEXT: ret
%active.lane.mask = call <4 x i1> @llvm.get.active.lane.mask.v4i1.i64(i64 %index, i64 %TC)
ret <4 x i1> %active.lane.mask
Expand All @@ -419,9 +419,9 @@ define <4 x i1> @lane_mask_v4i1_i64(i64 %index, i64 %TC) {
define <2 x i1> @lane_mask_v2i1_i64(i64 %index, i64 %TC) {
; CHECK-LABEL: lane_mask_v2i1_i64:
; CHECK: // %bb.0:
; CHECK-NEXT: whilelo p0.d, x0, x1
; CHECK-NEXT: mov z0.d, p0/z, #-1 // =0xffffffffffffffff
; CHECK-NEXT: xtn v0.2s, v0.2d
; CHECK-NEXT: whilelo p0.s, x0, x1
; CHECK-NEXT: mov z0.s, p0/z, #-1 // =0xffffffffffffffff
; CHECK-NEXT: // kill: def $d0 killed $d0 killed $z0
; CHECK-NEXT: ret
%active.lane.mask = call <2 x i1> @llvm.get.active.lane.mask.v2i1.i64(i64 %index, i64 %TC)
ret <2 x i1> %active.lane.mask
Expand Down