Skip to content

Commit fdf206c

Browse files
[LLVM][SVE] Improve legalisation of fixed length get.active.lane.mask (#90213)
We are effectively performing type and operation legalisation very early within the code generation flow. This results in worse code quality because the DAG is not in canonical form, which DAGCombiner corrects through the introduction of operations that are not legal. This patchs splits and moves the code to where type and operation legalisation is typically implemented.
1 parent 9bebf25 commit fdf206c

File tree

2 files changed

+61
-55
lines changed

2 files changed

+61
-55
lines changed

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 43 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1677,6 +1677,9 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
16771677
setOperationPromotedToType(ISD::VECTOR_SPLICE, MVT::nxv16i1, MVT::nxv16i8);
16781678

16791679
setOperationAction(ISD::VSCALE, MVT::i32, Custom);
1680+
1681+
for (auto VT : {MVT::v16i1, MVT::v8i1, MVT::v4i1, MVT::v2i1})
1682+
setOperationAction(ISD::INTRINSIC_WO_CHAIN, VT, Custom);
16801683
}
16811684

16821685
if (Subtarget->hasMOPS() && Subtarget->hasMTE()) {
@@ -5748,8 +5751,24 @@ SDValue AArch64TargetLowering::LowerINTRINSIC_WO_CHAIN(SDValue Op,
57485751
case Intrinsic::get_active_lane_mask: {
57495752
SDValue ID =
57505753
DAG.getTargetConstant(Intrinsic::aarch64_sve_whilelo, dl, MVT::i64);
5751-
return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, dl, Op.getValueType(), ID,
5752-
Op.getOperand(1), Op.getOperand(2));
5754+
5755+
EVT VT = Op.getValueType();
5756+
if (VT.isScalableVector())
5757+
return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, dl, VT, ID, Op.getOperand(1),
5758+
Op.getOperand(2));
5759+
5760+
// We can use the SVE whilelo instruction to lower this intrinsic by
5761+
// creating the appropriate sequence of scalable vector operations and
5762+
// then extracting a fixed-width subvector from the scalable vector.
5763+
5764+
EVT ContainerVT = getContainerForFixedLengthVector(DAG, VT);
5765+
EVT WhileVT = ContainerVT.changeElementType(MVT::i1);
5766+
5767+
SDValue Mask = DAG.getNode(ISD::INTRINSIC_WO_CHAIN, dl, WhileVT, ID,
5768+
Op.getOperand(1), Op.getOperand(2));
5769+
SDValue MaskAsInt = DAG.getNode(ISD::SIGN_EXTEND, dl, ContainerVT, Mask);
5770+
return DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, VT, MaskAsInt,
5771+
DAG.getVectorIdxConstant(0, dl));
57535772
}
57545773
case Intrinsic::aarch64_neon_uaddlv: {
57555774
EVT OpVT = Op.getOperand(1).getValueType();
@@ -20530,39 +20549,6 @@ static SDValue performIntrinsicCombine(SDNode *N,
2053020549
switch (IID) {
2053120550
default:
2053220551
break;
20533-
case Intrinsic::get_active_lane_mask: {
20534-
SDValue Res = SDValue();
20535-
EVT VT = N->getValueType(0);
20536-
if (VT.isFixedLengthVector()) {
20537-
// We can use the SVE whilelo instruction to lower this intrinsic by
20538-
// creating the appropriate sequence of scalable vector operations and
20539-
// then extracting a fixed-width subvector from the scalable vector.
20540-
20541-
SDLoc DL(N);
20542-
SDValue ID =
20543-
DAG.getTargetConstant(Intrinsic::aarch64_sve_whilelo, DL, MVT::i64);
20544-
20545-
EVT WhileVT = EVT::getVectorVT(
20546-
*DAG.getContext(), MVT::i1,
20547-
ElementCount::getScalable(VT.getVectorNumElements()));
20548-
20549-
// Get promoted scalable vector VT, i.e. promote nxv4i1 -> nxv4i32.
20550-
EVT PromVT = getPromotedVTForPredicate(WhileVT);
20551-
20552-
// Get the fixed-width equivalent of PromVT for extraction.
20553-
EVT ExtVT =
20554-
EVT::getVectorVT(*DAG.getContext(), PromVT.getVectorElementType(),
20555-
VT.getVectorElementCount());
20556-
20557-
Res = DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, WhileVT, ID,
20558-
N->getOperand(1), N->getOperand(2));
20559-
Res = DAG.getNode(ISD::SIGN_EXTEND, DL, PromVT, Res);
20560-
Res = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, ExtVT, Res,
20561-
DAG.getConstant(0, DL, MVT::i64));
20562-
Res = DAG.getNode(ISD::TRUNCATE, DL, VT, Res);
20563-
}
20564-
return Res;
20565-
}
2056620552
case Intrinsic::aarch64_neon_vcvtfxs2fp:
2056720553
case Intrinsic::aarch64_neon_vcvtfxu2fp:
2056820554
return tryCombineFixedPointConvert(N, DCI, DAG);
@@ -25636,15 +25622,15 @@ void AArch64TargetLowering::ReplaceNodeResults(
2563625622
return;
2563725623
case ISD::INTRINSIC_WO_CHAIN: {
2563825624
EVT VT = N->getValueType(0);
25639-
assert((VT == MVT::i8 || VT == MVT::i16) &&
25640-
"custom lowering for unexpected type");
2564125625

2564225626
Intrinsic::ID IntID =
2564325627
static_cast<Intrinsic::ID>(N->getConstantOperandVal(0));
2564425628
switch (IntID) {
2564525629
default:
2564625630
return;
2564725631
case Intrinsic::aarch64_sve_clasta_n: {
25632+
assert((VT == MVT::i8 || VT == MVT::i16) &&
25633+
"custom lowering for unexpected type");
2564825634
SDLoc DL(N);
2564925635
auto Op2 = DAG.getNode(ISD::ANY_EXTEND, DL, MVT::i32, N->getOperand(2));
2565025636
auto V = DAG.getNode(AArch64ISD::CLASTA_N, DL, MVT::i32,
@@ -25653,6 +25639,8 @@ void AArch64TargetLowering::ReplaceNodeResults(
2565325639
return;
2565425640
}
2565525641
case Intrinsic::aarch64_sve_clastb_n: {
25642+
assert((VT == MVT::i8 || VT == MVT::i16) &&
25643+
"custom lowering for unexpected type");
2565625644
SDLoc DL(N);
2565725645
auto Op2 = DAG.getNode(ISD::ANY_EXTEND, DL, MVT::i32, N->getOperand(2));
2565825646
auto V = DAG.getNode(AArch64ISD::CLASTB_N, DL, MVT::i32,
@@ -25661,19 +25649,37 @@ void AArch64TargetLowering::ReplaceNodeResults(
2566125649
return;
2566225650
}
2566325651
case Intrinsic::aarch64_sve_lasta: {
25652+
assert((VT == MVT::i8 || VT == MVT::i16) &&
25653+
"custom lowering for unexpected type");
2566425654
SDLoc DL(N);
2566525655
auto V = DAG.getNode(AArch64ISD::LASTA, DL, MVT::i32,
2566625656
N->getOperand(1), N->getOperand(2));
2566725657
Results.push_back(DAG.getNode(ISD::TRUNCATE, DL, VT, V));
2566825658
return;
2566925659
}
2567025660
case Intrinsic::aarch64_sve_lastb: {
25661+
assert((VT == MVT::i8 || VT == MVT::i16) &&
25662+
"custom lowering for unexpected type");
2567125663
SDLoc DL(N);
2567225664
auto V = DAG.getNode(AArch64ISD::LASTB, DL, MVT::i32,
2567325665
N->getOperand(1), N->getOperand(2));
2567425666
Results.push_back(DAG.getNode(ISD::TRUNCATE, DL, VT, V));
2567525667
return;
2567625668
}
25669+
case Intrinsic::get_active_lane_mask: {
25670+
if (!VT.isFixedLengthVector() || VT.getVectorElementType() != MVT::i1)
25671+
return;
25672+
25673+
// NOTE: Only trivial type promotion is supported.
25674+
EVT NewVT = getTypeToTransformTo(*DAG.getContext(), VT);
25675+
if (NewVT.getVectorNumElements() != VT.getVectorNumElements())
25676+
return;
25677+
25678+
SDLoc DL(N);
25679+
auto V = DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, NewVT, N->ops());
25680+
Results.push_back(DAG.getNode(ISD::TRUNCATE, DL, VT, V));
25681+
return;
25682+
}
2567725683
}
2567825684
}
2567925685
case ISD::READ_REGISTER: {

llvm/test/CodeGen/AArch64/active_lane_mask.ll

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -353,9 +353,9 @@ define <16 x i1> @lane_mask_v16i1_i32(i32 %index, i32 %TC) {
353353
define <8 x i1> @lane_mask_v8i1_i32(i32 %index, i32 %TC) {
354354
; CHECK-LABEL: lane_mask_v8i1_i32:
355355
; CHECK: // %bb.0:
356-
; CHECK-NEXT: whilelo p0.h, w0, w1
357-
; CHECK-NEXT: mov z0.h, p0/z, #-1 // =0xffffffffffffffff
358-
; CHECK-NEXT: xtn v0.8b, v0.8h
356+
; CHECK-NEXT: whilelo p0.b, w0, w1
357+
; CHECK-NEXT: mov z0.b, p0/z, #-1 // =0xffffffffffffffff
358+
; CHECK-NEXT: // kill: def $d0 killed $d0 killed $z0
359359
; CHECK-NEXT: ret
360360
%active.lane.mask = call <8 x i1> @llvm.get.active.lane.mask.v8i1.i32(i32 %index, i32 %TC)
361361
ret <8 x i1> %active.lane.mask
@@ -364,9 +364,9 @@ define <8 x i1> @lane_mask_v8i1_i32(i32 %index, i32 %TC) {
364364
define <4 x i1> @lane_mask_v4i1_i32(i32 %index, i32 %TC) {
365365
; CHECK-LABEL: lane_mask_v4i1_i32:
366366
; CHECK: // %bb.0:
367-
; CHECK-NEXT: whilelo p0.s, w0, w1
368-
; CHECK-NEXT: mov z0.s, p0/z, #-1 // =0xffffffffffffffff
369-
; CHECK-NEXT: xtn v0.4h, v0.4s
367+
; CHECK-NEXT: whilelo p0.h, w0, w1
368+
; CHECK-NEXT: mov z0.h, p0/z, #-1 // =0xffffffffffffffff
369+
; CHECK-NEXT: // kill: def $d0 killed $d0 killed $z0
370370
; CHECK-NEXT: ret
371371
%active.lane.mask = call <4 x i1> @llvm.get.active.lane.mask.v4i1.i32(i32 %index, i32 %TC)
372372
ret <4 x i1> %active.lane.mask
@@ -375,9 +375,9 @@ define <4 x i1> @lane_mask_v4i1_i32(i32 %index, i32 %TC) {
375375
define <2 x i1> @lane_mask_v2i1_i32(i32 %index, i32 %TC) {
376376
; CHECK-LABEL: lane_mask_v2i1_i32:
377377
; CHECK: // %bb.0:
378-
; CHECK-NEXT: whilelo p0.d, w0, w1
379-
; CHECK-NEXT: mov z0.d, p0/z, #-1 // =0xffffffffffffffff
380-
; CHECK-NEXT: xtn v0.2s, v0.2d
378+
; CHECK-NEXT: whilelo p0.s, w0, w1
379+
; CHECK-NEXT: mov z0.s, p0/z, #-1 // =0xffffffffffffffff
380+
; CHECK-NEXT: // kill: def $d0 killed $d0 killed $z0
381381
; CHECK-NEXT: ret
382382
%active.lane.mask = call <2 x i1> @llvm.get.active.lane.mask.v2i1.i32(i32 %index, i32 %TC)
383383
ret <2 x i1> %active.lane.mask
@@ -397,9 +397,9 @@ define <16 x i1> @lane_mask_v16i1_i64(i64 %index, i64 %TC) {
397397
define <8 x i1> @lane_mask_v8i1_i64(i64 %index, i64 %TC) {
398398
; CHECK-LABEL: lane_mask_v8i1_i64:
399399
; CHECK: // %bb.0:
400-
; CHECK-NEXT: whilelo p0.h, x0, x1
401-
; CHECK-NEXT: mov z0.h, p0/z, #-1 // =0xffffffffffffffff
402-
; CHECK-NEXT: xtn v0.8b, v0.8h
400+
; CHECK-NEXT: whilelo p0.b, x0, x1
401+
; CHECK-NEXT: mov z0.b, p0/z, #-1 // =0xffffffffffffffff
402+
; CHECK-NEXT: // kill: def $d0 killed $d0 killed $z0
403403
; CHECK-NEXT: ret
404404
%active.lane.mask = call <8 x i1> @llvm.get.active.lane.mask.v8i1.i64(i64 %index, i64 %TC)
405405
ret <8 x i1> %active.lane.mask
@@ -408,9 +408,9 @@ define <8 x i1> @lane_mask_v8i1_i64(i64 %index, i64 %TC) {
408408
define <4 x i1> @lane_mask_v4i1_i64(i64 %index, i64 %TC) {
409409
; CHECK-LABEL: lane_mask_v4i1_i64:
410410
; CHECK: // %bb.0:
411-
; CHECK-NEXT: whilelo p0.s, x0, x1
412-
; CHECK-NEXT: mov z0.s, p0/z, #-1 // =0xffffffffffffffff
413-
; CHECK-NEXT: xtn v0.4h, v0.4s
411+
; CHECK-NEXT: whilelo p0.h, x0, x1
412+
; CHECK-NEXT: mov z0.h, p0/z, #-1 // =0xffffffffffffffff
413+
; CHECK-NEXT: // kill: def $d0 killed $d0 killed $z0
414414
; CHECK-NEXT: ret
415415
%active.lane.mask = call <4 x i1> @llvm.get.active.lane.mask.v4i1.i64(i64 %index, i64 %TC)
416416
ret <4 x i1> %active.lane.mask
@@ -419,9 +419,9 @@ define <4 x i1> @lane_mask_v4i1_i64(i64 %index, i64 %TC) {
419419
define <2 x i1> @lane_mask_v2i1_i64(i64 %index, i64 %TC) {
420420
; CHECK-LABEL: lane_mask_v2i1_i64:
421421
; CHECK: // %bb.0:
422-
; CHECK-NEXT: whilelo p0.d, x0, x1
423-
; CHECK-NEXT: mov z0.d, p0/z, #-1 // =0xffffffffffffffff
424-
; CHECK-NEXT: xtn v0.2s, v0.2d
422+
; CHECK-NEXT: whilelo p0.s, x0, x1
423+
; CHECK-NEXT: mov z0.s, p0/z, #-1 // =0xffffffffffffffff
424+
; CHECK-NEXT: // kill: def $d0 killed $d0 killed $z0
425425
; CHECK-NEXT: ret
426426
%active.lane.mask = call <2 x i1> @llvm.get.active.lane.mask.v2i1.i64(i64 %index, i64 %TC)
427427
ret <2 x i1> %active.lane.mask

0 commit comments

Comments
 (0)