Skip to content

Commit aea6409

Browse files
committed
Lower scalar FP converts to SVE
1 parent 974c5ae commit aea6409

7 files changed

+880
-294
lines changed

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 93 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1454,8 +1454,12 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
14541454
setOperationAction(ISD::INSERT_SUBVECTOR, VT, Custom);
14551455
setOperationAction(ISD::UINT_TO_FP, VT, Custom);
14561456
setOperationAction(ISD::SINT_TO_FP, VT, Custom);
1457+
setOperationAction(ISD::STRICT_UINT_TO_FP, VT, Custom);
1458+
setOperationAction(ISD::STRICT_SINT_TO_FP, VT, Custom);
14571459
setOperationAction(ISD::FP_TO_UINT, VT, Custom);
14581460
setOperationAction(ISD::FP_TO_SINT, VT, Custom);
1461+
setOperationAction(ISD::STRICT_FP_TO_UINT, VT, Custom);
1462+
setOperationAction(ISD::STRICT_FP_TO_SINT, VT, Custom);
14591463
setOperationAction(ISD::MLOAD, VT, Custom);
14601464
setOperationAction(ISD::MUL, VT, Custom);
14611465
setOperationAction(ISD::MULHS, VT, Custom);
@@ -2138,6 +2142,8 @@ void AArch64TargetLowering::addTypeForFixedLengthSVE(MVT VT) {
21382142
setOperationAction(ISD::FP_ROUND, VT, Default);
21392143
setOperationAction(ISD::FP_TO_SINT, VT, Default);
21402144
setOperationAction(ISD::FP_TO_UINT, VT, Default);
2145+
setOperationAction(ISD::STRICT_FP_TO_SINT, VT, Default);
2146+
setOperationAction(ISD::STRICT_FP_TO_UINT, VT, Default);
21412147
setOperationAction(ISD::FRINT, VT, Default);
21422148
setOperationAction(ISD::LRINT, VT, Default);
21432149
setOperationAction(ISD::LLRINT, VT, Default);
@@ -2164,6 +2170,7 @@ void AArch64TargetLowering::addTypeForFixedLengthSVE(MVT VT) {
21642170
setOperationAction(ISD::SIGN_EXTEND, VT, Default);
21652171
setOperationAction(ISD::SIGN_EXTEND_INREG, VT, Default);
21662172
setOperationAction(ISD::SINT_TO_FP, VT, Default);
2173+
setOperationAction(ISD::STRICT_SINT_TO_FP, VT, Default);
21672174
setOperationAction(ISD::SMAX, VT, Default);
21682175
setOperationAction(ISD::SMIN, VT, Default);
21692176
setOperationAction(ISD::SPLAT_VECTOR, VT, Default);
@@ -2174,6 +2181,7 @@ void AArch64TargetLowering::addTypeForFixedLengthSVE(MVT VT) {
21742181
setOperationAction(ISD::TRUNCATE, VT, Default);
21752182
setOperationAction(ISD::UDIV, VT, Default);
21762183
setOperationAction(ISD::UINT_TO_FP, VT, Default);
2184+
setOperationAction(ISD::STRICT_UINT_TO_FP, VT, Default);
21772185
setOperationAction(ISD::UMAX, VT, Default);
21782186
setOperationAction(ISD::UMIN, VT, Default);
21792187
setOperationAction(ISD::VECREDUCE_ADD, VT, Default);
@@ -4550,9 +4558,10 @@ SDValue AArch64TargetLowering::LowerVectorFP_TO_INT(SDValue Op,
45504558
EVT VT = Op.getValueType();
45514559

45524560
if (VT.isScalableVector()) {
4553-
unsigned Opcode = Op.getOpcode() == ISD::FP_TO_UINT
4554-
? AArch64ISD::FCVTZU_MERGE_PASSTHRU
4555-
: AArch64ISD::FCVTZS_MERGE_PASSTHRU;
4561+
unsigned Opc = Op.getOpcode();
4562+
bool IsSigned = Opc == ISD::FP_TO_SINT || Opc == ISD::STRICT_FP_TO_SINT;
4563+
unsigned Opcode = IsSigned ? AArch64ISD::FCVTZS_MERGE_PASSTHRU
4564+
: AArch64ISD::FCVTZU_MERGE_PASSTHRU;
45564565
return LowerToPredicatedOp(Op, DAG, Opcode);
45574566
}
45584567

@@ -4628,6 +4637,51 @@ SDValue AArch64TargetLowering::LowerVectorFP_TO_INT(SDValue Op,
46284637
return Op;
46294638
}
46304639

4640+
static bool CanLowerToScalarSVEFPIntConversion(EVT VT) {
4641+
if (!VT.isSimple())
4642+
return false;
4643+
// There are SVE instructions that can convert to/from all pairs of these int
4644+
// and float types. Note: We don't bother with i8 or i16 as those are illegal
4645+
// types for scalars.
4646+
return is_contained({MVT::i32, MVT::i64, MVT::f16, MVT::f32, MVT::f64},
4647+
VT.getSimpleVT().SimpleTy);
4648+
}
4649+
4650+
/// Lowers a scalar FP conversion (to/from) int to SVE.
4651+
static SDValue LowerScalarFPConversionToSVE(SDValue Op, SelectionDAG &DAG) {
4652+
bool IsStrict = Op->isStrictFPOpcode();
4653+
SDValue SrcVal = Op.getOperand(IsStrict ? 1 : 0);
4654+
EVT SrcTy = SrcVal.getValueType();
4655+
EVT DestTy = Op.getValueType();
4656+
EVT SrcVecTy;
4657+
EVT DestVecTy;
4658+
// Use a packed vector for the larger type.
4659+
// Note: For conversions such as FCVTZS_ZPmZ_DtoS, and UCVTF_ZPmZ_StoD that
4660+
// notionally take or return a nxv2i32 type we must instead use a nxv4i32, as
4661+
// (unlike floats) nxv2i32 is an illegal unpacked type.
4662+
if (DestTy.bitsGT(SrcTy)) {
4663+
DestVecTy = getPackedSVEVectorVT(DestTy);
4664+
SrcVecTy = SrcTy == MVT::i32 ? getPackedSVEVectorVT(SrcTy)
4665+
: DestVecTy.changeVectorElementType(SrcTy);
4666+
} else {
4667+
SrcVecTy = getPackedSVEVectorVT(SrcTy);
4668+
DestVecTy = DestTy == MVT::i32 ? getPackedSVEVectorVT(DestTy)
4669+
: SrcVecTy.changeVectorElementType(DestTy);
4670+
}
4671+
SDLoc dl(Op);
4672+
SDValue ZeroIdx = DAG.getVectorIdxConstant(0, dl);
4673+
SDValue Vec = DAG.getNode(ISD::INSERT_VECTOR_ELT, dl, SrcVecTy,
4674+
DAG.getUNDEF(SrcVecTy), SrcVal, ZeroIdx);
4675+
Vec = IsStrict ? DAG.getNode(Op.getOpcode(), dl, {DestVecTy, MVT::Other},
4676+
{Op.getOperand(0), Vec})
4677+
: DAG.getNode(Op.getOpcode(), dl, DestVecTy, Vec);
4678+
SDValue Scalar =
4679+
DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, Op.getValueType(), Vec, ZeroIdx);
4680+
if (IsStrict)
4681+
return DAG.getMergeValues({Scalar, Vec.getValue(1)}, dl);
4682+
return Scalar;
4683+
}
4684+
46314685
SDValue AArch64TargetLowering::LowerFP_TO_INT(SDValue Op,
46324686
SelectionDAG &DAG) const {
46334687
bool IsStrict = Op->isStrictFPOpcode();
@@ -4636,6 +4690,12 @@ SDValue AArch64TargetLowering::LowerFP_TO_INT(SDValue Op,
46364690
if (SrcVal.getValueType().isVector())
46374691
return LowerVectorFP_TO_INT(Op, DAG);
46384692

4693+
if (!Subtarget->isNeonAvailable() &&
4694+
Subtarget->isSVEorStreamingSVEAvailable() &&
4695+
CanLowerToScalarSVEFPIntConversion(SrcVal.getValueType()) &&
4696+
CanLowerToScalarSVEFPIntConversion(Op.getValueType()))
4697+
return LowerScalarFPConversionToSVE(Op, DAG);
4698+
46394699
// f16 conversions are promoted to f32 when full fp16 is not supported.
46404700
if ((SrcVal.getValueType() == MVT::f16 && !Subtarget->hasFullFP16()) ||
46414701
SrcVal.getValueType() == MVT::bf16) {
@@ -4939,6 +4999,12 @@ SDValue AArch64TargetLowering::LowerINT_TO_FP(SDValue Op,
49394999
bool IsStrict = Op->isStrictFPOpcode();
49405000
SDValue SrcVal = Op.getOperand(IsStrict ? 1 : 0);
49415001

5002+
if (!Subtarget->isNeonAvailable() &&
5003+
Subtarget->isSVEorStreamingSVEAvailable() &&
5004+
CanLowerToScalarSVEFPIntConversion(SrcVal.getValueType()) &&
5005+
CanLowerToScalarSVEFPIntConversion(Op.getValueType()))
5006+
return LowerScalarFPConversionToSVE(Op, DAG);
5007+
49425008
bool IsSigned = Op->getOpcode() == ISD::STRICT_SINT_TO_FP ||
49435009
Op->getOpcode() == ISD::SINT_TO_FP;
49445010

@@ -28293,7 +28359,21 @@ SDValue AArch64TargetLowering::LowerToPredicatedOp(SDValue Op,
2829328359
unsigned NewOp) const {
2829428360
EVT VT = Op.getValueType();
2829528361
SDLoc DL(Op);
28296-
auto Pg = getPredicateForVector(DAG, DL, VT);
28362+
SDValue Pg;
28363+
28364+
// FCVTZS_ZPmZ_DtoS and FCVTZU_ZPmZ_DtoS are special cases. These operations
28365+
// return nxv4i32 rather than the correct nxv2i32, as nxv2i32 is an illegal
28366+
// unpacked type. So, in this case, we take the predicate size from the
28367+
// operand.
28368+
SDValue LastOp{};
28369+
if ((NewOp == AArch64ISD::FCVTZU_MERGE_PASSTHRU ||
28370+
NewOp == AArch64ISD::FCVTZS_MERGE_PASSTHRU) &&
28371+
VT == MVT::nxv4i32 &&
28372+
(LastOp = Op->ops().back().get()).getValueType() == MVT::nxv2f64) {
28373+
Pg = getPredicateForVector(DAG, DL, LastOp.getValueType());
28374+
} else {
28375+
Pg = getPredicateForVector(DAG, DL, VT);
28376+
}
2829728377

2829828378
if (VT.isFixedLengthVector()) {
2829928379
assert(isTypeLegal(VT) && "Expected only legal fixed-width types");
@@ -28329,7 +28409,12 @@ SDValue AArch64TargetLowering::LowerToPredicatedOp(SDValue Op,
2832928409
assert(VT.isScalableVector() && "Only expect to lower scalable vector op!");
2833028410

2833128411
SmallVector<SDValue, 4> Operands = {Pg};
28412+
SDValue Chain{};
2833228413
for (const SDValue &V : Op->op_values()) {
28414+
if (!isa<CondCodeSDNode>(V) && V.getValueType() == MVT::Other) {
28415+
Chain = V;
28416+
continue;
28417+
}
2833328418
assert((!V.getValueType().isVector() ||
2833428419
V.getValueType().isScalableVector()) &&
2833528420
"Only scalable vectors are supported!");
@@ -28339,7 +28424,10 @@ SDValue AArch64TargetLowering::LowerToPredicatedOp(SDValue Op,
2833928424
if (isMergePassthruOpcode(NewOp))
2834028425
Operands.push_back(DAG.getUNDEF(VT));
2834128426

28342-
return DAG.getNode(NewOp, DL, VT, Operands, Op->getFlags());
28427+
auto NewNode = DAG.getNode(NewOp, DL, VT, Operands, Op->getFlags());
28428+
if (Chain)
28429+
return DAG.getMergeValues({NewNode, Chain}, DL);
28430+
return NewNode;
2834328431
}
2834428432

2834528433
// If a fixed length vector operation has no side effects when applied to

llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td

Lines changed: 2 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -2338,8 +2338,8 @@ let Predicates = [HasSVEorSME] in {
23382338
defm UCVTF_ZPmZ_DtoH : sve_fp_2op_p_zd< 0b0110111, "ucvtf", ZPR64, ZPR16, int_aarch64_sve_ucvtf_f16i64, AArch64ucvtf_mt, nxv2f16, nxv2i1, nxv2i64, ElementSizeD>;
23392339
defm SCVTF_ZPmZ_DtoD : sve_fp_2op_p_zd< 0b1110110, "scvtf", ZPR64, ZPR64, null_frag, AArch64scvtf_mt, nxv2f64, nxv2i1, nxv2i64, ElementSizeD>;
23402340
defm UCVTF_ZPmZ_DtoD : sve_fp_2op_p_zd< 0b1110111, "ucvtf", ZPR64, ZPR64, null_frag, AArch64ucvtf_mt, nxv2f64, nxv2i1, nxv2i64, ElementSizeD>;
2341-
defm FCVTZS_ZPmZ_DtoS : sve_fp_2op_p_zd< 0b1111000, "fcvtzs", ZPR64, ZPR32, int_aarch64_sve_fcvtzs_i32f64, null_frag, nxv4i32, nxv2i1, nxv2f64, ElementSizeD>;
2342-
defm FCVTZU_ZPmZ_DtoS : sve_fp_2op_p_zd< 0b1111001, "fcvtzu", ZPR64, ZPR32, int_aarch64_sve_fcvtzu_i32f64, null_frag, nxv4i32, nxv2i1, nxv2f64, ElementSizeD>;
2341+
defm FCVTZS_ZPmZ_DtoS : sve_fp_2op_p_zd< 0b1111000, "fcvtzs", ZPR64, ZPR32, int_aarch64_sve_fcvtzs_i32f64, AArch64fcvtzs_mt, nxv4i32, nxv2i1, nxv2f64, ElementSizeD>;
2342+
defm FCVTZU_ZPmZ_DtoS : sve_fp_2op_p_zd< 0b1111001, "fcvtzu", ZPR64, ZPR32, int_aarch64_sve_fcvtzu_i32f64, AArch64fcvtzu_mt, nxv4i32, nxv2i1, nxv2f64, ElementSizeD>;
23432343
defm FCVTZS_ZPmZ_StoD : sve_fp_2op_p_zd< 0b1111100, "fcvtzs", ZPR32, ZPR64, int_aarch64_sve_fcvtzs_i64f32, AArch64fcvtzs_mt, nxv2i64, nxv2i1, nxv2f32, ElementSizeD>;
23442344
defm FCVTZS_ZPmZ_HtoS : sve_fp_2op_p_zd< 0b0111100, "fcvtzs", ZPR16, ZPR32, int_aarch64_sve_fcvtzs_i32f16, AArch64fcvtzs_mt, nxv4i32, nxv4i1, nxv4f16, ElementSizeS>;
23452345
defm FCVTZS_ZPmZ_HtoD : sve_fp_2op_p_zd< 0b0111110, "fcvtzs", ZPR16, ZPR64, int_aarch64_sve_fcvtzs_i64f16, AArch64fcvtzs_mt, nxv2i64, nxv2i1, nxv2f16, ElementSizeD>;
@@ -2421,42 +2421,6 @@ let Predicates = [HasSVEorSME] in {
24212421
defm FSQRT_ZPmZ : sve_fp_2op_p_zd_HSD<0b01101, "fsqrt", AArch64fsqrt_mt>;
24222422
} // End HasSVEorSME
24232423

2424-
// Helper for creating scalar fp -> int -> fp conversions using SVE.
2425-
class sve_scalar_fp_int_fp_cvt
2426-
<Instruction PTRUE, Instruction FROM_INT, Instruction TO_INT, SubRegIndex sub>
2427-
: OutPatFrag<(ops node: $Rn),
2428-
(EXTRACT_SUBREG
2429-
(FROM_INT (IMPLICIT_DEF), (PTRUE 1),
2430-
(TO_INT (IMPLICIT_DEF), (PTRUE 1),
2431-
(INSERT_SUBREG (IMPLICIT_DEF), $Rn, sub))), sub)>;
2432-
2433-
// Some scalar float -> int -> float conversion patterns where we want to keep
2434-
// the int values in FP registers to avoid costly GPR <-> FPR register
2435-
// transfers using SVE instructions. Only used when NEON is not available (e.g.
2436-
// in streaming functions).
2437-
// TODO: When +sme2p2 is available Neon single-element vectors should be preferred.
2438-
def HasNoNEON : Predicate<"!Subtarget->isNeonAvailable()">;
2439-
let Predicates = [HasSVEorSME, HasNoNEON] in {
2440-
def : Pat<
2441-
(f64 (sint_to_fp (i64 (fp_to_sint f64:$Rn)))),
2442-
(sve_scalar_fp_int_fp_cvt<PTRUE_D, SCVTF_ZPmZ_DtoD, FCVTZS_ZPmZ_DtoD, dsub> $Rn)>;
2443-
def : Pat<
2444-
(f64 (uint_to_fp (i64 (fp_to_uint f64:$Rn)))),
2445-
(sve_scalar_fp_int_fp_cvt<PTRUE_D, UCVTF_ZPmZ_DtoD, FCVTZU_ZPmZ_DtoD, dsub> $Rn)>;
2446-
def : Pat<
2447-
(f32 (sint_to_fp (i32 (fp_to_sint f32:$Rn)))),
2448-
(sve_scalar_fp_int_fp_cvt<PTRUE_S, SCVTF_ZPmZ_StoS, FCVTZS_ZPmZ_StoS, ssub> $Rn)>;
2449-
def : Pat<
2450-
(f32 (uint_to_fp (i32 (fp_to_uint f32:$Rn)))),
2451-
(sve_scalar_fp_int_fp_cvt<PTRUE_S, UCVTF_ZPmZ_StoS, FCVTZU_ZPmZ_StoS, ssub> $Rn)>;
2452-
def : Pat<
2453-
(f16 (sint_to_fp (i32 (fp_to_sint f16:$Rn)))),
2454-
(sve_scalar_fp_int_fp_cvt<PTRUE_H, SCVTF_ZPmZ_HtoH, FCVTZS_ZPmZ_HtoH, hsub> $Rn)>;
2455-
def : Pat<
2456-
(f16 (uint_to_fp (i32 (fp_to_uint f16:$Rn)))),
2457-
(sve_scalar_fp_int_fp_cvt<PTRUE_H, UCVTF_ZPmZ_HtoH, FCVTZU_ZPmZ_HtoH, hsub> $Rn)>;
2458-
} // End HasSVEorSME, HasNoNEON
2459-
24602424
let Predicates = [HasBF16, HasSVEorSME] in {
24612425
defm BFDOT_ZZZ : sve_float_dot<0b1, 0b0, ZPR32, ZPR16, "bfdot", nxv8bf16, int_aarch64_sve_bfdot>;
24622426
defm BFDOT_ZZI : sve_float_dot_indexed<0b1, 0b00, ZPR16, ZPR3b16, "bfdot", nxv8bf16, int_aarch64_sve_bfdot_lane_v2>;

llvm/test/CodeGen/AArch64/sve-streaming-mode-cvt-fp-int-fp.ll

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ target triple = "aarch64-unknown-linux-gnu"
88
define double @t1(double %x) {
99
; CHECK-LABEL: t1:
1010
; CHECK: // %bb.0: // %entry
11-
; CHECK-NEXT: ptrue p0.d, vl1
11+
; CHECK-NEXT: ptrue p0.d
1212
; CHECK-NEXT: // kill: def $d0 killed $d0 def $z0
1313
; CHECK-NEXT: fcvtzs z0.d, p0/m, z0.d
1414
; CHECK-NEXT: scvtf z0.d, p0/m, z0.d
@@ -35,7 +35,7 @@ entry:
3535
define float @t2(float %x) {
3636
; CHECK-LABEL: t2:
3737
; CHECK: // %bb.0: // %entry
38-
; CHECK-NEXT: ptrue p0.s, vl1
38+
; CHECK-NEXT: ptrue p0.s
3939
; CHECK-NEXT: // kill: def $s0 killed $s0 def $z0
4040
; CHECK-NEXT: fcvtzs z0.s, p0/m, z0.s
4141
; CHECK-NEXT: scvtf z0.s, p0/m, z0.s
@@ -62,10 +62,10 @@ entry:
6262
define half @t3(half %x) {
6363
; CHECK-LABEL: t3:
6464
; CHECK: // %bb.0: // %entry
65-
; CHECK-NEXT: ptrue p0.h, vl1
65+
; CHECK-NEXT: ptrue p0.s
6666
; CHECK-NEXT: // kill: def $h0 killed $h0 def $z0
67-
; CHECK-NEXT: fcvtzs z0.h, p0/m, z0.h
68-
; CHECK-NEXT: scvtf z0.h, p0/m, z0.h
67+
; CHECK-NEXT: fcvtzs z0.s, p0/m, z0.h
68+
; CHECK-NEXT: scvtf z0.h, p0/m, z0.s
6969
; CHECK-NEXT: // kill: def $h0 killed $h0 killed $z0
7070
; CHECK-NEXT: ret
7171
;
@@ -93,7 +93,7 @@ entry:
9393
define double @t4(double %x) {
9494
; CHECK-LABEL: t4:
9595
; CHECK: // %bb.0: // %entry
96-
; CHECK-NEXT: ptrue p0.d, vl1
96+
; CHECK-NEXT: ptrue p0.d
9797
; CHECK-NEXT: // kill: def $d0 killed $d0 def $z0
9898
; CHECK-NEXT: fcvtzu z0.d, p0/m, z0.d
9999
; CHECK-NEXT: ucvtf z0.d, p0/m, z0.d
@@ -120,7 +120,7 @@ entry:
120120
define float @t5(float %x) {
121121
; CHECK-LABEL: t5:
122122
; CHECK: // %bb.0: // %entry
123-
; CHECK-NEXT: ptrue p0.s, vl1
123+
; CHECK-NEXT: ptrue p0.s
124124
; CHECK-NEXT: // kill: def $s0 killed $s0 def $z0
125125
; CHECK-NEXT: fcvtzu z0.s, p0/m, z0.s
126126
; CHECK-NEXT: ucvtf z0.s, p0/m, z0.s
@@ -147,10 +147,10 @@ entry:
147147
define half @t6(half %x) {
148148
; CHECK-LABEL: t6:
149149
; CHECK: // %bb.0: // %entry
150-
; CHECK-NEXT: ptrue p0.h, vl1
150+
; CHECK-NEXT: ptrue p0.s
151151
; CHECK-NEXT: // kill: def $h0 killed $h0 def $z0
152-
; CHECK-NEXT: fcvtzu z0.h, p0/m, z0.h
153-
; CHECK-NEXT: ucvtf z0.h, p0/m, z0.h
152+
; CHECK-NEXT: fcvtzu z0.s, p0/m, z0.h
153+
; CHECK-NEXT: ucvtf z0.h, p0/m, z0.s
154154
; CHECK-NEXT: // kill: def $h0 killed $h0 killed $z0
155155
; CHECK-NEXT: ret
156156
;

0 commit comments

Comments
 (0)