Skip to content

Commit ae438d1

Browse files
committed
Lower scalar FP converts to SVE
1 parent ba1deef commit ae438d1

7 files changed

+880
-294
lines changed

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 93 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1454,8 +1454,12 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
14541454
setOperationAction(ISD::INSERT_SUBVECTOR, VT, Custom);
14551455
setOperationAction(ISD::UINT_TO_FP, VT, Custom);
14561456
setOperationAction(ISD::SINT_TO_FP, VT, Custom);
1457+
setOperationAction(ISD::STRICT_UINT_TO_FP, VT, Custom);
1458+
setOperationAction(ISD::STRICT_SINT_TO_FP, VT, Custom);
14571459
setOperationAction(ISD::FP_TO_UINT, VT, Custom);
14581460
setOperationAction(ISD::FP_TO_SINT, VT, Custom);
1461+
setOperationAction(ISD::STRICT_FP_TO_UINT, VT, Custom);
1462+
setOperationAction(ISD::STRICT_FP_TO_SINT, VT, Custom);
14591463
setOperationAction(ISD::MLOAD, VT, Custom);
14601464
setOperationAction(ISD::MUL, VT, Custom);
14611465
setOperationAction(ISD::MULHS, VT, Custom);
@@ -2138,6 +2142,8 @@ void AArch64TargetLowering::addTypeForFixedLengthSVE(MVT VT) {
21382142
setOperationAction(ISD::FP_ROUND, VT, Default);
21392143
setOperationAction(ISD::FP_TO_SINT, VT, Default);
21402144
setOperationAction(ISD::FP_TO_UINT, VT, Default);
2145+
setOperationAction(ISD::STRICT_FP_TO_SINT, VT, Default);
2146+
setOperationAction(ISD::STRICT_FP_TO_UINT, VT, Default);
21412147
setOperationAction(ISD::FRINT, VT, Default);
21422148
setOperationAction(ISD::LRINT, VT, Default);
21432149
setOperationAction(ISD::LLRINT, VT, Default);
@@ -2164,6 +2170,7 @@ void AArch64TargetLowering::addTypeForFixedLengthSVE(MVT VT) {
21642170
setOperationAction(ISD::SIGN_EXTEND, VT, Default);
21652171
setOperationAction(ISD::SIGN_EXTEND_INREG, VT, Default);
21662172
setOperationAction(ISD::SINT_TO_FP, VT, Default);
2173+
setOperationAction(ISD::STRICT_SINT_TO_FP, VT, Default);
21672174
setOperationAction(ISD::SMAX, VT, Default);
21682175
setOperationAction(ISD::SMIN, VT, Default);
21692176
setOperationAction(ISD::SPLAT_VECTOR, VT, Default);
@@ -2174,6 +2181,7 @@ void AArch64TargetLowering::addTypeForFixedLengthSVE(MVT VT) {
21742181
setOperationAction(ISD::TRUNCATE, VT, Default);
21752182
setOperationAction(ISD::UDIV, VT, Default);
21762183
setOperationAction(ISD::UINT_TO_FP, VT, Default);
2184+
setOperationAction(ISD::STRICT_UINT_TO_FP, VT, Default);
21772185
setOperationAction(ISD::UMAX, VT, Default);
21782186
setOperationAction(ISD::UMIN, VT, Default);
21792187
setOperationAction(ISD::VECREDUCE_ADD, VT, Default);
@@ -4550,9 +4558,10 @@ SDValue AArch64TargetLowering::LowerVectorFP_TO_INT(SDValue Op,
45504558
EVT VT = Op.getValueType();
45514559

45524560
if (VT.isScalableVector()) {
4553-
unsigned Opcode = Op.getOpcode() == ISD::FP_TO_UINT
4554-
? AArch64ISD::FCVTZU_MERGE_PASSTHRU
4555-
: AArch64ISD::FCVTZS_MERGE_PASSTHRU;
4561+
unsigned Opc = Op.getOpcode();
4562+
bool IsSigned = Opc == ISD::FP_TO_SINT || Opc == ISD::STRICT_FP_TO_SINT;
4563+
unsigned Opcode = IsSigned ? AArch64ISD::FCVTZS_MERGE_PASSTHRU
4564+
: AArch64ISD::FCVTZU_MERGE_PASSTHRU;
45564565
return LowerToPredicatedOp(Op, DAG, Opcode);
45574566
}
45584567

@@ -4628,6 +4637,51 @@ SDValue AArch64TargetLowering::LowerVectorFP_TO_INT(SDValue Op,
46284637
return Op;
46294638
}
46304639

4640+
static bool CanLowerToScalarSVEFPIntConversion(EVT VT) {
4641+
if (!VT.isSimple())
4642+
return false;
4643+
// There are SVE instructions that can convert to/from all pairs of these int
4644+
// and float types. Note: We don't bother with i8 or i16 as those are illegal
4645+
// types for scalars.
4646+
return is_contained({MVT::i32, MVT::i64, MVT::f16, MVT::f32, MVT::f64},
4647+
VT.getSimpleVT().SimpleTy);
4648+
}
4649+
4650+
/// Lowers a scalar FP conversion (to/from) int to SVE.
4651+
static SDValue LowerScalarFPConversionToSVE(SDValue Op, SelectionDAG &DAG) {
4652+
bool IsStrict = Op->isStrictFPOpcode();
4653+
SDValue SrcVal = Op.getOperand(IsStrict ? 1 : 0);
4654+
EVT SrcTy = SrcVal.getValueType();
4655+
EVT DestTy = Op.getValueType();
4656+
EVT SrcVecTy;
4657+
EVT DestVecTy;
4658+
// Use a packed vector for the larger type.
4659+
// Note: For conversions such as FCVTZS_ZPmZ_DtoS, and UCVTF_ZPmZ_StoD that
4660+
// notionally take or return a nxv2i32 type we must instead use a nxv4i32, as
4661+
// (unlike floats) nxv2i32 is an illegal unpacked type.
4662+
if (DestTy.bitsGT(SrcTy)) {
4663+
DestVecTy = getPackedSVEVectorVT(DestTy);
4664+
SrcVecTy = SrcTy == MVT::i32 ? getPackedSVEVectorVT(SrcTy)
4665+
: DestVecTy.changeVectorElementType(SrcTy);
4666+
} else {
4667+
SrcVecTy = getPackedSVEVectorVT(SrcTy);
4668+
DestVecTy = DestTy == MVT::i32 ? getPackedSVEVectorVT(DestTy)
4669+
: SrcVecTy.changeVectorElementType(DestTy);
4670+
}
4671+
SDLoc dl(Op);
4672+
SDValue ZeroIdx = DAG.getVectorIdxConstant(0, dl);
4673+
SDValue Vec = DAG.getNode(ISD::INSERT_VECTOR_ELT, dl, SrcVecTy,
4674+
DAG.getUNDEF(SrcVecTy), SrcVal, ZeroIdx);
4675+
Vec = IsStrict ? DAG.getNode(Op.getOpcode(), dl, {DestVecTy, MVT::Other},
4676+
{Op.getOperand(0), Vec})
4677+
: DAG.getNode(Op.getOpcode(), dl, DestVecTy, Vec);
4678+
SDValue Scalar =
4679+
DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, Op.getValueType(), Vec, ZeroIdx);
4680+
if (IsStrict)
4681+
return DAG.getMergeValues({Scalar, Vec.getValue(1)}, dl);
4682+
return Scalar;
4683+
}
4684+
46314685
SDValue AArch64TargetLowering::LowerFP_TO_INT(SDValue Op,
46324686
SelectionDAG &DAG) const {
46334687
bool IsStrict = Op->isStrictFPOpcode();
@@ -4636,6 +4690,12 @@ SDValue AArch64TargetLowering::LowerFP_TO_INT(SDValue Op,
46364690
if (SrcVal.getValueType().isVector())
46374691
return LowerVectorFP_TO_INT(Op, DAG);
46384692

4693+
if (!Subtarget->isNeonAvailable() &&
4694+
Subtarget->isSVEorStreamingSVEAvailable() &&
4695+
CanLowerToScalarSVEFPIntConversion(SrcVal.getValueType()) &&
4696+
CanLowerToScalarSVEFPIntConversion(Op.getValueType()))
4697+
return LowerScalarFPConversionToSVE(Op, DAG);
4698+
46394699
// f16 conversions are promoted to f32 when full fp16 is not supported.
46404700
if ((SrcVal.getValueType() == MVT::f16 && !Subtarget->hasFullFP16()) ||
46414701
SrcVal.getValueType() == MVT::bf16) {
@@ -4939,6 +4999,12 @@ SDValue AArch64TargetLowering::LowerINT_TO_FP(SDValue Op,
49394999
bool IsStrict = Op->isStrictFPOpcode();
49405000
SDValue SrcVal = Op.getOperand(IsStrict ? 1 : 0);
49415001

5002+
if (!Subtarget->isNeonAvailable() &&
5003+
Subtarget->isSVEorStreamingSVEAvailable() &&
5004+
CanLowerToScalarSVEFPIntConversion(SrcVal.getValueType()) &&
5005+
CanLowerToScalarSVEFPIntConversion(Op.getValueType()))
5006+
return LowerScalarFPConversionToSVE(Op, DAG);
5007+
49425008
bool IsSigned = Op->getOpcode() == ISD::STRICT_SINT_TO_FP ||
49435009
Op->getOpcode() == ISD::SINT_TO_FP;
49445010

@@ -28295,7 +28361,21 @@ SDValue AArch64TargetLowering::LowerToPredicatedOp(SDValue Op,
2829528361
unsigned NewOp) const {
2829628362
EVT VT = Op.getValueType();
2829728363
SDLoc DL(Op);
28298-
auto Pg = getPredicateForVector(DAG, DL, VT);
28364+
SDValue Pg;
28365+
28366+
// FCVTZS_ZPmZ_DtoS and FCVTZU_ZPmZ_DtoS are special cases. These operations
28367+
// return nxv4i32 rather than the correct nxv2i32, as nxv2i32 is an illegal
28368+
// unpacked type. So, in this case, we take the predicate size from the
28369+
// operand.
28370+
SDValue LastOp{};
28371+
if ((NewOp == AArch64ISD::FCVTZU_MERGE_PASSTHRU ||
28372+
NewOp == AArch64ISD::FCVTZS_MERGE_PASSTHRU) &&
28373+
VT == MVT::nxv4i32 &&
28374+
(LastOp = Op->ops().back().get()).getValueType() == MVT::nxv2f64) {
28375+
Pg = getPredicateForVector(DAG, DL, LastOp.getValueType());
28376+
} else {
28377+
Pg = getPredicateForVector(DAG, DL, VT);
28378+
}
2829928379

2830028380
if (VT.isFixedLengthVector()) {
2830128381
assert(isTypeLegal(VT) && "Expected only legal fixed-width types");
@@ -28331,7 +28411,12 @@ SDValue AArch64TargetLowering::LowerToPredicatedOp(SDValue Op,
2833128411
assert(VT.isScalableVector() && "Only expect to lower scalable vector op!");
2833228412

2833328413
SmallVector<SDValue, 4> Operands = {Pg};
28414+
SDValue Chain{};
2833428415
for (const SDValue &V : Op->op_values()) {
28416+
if (!isa<CondCodeSDNode>(V) && V.getValueType() == MVT::Other) {
28417+
Chain = V;
28418+
continue;
28419+
}
2833528420
assert((!V.getValueType().isVector() ||
2833628421
V.getValueType().isScalableVector()) &&
2833728422
"Only scalable vectors are supported!");
@@ -28341,7 +28426,10 @@ SDValue AArch64TargetLowering::LowerToPredicatedOp(SDValue Op,
2834128426
if (isMergePassthruOpcode(NewOp))
2834228427
Operands.push_back(DAG.getUNDEF(VT));
2834328428

28344-
return DAG.getNode(NewOp, DL, VT, Operands, Op->getFlags());
28429+
auto NewNode = DAG.getNode(NewOp, DL, VT, Operands, Op->getFlags());
28430+
if (Chain)
28431+
return DAG.getMergeValues({NewNode, Chain}, DL);
28432+
return NewNode;
2834528433
}
2834628434

2834728435
// If a fixed length vector operation has no side effects when applied to

llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td

Lines changed: 2 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -2338,8 +2338,8 @@ let Predicates = [HasSVEorSME] in {
23382338
defm UCVTF_ZPmZ_DtoH : sve_fp_2op_p_zd< 0b0110111, "ucvtf", ZPR64, ZPR16, int_aarch64_sve_ucvtf_f16i64, AArch64ucvtf_mt, nxv2f16, nxv2i1, nxv2i64, ElementSizeD>;
23392339
defm SCVTF_ZPmZ_DtoD : sve_fp_2op_p_zd< 0b1110110, "scvtf", ZPR64, ZPR64, null_frag, AArch64scvtf_mt, nxv2f64, nxv2i1, nxv2i64, ElementSizeD>;
23402340
defm UCVTF_ZPmZ_DtoD : sve_fp_2op_p_zd< 0b1110111, "ucvtf", ZPR64, ZPR64, null_frag, AArch64ucvtf_mt, nxv2f64, nxv2i1, nxv2i64, ElementSizeD>;
2341-
defm FCVTZS_ZPmZ_DtoS : sve_fp_2op_p_zd< 0b1111000, "fcvtzs", ZPR64, ZPR32, int_aarch64_sve_fcvtzs_i32f64, null_frag, nxv4i32, nxv2i1, nxv2f64, ElementSizeD>;
2342-
defm FCVTZU_ZPmZ_DtoS : sve_fp_2op_p_zd< 0b1111001, "fcvtzu", ZPR64, ZPR32, int_aarch64_sve_fcvtzu_i32f64, null_frag, nxv4i32, nxv2i1, nxv2f64, ElementSizeD>;
2341+
defm FCVTZS_ZPmZ_DtoS : sve_fp_2op_p_zd< 0b1111000, "fcvtzs", ZPR64, ZPR32, int_aarch64_sve_fcvtzs_i32f64, AArch64fcvtzs_mt, nxv4i32, nxv2i1, nxv2f64, ElementSizeD>;
2342+
defm FCVTZU_ZPmZ_DtoS : sve_fp_2op_p_zd< 0b1111001, "fcvtzu", ZPR64, ZPR32, int_aarch64_sve_fcvtzu_i32f64, AArch64fcvtzu_mt, nxv4i32, nxv2i1, nxv2f64, ElementSizeD>;
23432343
defm FCVTZS_ZPmZ_StoD : sve_fp_2op_p_zd< 0b1111100, "fcvtzs", ZPR32, ZPR64, int_aarch64_sve_fcvtzs_i64f32, AArch64fcvtzs_mt, nxv2i64, nxv2i1, nxv2f32, ElementSizeD>;
23442344
defm FCVTZS_ZPmZ_HtoS : sve_fp_2op_p_zd< 0b0111100, "fcvtzs", ZPR16, ZPR32, int_aarch64_sve_fcvtzs_i32f16, AArch64fcvtzs_mt, nxv4i32, nxv4i1, nxv4f16, ElementSizeS>;
23452345
defm FCVTZS_ZPmZ_HtoD : sve_fp_2op_p_zd< 0b0111110, "fcvtzs", ZPR16, ZPR64, int_aarch64_sve_fcvtzs_i64f16, AArch64fcvtzs_mt, nxv2i64, nxv2i1, nxv2f16, ElementSizeD>;
@@ -2421,42 +2421,6 @@ let Predicates = [HasSVEorSME] in {
24212421
defm FSQRT_ZPmZ : sve_fp_2op_p_zd_HSD<0b01101, "fsqrt", AArch64fsqrt_mt>;
24222422
} // End HasSVEorSME
24232423

2424-
// Helper for creating scalar fp -> int -> fp conversions using SVE.
2425-
class sve_scalar_fp_int_fp_cvt
2426-
<Instruction PTRUE, Instruction FROM_INT, Instruction TO_INT, SubRegIndex sub>
2427-
: OutPatFrag<(ops node: $Rn),
2428-
(EXTRACT_SUBREG
2429-
(FROM_INT (IMPLICIT_DEF), (PTRUE 1),
2430-
(TO_INT (IMPLICIT_DEF), (PTRUE 1),
2431-
(INSERT_SUBREG (IMPLICIT_DEF), $Rn, sub))), sub)>;
2432-
2433-
// Some scalar float -> int -> float conversion patterns where we want to keep
2434-
// the int values in FP registers to avoid costly GPR <-> FPR register
2435-
// transfers using SVE instructions. Only used when NEON is not available (e.g.
2436-
// in streaming functions).
2437-
// TODO: When +sme2p2 is available Neon single-element vectors should be preferred.
2438-
def HasNoNEON : Predicate<"!Subtarget->isNeonAvailable()">;
2439-
let Predicates = [HasSVEorSME, HasNoNEON] in {
2440-
def : Pat<
2441-
(f64 (sint_to_fp (i64 (fp_to_sint f64:$Rn)))),
2442-
(sve_scalar_fp_int_fp_cvt<PTRUE_D, SCVTF_ZPmZ_DtoD, FCVTZS_ZPmZ_DtoD, dsub> $Rn)>;
2443-
def : Pat<
2444-
(f64 (uint_to_fp (i64 (fp_to_uint f64:$Rn)))),
2445-
(sve_scalar_fp_int_fp_cvt<PTRUE_D, UCVTF_ZPmZ_DtoD, FCVTZU_ZPmZ_DtoD, dsub> $Rn)>;
2446-
def : Pat<
2447-
(f32 (sint_to_fp (i32 (fp_to_sint f32:$Rn)))),
2448-
(sve_scalar_fp_int_fp_cvt<PTRUE_S, SCVTF_ZPmZ_StoS, FCVTZS_ZPmZ_StoS, ssub> $Rn)>;
2449-
def : Pat<
2450-
(f32 (uint_to_fp (i32 (fp_to_uint f32:$Rn)))),
2451-
(sve_scalar_fp_int_fp_cvt<PTRUE_S, UCVTF_ZPmZ_StoS, FCVTZU_ZPmZ_StoS, ssub> $Rn)>;
2452-
def : Pat<
2453-
(f16 (sint_to_fp (i32 (fp_to_sint f16:$Rn)))),
2454-
(sve_scalar_fp_int_fp_cvt<PTRUE_H, SCVTF_ZPmZ_HtoH, FCVTZS_ZPmZ_HtoH, hsub> $Rn)>;
2455-
def : Pat<
2456-
(f16 (uint_to_fp (i32 (fp_to_uint f16:$Rn)))),
2457-
(sve_scalar_fp_int_fp_cvt<PTRUE_H, UCVTF_ZPmZ_HtoH, FCVTZU_ZPmZ_HtoH, hsub> $Rn)>;
2458-
} // End HasSVEorSME, HasNoNEON
2459-
24602424
let Predicates = [HasBF16, HasSVEorSME] in {
24612425
defm BFDOT_ZZZ : sve_float_dot<0b1, 0b0, ZPR32, ZPR16, "bfdot", nxv8bf16, int_aarch64_sve_bfdot>;
24622426
defm BFDOT_ZZI : sve_float_dot_indexed<0b1, 0b00, ZPR16, ZPR3b16, "bfdot", nxv8bf16, int_aarch64_sve_bfdot_lane_v2>;

llvm/test/CodeGen/AArch64/sve-streaming-mode-cvt-fp-int-fp.ll

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ target triple = "aarch64-unknown-linux-gnu"
88
define double @t1(double %x) {
99
; CHECK-LABEL: t1:
1010
; CHECK: // %bb.0: // %entry
11-
; CHECK-NEXT: ptrue p0.d, vl1
11+
; CHECK-NEXT: ptrue p0.d
1212
; CHECK-NEXT: // kill: def $d0 killed $d0 def $z0
1313
; CHECK-NEXT: fcvtzs z0.d, p0/m, z0.d
1414
; CHECK-NEXT: scvtf z0.d, p0/m, z0.d
@@ -29,7 +29,7 @@ entry:
2929
define float @t2(float %x) {
3030
; CHECK-LABEL: t2:
3131
; CHECK: // %bb.0: // %entry
32-
; CHECK-NEXT: ptrue p0.s, vl1
32+
; CHECK-NEXT: ptrue p0.s
3333
; CHECK-NEXT: // kill: def $s0 killed $s0 def $z0
3434
; CHECK-NEXT: fcvtzs z0.s, p0/m, z0.s
3535
; CHECK-NEXT: scvtf z0.s, p0/m, z0.s
@@ -50,10 +50,10 @@ entry:
5050
define half @t3(half %x) {
5151
; CHECK-LABEL: t3:
5252
; CHECK: // %bb.0: // %entry
53-
; CHECK-NEXT: ptrue p0.h, vl1
53+
; CHECK-NEXT: ptrue p0.s
5454
; CHECK-NEXT: // kill: def $h0 killed $h0 def $z0
55-
; CHECK-NEXT: fcvtzs z0.h, p0/m, z0.h
56-
; CHECK-NEXT: scvtf z0.h, p0/m, z0.h
55+
; CHECK-NEXT: fcvtzs z0.s, p0/m, z0.h
56+
; CHECK-NEXT: scvtf z0.h, p0/m, z0.s
5757
; CHECK-NEXT: // kill: def $h0 killed $h0 killed $z0
5858
; CHECK-NEXT: ret
5959
;
@@ -73,7 +73,7 @@ entry:
7373
define double @t4(double %x) {
7474
; CHECK-LABEL: t4:
7575
; CHECK: // %bb.0: // %entry
76-
; CHECK-NEXT: ptrue p0.d, vl1
76+
; CHECK-NEXT: ptrue p0.d
7777
; CHECK-NEXT: // kill: def $d0 killed $d0 def $z0
7878
; CHECK-NEXT: fcvtzu z0.d, p0/m, z0.d
7979
; CHECK-NEXT: ucvtf z0.d, p0/m, z0.d
@@ -94,7 +94,7 @@ entry:
9494
define float @t5(float %x) {
9595
; CHECK-LABEL: t5:
9696
; CHECK: // %bb.0: // %entry
97-
; CHECK-NEXT: ptrue p0.s, vl1
97+
; CHECK-NEXT: ptrue p0.s
9898
; CHECK-NEXT: // kill: def $s0 killed $s0 def $z0
9999
; CHECK-NEXT: fcvtzu z0.s, p0/m, z0.s
100100
; CHECK-NEXT: ucvtf z0.s, p0/m, z0.s
@@ -115,10 +115,10 @@ entry:
115115
define half @t6(half %x) {
116116
; CHECK-LABEL: t6:
117117
; CHECK: // %bb.0: // %entry
118-
; CHECK-NEXT: ptrue p0.h, vl1
118+
; CHECK-NEXT: ptrue p0.s
119119
; CHECK-NEXT: // kill: def $h0 killed $h0 def $z0
120-
; CHECK-NEXT: fcvtzu z0.h, p0/m, z0.h
121-
; CHECK-NEXT: ucvtf z0.h, p0/m, z0.h
120+
; CHECK-NEXT: fcvtzu z0.s, p0/m, z0.h
121+
; CHECK-NEXT: ucvtf z0.h, p0/m, z0.s
122122
; CHECK-NEXT: // kill: def $h0 killed $h0 killed $z0
123123
; CHECK-NEXT: ret
124124
;

0 commit comments

Comments
 (0)