Skip to content

Commit b61144b

Browse files
[AArch64] Allow lowering of more types to GET_ACTIVE_LANE_MASK (#140062)
Adds support for operand promotion and splitting/widening the result of the ISD::GET_ACTIVE_LANE_MASK node. For AArch64, shouldExpandGetActiveLaneMask now returns false for more types which we know can be legalised.
1 parent bf1d422 commit b61144b

File tree

6 files changed

+110
-206
lines changed

6 files changed

+110
-206
lines changed

llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2088,6 +2088,9 @@ bool DAGTypeLegalizer::PromoteIntegerOperand(SDNode *N, unsigned OpNo) {
20882088
case ISD::VECTOR_FIND_LAST_ACTIVE:
20892089
Res = PromoteIntOp_VECTOR_FIND_LAST_ACTIVE(N, OpNo);
20902090
break;
2091+
case ISD::GET_ACTIVE_LANE_MASK:
2092+
Res = PromoteIntOp_GET_ACTIVE_LANE_MASK(N);
2093+
break;
20912094
case ISD::PARTIAL_REDUCE_UMLA:
20922095
case ISD::PARTIAL_REDUCE_SMLA:
20932096
Res = PromoteIntOp_PARTIAL_REDUCE_MLA(N);
@@ -2874,6 +2877,13 @@ SDValue DAGTypeLegalizer::PromoteIntOp_VECTOR_FIND_LAST_ACTIVE(SDNode *N,
28742877
return SDValue(DAG.UpdateNodeOperands(N, NewOps), 0);
28752878
}
28762879

2880+
SDValue DAGTypeLegalizer::PromoteIntOp_GET_ACTIVE_LANE_MASK(SDNode *N) {
2881+
SmallVector<SDValue, 1> NewOps(N->ops());
2882+
NewOps[0] = ZExtPromotedInteger(N->getOperand(0));
2883+
NewOps[1] = ZExtPromotedInteger(N->getOperand(1));
2884+
return SDValue(DAG.UpdateNodeOperands(N, NewOps), 0);
2885+
}
2886+
28772887
SDValue DAGTypeLegalizer::PromoteIntOp_PARTIAL_REDUCE_MLA(SDNode *N) {
28782888
SmallVector<SDValue, 1> NewOps(N->ops());
28792889
if (N->getOpcode() == ISD::PARTIAL_REDUCE_SMLA) {

llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -432,6 +432,7 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
432432
SDValue PromoteIntOp_VP_SPLICE(SDNode *N, unsigned OpNo);
433433
SDValue PromoteIntOp_VECTOR_HISTOGRAM(SDNode *N, unsigned OpNo);
434434
SDValue PromoteIntOp_VECTOR_FIND_LAST_ACTIVE(SDNode *N, unsigned OpNo);
435+
SDValue PromoteIntOp_GET_ACTIVE_LANE_MASK(SDNode *N);
435436
SDValue PromoteIntOp_PARTIAL_REDUCE_MLA(SDNode *N);
436437

437438
void SExtOrZExtPromotedOperands(SDValue &LHS, SDValue &RHS);
@@ -985,6 +986,7 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
985986
void SplitVecRes_FP_TO_XINT_SAT(SDNode *N, SDValue &Lo, SDValue &Hi);
986987
void SplitVecRes_VP_REVERSE(SDNode *N, SDValue &Lo, SDValue &Hi);
987988
void SplitVecRes_PARTIAL_REDUCE_MLA(SDNode *N, SDValue &Lo, SDValue &Hi);
989+
void SplitVecRes_GET_ACTIVE_LANE_MASK(SDNode *N, SDValue &Lo, SDValue &Hi);
988990

989991
// Vector Operand Splitting: <128 x ty> -> 2 x <64 x ty>.
990992
bool SplitVectorOperand(SDNode *N, unsigned OpNo);
@@ -1081,6 +1083,7 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
10811083
SDValue WidenVecRes_UNDEF(SDNode *N);
10821084
SDValue WidenVecRes_VECTOR_SHUFFLE(ShuffleVectorSDNode *N);
10831085
SDValue WidenVecRes_VECTOR_REVERSE(SDNode *N);
1086+
SDValue WidenVecRes_GET_ACTIVE_LANE_MASK(SDNode *N);
10841087

10851088
SDValue WidenVecRes_Ternary(SDNode *N);
10861089
SDValue WidenVecRes_Binary(SDNode *N);

llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1389,6 +1389,9 @@ void DAGTypeLegalizer::SplitVectorResult(SDNode *N, unsigned ResNo) {
13891389
case ISD::PARTIAL_REDUCE_SMLA:
13901390
SplitVecRes_PARTIAL_REDUCE_MLA(N, Lo, Hi);
13911391
break;
1392+
case ISD::GET_ACTIVE_LANE_MASK:
1393+
SplitVecRes_GET_ACTIVE_LANE_MASK(N, Lo, Hi);
1394+
break;
13921395
}
13931396

13941397
// If Lo/Hi is null, the sub-method took care of registering results etc.
@@ -3234,6 +3237,22 @@ void DAGTypeLegalizer::SplitVecRes_PARTIAL_REDUCE_MLA(SDNode *N, SDValue &Lo,
32343237
Hi = DAG.getNode(Opcode, DL, ResultVT, AccHi, Input1Hi, Input2Hi);
32353238
}
32363239

3240+
void DAGTypeLegalizer::SplitVecRes_GET_ACTIVE_LANE_MASK(SDNode *N, SDValue &Lo,
3241+
SDValue &Hi) {
3242+
SDLoc DL(N);
3243+
SDValue Op0 = N->getOperand(0);
3244+
SDValue Op1 = N->getOperand(1);
3245+
EVT OpVT = Op0.getValueType();
3246+
3247+
EVT LoVT, HiVT;
3248+
std::tie(LoVT, HiVT) = DAG.GetSplitDestVTs(N->getValueType(0));
3249+
3250+
Lo = DAG.getNode(ISD::GET_ACTIVE_LANE_MASK, DL, LoVT, Op0, Op1);
3251+
SDValue LoElts = DAG.getElementCount(DL, OpVT, LoVT.getVectorElementCount());
3252+
SDValue HiStartVal = DAG.getNode(ISD::UADDSAT, DL, OpVT, Op0, LoElts);
3253+
Hi = DAG.getNode(ISD::GET_ACTIVE_LANE_MASK, DL, HiVT, HiStartVal, Op1);
3254+
}
3255+
32373256
void DAGTypeLegalizer::SplitVecRes_VECTOR_DEINTERLEAVE(SDNode *N) {
32383257
unsigned Factor = N->getNumOperands();
32393258

@@ -4631,6 +4650,9 @@ void DAGTypeLegalizer::WidenVectorResult(SDNode *N, unsigned ResNo) {
46314650
case ISD::VECTOR_REVERSE:
46324651
Res = WidenVecRes_VECTOR_REVERSE(N);
46334652
break;
4653+
case ISD::GET_ACTIVE_LANE_MASK:
4654+
Res = WidenVecRes_GET_ACTIVE_LANE_MASK(N);
4655+
break;
46344656

46354657
case ISD::ADD: case ISD::VP_ADD:
46364658
case ISD::AND: case ISD::VP_AND:
@@ -6579,6 +6601,11 @@ SDValue DAGTypeLegalizer::WidenVecRes_VECTOR_REVERSE(SDNode *N) {
65796601
Mask);
65806602
}
65816603

6604+
SDValue DAGTypeLegalizer::WidenVecRes_GET_ACTIVE_LANE_MASK(SDNode *N) {
6605+
EVT NVT = TLI.getTypeToTransformTo(*DAG.getContext(), N->getValueType(0));
6606+
return DAG.getNode(ISD::GET_ACTIVE_LANE_MASK, SDLoc(N), NVT, N->ops());
6607+
}
6608+
65826609
SDValue DAGTypeLegalizer::WidenVecRes_SETCC(SDNode *N) {
65836610
assert(N->getValueType(0).isVector() &&
65846611
N->getOperand(0).getValueType().isVector() &&

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2102,18 +2102,18 @@ void AArch64TargetLowering::addTypeForNEON(MVT VT) {
21022102
bool AArch64TargetLowering::shouldExpandGetActiveLaneMask(EVT ResVT,
21032103
EVT OpVT) const {
21042104
// Only SVE has a 1:1 mapping from intrinsic -> instruction (whilelo).
2105-
if (!Subtarget->hasSVE())
2105+
if (!Subtarget->hasSVE() || ResVT.getVectorElementType() != MVT::i1)
21062106
return true;
21072107

2108-
// We can only support legal predicate result types. We can use the SVE
2109-
// whilelo instruction for generating fixed-width predicates too.
2110-
if (ResVT != MVT::nxv2i1 && ResVT != MVT::nxv4i1 && ResVT != MVT::nxv8i1 &&
2111-
ResVT != MVT::nxv16i1 && ResVT != MVT::v2i1 && ResVT != MVT::v4i1 &&
2112-
ResVT != MVT::v8i1 && ResVT != MVT::v16i1)
2108+
// Only support illegal types if the result is scalable and min elements > 1.
2109+
if (ResVT.getVectorMinNumElements() == 1 ||
2110+
(ResVT.isFixedLengthVector() && (ResVT.getVectorNumElements() > 16 ||
2111+
(OpVT != MVT::i32 && OpVT != MVT::i64))))
21132112
return true;
21142113

2115-
// The whilelo instruction only works with i32 or i64 scalar inputs.
2116-
if (OpVT != MVT::i32 && OpVT != MVT::i64)
2114+
// 32 & 64 bit operands are supported. We can promote anything < 64 bits,
2115+
// but anything larger should be expanded.
2116+
if (OpVT.getFixedSizeInBits() > 64)
21172117
return true;
21182118

21192119
return false;

llvm/test/Analysis/CostModel/AArch64/sve-intrinsics.ll

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -920,8 +920,8 @@ define void @get_lane_mask() #0 {
920920
; CHECK-VSCALE-1-NEXT: Cost Model: Found costs of 1 for: %mask_nxv8i1_i32 = call <vscale x 8 x i1> @llvm.get.active.lane.mask.nxv8i1.i32(i32 undef, i32 undef)
921921
; CHECK-VSCALE-1-NEXT: Cost Model: Found costs of 1 for: %mask_nxv4i1_i32 = call <vscale x 4 x i1> @llvm.get.active.lane.mask.nxv4i1.i32(i32 undef, i32 undef)
922922
; CHECK-VSCALE-1-NEXT: Cost Model: Found costs of 1 for: %mask_nxv2i1_i32 = call <vscale x 2 x i1> @llvm.get.active.lane.mask.nxv2i1.i32(i32 undef, i32 undef)
923-
; CHECK-VSCALE-1-NEXT: Cost Model: Found costs of RThru:48 CodeSize:33 Lat:33 SizeLat:33 for: %mask_nxv32i1_i64 = call <vscale x 32 x i1> @llvm.get.active.lane.mask.nxv32i1.i64(i64 undef, i64 undef)
924-
; CHECK-VSCALE-1-NEXT: Cost Model: Found costs of RThru:6 CodeSize:5 Lat:5 SizeLat:5 for: %mask_nxv16i1_i16 = call <vscale x 16 x i1> @llvm.get.active.lane.mask.nxv16i1.i16(i16 undef, i16 undef)
923+
; CHECK-VSCALE-1-NEXT: Cost Model: Found costs of 2 for: %mask_nxv32i1_i64 = call <vscale x 32 x i1> @llvm.get.active.lane.mask.nxv32i1.i64(i64 undef, i64 undef)
924+
; CHECK-VSCALE-1-NEXT: Cost Model: Found costs of 1 for: %mask_nxv16i1_i16 = call <vscale x 16 x i1> @llvm.get.active.lane.mask.nxv16i1.i16(i16 undef, i16 undef)
925925
; CHECK-VSCALE-1-NEXT: Cost Model: Found costs of 32 for: %mask_v16i1_i64 = call <16 x i1> @llvm.get.active.lane.mask.v16i1.i64(i64 undef, i64 undef)
926926
; CHECK-VSCALE-1-NEXT: Cost Model: Found costs of 16 for: %mask_v8i1_i64 = call <8 x i1> @llvm.get.active.lane.mask.v8i1.i64(i64 undef, i64 undef)
927927
; CHECK-VSCALE-1-NEXT: Cost Model: Found costs of 8 for: %mask_v4i1_i64 = call <4 x i1> @llvm.get.active.lane.mask.v4i1.i64(i64 undef, i64 undef)
@@ -943,8 +943,8 @@ define void @get_lane_mask() #0 {
943943
; CHECK-VSCALE-2-NEXT: Cost Model: Found costs of 1 for: %mask_nxv8i1_i32 = call <vscale x 8 x i1> @llvm.get.active.lane.mask.nxv8i1.i32(i32 undef, i32 undef)
944944
; CHECK-VSCALE-2-NEXT: Cost Model: Found costs of 1 for: %mask_nxv4i1_i32 = call <vscale x 4 x i1> @llvm.get.active.lane.mask.nxv4i1.i32(i32 undef, i32 undef)
945945
; CHECK-VSCALE-2-NEXT: Cost Model: Found costs of 1 for: %mask_nxv2i1_i32 = call <vscale x 2 x i1> @llvm.get.active.lane.mask.nxv2i1.i32(i32 undef, i32 undef)
946-
; CHECK-VSCALE-2-NEXT: Cost Model: Found costs of RThru:48 CodeSize:33 Lat:33 SizeLat:33 for: %mask_nxv32i1_i64 = call <vscale x 32 x i1> @llvm.get.active.lane.mask.nxv32i1.i64(i64 undef, i64 undef)
947-
; CHECK-VSCALE-2-NEXT: Cost Model: Found costs of RThru:6 CodeSize:5 Lat:5 SizeLat:5 for: %mask_nxv16i1_i16 = call <vscale x 16 x i1> @llvm.get.active.lane.mask.nxv16i1.i16(i16 undef, i16 undef)
946+
; CHECK-VSCALE-2-NEXT: Cost Model: Found costs of 2 for: %mask_nxv32i1_i64 = call <vscale x 32 x i1> @llvm.get.active.lane.mask.nxv32i1.i64(i64 undef, i64 undef)
947+
; CHECK-VSCALE-2-NEXT: Cost Model: Found costs of 1 for: %mask_nxv16i1_i16 = call <vscale x 16 x i1> @llvm.get.active.lane.mask.nxv16i1.i16(i16 undef, i16 undef)
948948
; CHECK-VSCALE-2-NEXT: Cost Model: Found costs of 32 for: %mask_v16i1_i64 = call <16 x i1> @llvm.get.active.lane.mask.v16i1.i64(i64 undef, i64 undef)
949949
; CHECK-VSCALE-2-NEXT: Cost Model: Found costs of 16 for: %mask_v8i1_i64 = call <8 x i1> @llvm.get.active.lane.mask.v8i1.i64(i64 undef, i64 undef)
950950
; CHECK-VSCALE-2-NEXT: Cost Model: Found costs of 8 for: %mask_v4i1_i64 = call <4 x i1> @llvm.get.active.lane.mask.v4i1.i64(i64 undef, i64 undef)
@@ -966,8 +966,8 @@ define void @get_lane_mask() #0 {
966966
; TYPE_BASED_ONLY-NEXT: Cost Model: Found costs of 1 for: %mask_nxv8i1_i32 = call <vscale x 8 x i1> @llvm.get.active.lane.mask.nxv8i1.i32(i32 undef, i32 undef)
967967
; TYPE_BASED_ONLY-NEXT: Cost Model: Found costs of 1 for: %mask_nxv4i1_i32 = call <vscale x 4 x i1> @llvm.get.active.lane.mask.nxv4i1.i32(i32 undef, i32 undef)
968968
; TYPE_BASED_ONLY-NEXT: Cost Model: Found costs of 1 for: %mask_nxv2i1_i32 = call <vscale x 2 x i1> @llvm.get.active.lane.mask.nxv2i1.i32(i32 undef, i32 undef)
969-
; TYPE_BASED_ONLY-NEXT: Cost Model: Found costs of RThru:48 CodeSize:33 Lat:33 SizeLat:33 for: %mask_nxv32i1_i64 = call <vscale x 32 x i1> @llvm.get.active.lane.mask.nxv32i1.i64(i64 undef, i64 undef)
970-
; TYPE_BASED_ONLY-NEXT: Cost Model: Found costs of RThru:6 CodeSize:5 Lat:5 SizeLat:5 for: %mask_nxv16i1_i16 = call <vscale x 16 x i1> @llvm.get.active.lane.mask.nxv16i1.i16(i16 undef, i16 undef)
969+
; TYPE_BASED_ONLY-NEXT: Cost Model: Found costs of 2 for: %mask_nxv32i1_i64 = call <vscale x 32 x i1> @llvm.get.active.lane.mask.nxv32i1.i64(i64 undef, i64 undef)
970+
; TYPE_BASED_ONLY-NEXT: Cost Model: Found costs of 1 for: %mask_nxv16i1_i16 = call <vscale x 16 x i1> @llvm.get.active.lane.mask.nxv16i1.i16(i16 undef, i16 undef)
971971
; TYPE_BASED_ONLY-NEXT: Cost Model: Found costs of 32 for: %mask_v16i1_i64 = call <16 x i1> @llvm.get.active.lane.mask.v16i1.i64(i64 undef, i64 undef)
972972
; TYPE_BASED_ONLY-NEXT: Cost Model: Found costs of 16 for: %mask_v8i1_i64 = call <8 x i1> @llvm.get.active.lane.mask.v8i1.i64(i64 undef, i64 undef)
973973
; TYPE_BASED_ONLY-NEXT: Cost Model: Found costs of 8 for: %mask_v4i1_i64 = call <4 x i1> @llvm.get.active.lane.mask.v4i1.i64(i64 undef, i64 undef)

0 commit comments

Comments
 (0)