Skip to content

Commit edc1c3d

Browse files
committed
[AArch64] Make more vector f16 operations legal
v8f16 is a legal type but promoting to v16f16 would result in an illegal type. Let's legalize these by a combination of splitting+promoting resulting in a pair of v4f16. Also, we were being overly cautious with different v4f16 nodes. Mark more of them safe to promote to v4f32.
1 parent 5f935e9 commit edc1c3d

File tree

13 files changed

+864
-3665
lines changed

13 files changed

+864
-3665
lines changed

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 76 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -701,43 +701,45 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
701701
}
702702

703703
auto LegalizeNarrowFP = [this](MVT ScalarVT) {
704-
for (auto Op : {ISD::SETCC,
705-
ISD::SELECT_CC,
706-
ISD::BR_CC,
707-
ISD::FADD,
708-
ISD::FSUB,
709-
ISD::FMUL,
710-
ISD::FDIV,
711-
ISD::FMA,
712-
ISD::FCEIL,
713-
ISD::FSQRT,
714-
ISD::FFLOOR,
715-
ISD::FNEARBYINT,
716-
ISD::FRINT,
717-
ISD::FROUND,
718-
ISD::FROUNDEVEN,
719-
ISD::FTRUNC,
720-
ISD::FMINNUM,
721-
ISD::FMAXNUM,
722-
ISD::FMINIMUM,
723-
ISD::FMAXIMUM,
724-
ISD::STRICT_FADD,
725-
ISD::STRICT_FSUB,
726-
ISD::STRICT_FMUL,
727-
ISD::STRICT_FDIV,
728-
ISD::STRICT_FMA,
729-
ISD::STRICT_FCEIL,
730-
ISD::STRICT_FFLOOR,
731-
ISD::STRICT_FSQRT,
732-
ISD::STRICT_FRINT,
733-
ISD::STRICT_FNEARBYINT,
734-
ISD::STRICT_FROUND,
735-
ISD::STRICT_FTRUNC,
736-
ISD::STRICT_FROUNDEVEN,
737-
ISD::STRICT_FMINNUM,
738-
ISD::STRICT_FMAXNUM,
739-
ISD::STRICT_FMINIMUM,
740-
ISD::STRICT_FMAXIMUM})
704+
for (auto Op : {
705+
ISD::SETCC,
706+
ISD::SELECT_CC,
707+
ISD::BR_CC,
708+
ISD::FADD,
709+
ISD::FSUB,
710+
ISD::FMUL,
711+
ISD::FDIV,
712+
ISD::FMA,
713+
ISD::FCEIL,
714+
ISD::FSQRT,
715+
ISD::FFLOOR,
716+
ISD::FNEARBYINT,
717+
ISD::FRINT,
718+
ISD::FROUND,
719+
ISD::FROUNDEVEN,
720+
ISD::FTRUNC,
721+
ISD::FMINNUM,
722+
ISD::FMAXNUM,
723+
ISD::FMINIMUM,
724+
ISD::FMAXIMUM,
725+
ISD::STRICT_FADD,
726+
ISD::STRICT_FSUB,
727+
ISD::STRICT_FMUL,
728+
ISD::STRICT_FDIV,
729+
ISD::STRICT_FMA,
730+
ISD::STRICT_FCEIL,
731+
ISD::STRICT_FFLOOR,
732+
ISD::STRICT_FSQRT,
733+
ISD::STRICT_FRINT,
734+
ISD::STRICT_FNEARBYINT,
735+
ISD::STRICT_FROUND,
736+
ISD::STRICT_FTRUNC,
737+
ISD::STRICT_FROUNDEVEN,
738+
ISD::STRICT_FMINNUM,
739+
ISD::STRICT_FMAXNUM,
740+
ISD::STRICT_FMINIMUM,
741+
ISD::STRICT_FMAXIMUM,
742+
})
741743
setOperationAction(Op, ScalarVT, Promote);
742744

743745
for (auto Op : {ISD::FNEG, ISD::FABS})
@@ -752,45 +754,45 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
752754

753755
// promote v4f16 to v4f32 when that is known to be safe.
754756
auto V4Narrow = MVT::getVectorVT(ScalarVT, 4);
755-
setOperationPromotedToType(ISD::FADD, V4Narrow, MVT::v4f32);
756-
setOperationPromotedToType(ISD::FSUB, V4Narrow, MVT::v4f32);
757-
setOperationPromotedToType(ISD::FMUL, V4Narrow, MVT::v4f32);
758-
setOperationPromotedToType(ISD::FDIV, V4Narrow, MVT::v4f32);
759-
760-
setOperationAction(ISD::FABS, V4Narrow, Legal);
761-
setOperationAction(ISD::FNEG, V4Narrow, Legal);
762-
setOperationAction(ISD::FROUND, V4Narrow, Expand);
763-
setOperationAction(ISD::FROUNDEVEN, V4Narrow, Expand);
757+
setOperationPromotedToType(ISD::FADD, V4Narrow, MVT::v4f32);
758+
setOperationPromotedToType(ISD::FSUB, V4Narrow, MVT::v4f32);
759+
setOperationPromotedToType(ISD::FMUL, V4Narrow, MVT::v4f32);
760+
setOperationPromotedToType(ISD::FDIV, V4Narrow, MVT::v4f32);
761+
setOperationPromotedToType(ISD::FCEIL, V4Narrow, MVT::v4f32);
762+
setOperationPromotedToType(ISD::FFLOOR, V4Narrow, MVT::v4f32);
763+
setOperationPromotedToType(ISD::FROUND, V4Narrow, MVT::v4f32);
764+
setOperationPromotedToType(ISD::FTRUNC, V4Narrow, MVT::v4f32);
765+
setOperationPromotedToType(ISD::FROUNDEVEN, V4Narrow, MVT::v4f32);
766+
setOperationPromotedToType(ISD::FRINT, V4Narrow, MVT::v4f32);
767+
setOperationPromotedToType(ISD::FNEARBYINT, V4Narrow, MVT::v4f32);
768+
769+
setOperationAction(ISD::FABS, V4Narrow, Legal);
770+
setOperationAction(ISD::FNEG, V4Narrow, Legal);
764771
setOperationAction(ISD::FMA, V4Narrow, Expand);
765772
setOperationAction(ISD::SETCC, V4Narrow, Custom);
766773
setOperationAction(ISD::BR_CC, V4Narrow, Expand);
767774
setOperationAction(ISD::SELECT, V4Narrow, Expand);
768775
setOperationAction(ISD::SELECT_CC, V4Narrow, Expand);
769-
setOperationAction(ISD::FTRUNC, V4Narrow, Expand);
770-
setOperationAction(ISD::FCOPYSIGN, V4Narrow, Custom);
771-
setOperationAction(ISD::FFLOOR, V4Narrow, Expand);
772-
setOperationAction(ISD::FCEIL, V4Narrow, Expand);
773-
setOperationAction(ISD::FRINT, V4Narrow, Expand);
774-
setOperationAction(ISD::FNEARBYINT, V4Narrow, Expand);
776+
setOperationAction(ISD::FCOPYSIGN, V4Narrow, Custom);
775777
setOperationAction(ISD::FSQRT, V4Narrow, Expand);
776778

777779
auto V8Narrow = MVT::getVectorVT(ScalarVT, 8);
778-
setOperationAction(ISD::FABS, V8Narrow, Legal);
779-
setOperationAction(ISD::FADD, V8Narrow, Expand);
780-
setOperationAction(ISD::FCEIL, V8Narrow, Expand);
781-
setOperationAction(ISD::FCOPYSIGN, V8Narrow, Custom);
782-
setOperationAction(ISD::FDIV, V8Narrow, Expand);
783-
setOperationAction(ISD::FFLOOR, V8Narrow, Expand);
780+
setOperationAction(ISD::FABS, V8Narrow, Legal);
781+
setOperationAction(ISD::FADD, V8Narrow, Legal);
782+
setOperationAction(ISD::FCEIL, V8Narrow, Legal);
783+
setOperationAction(ISD::FCOPYSIGN, V8Narrow, Custom);
784+
setOperationAction(ISD::FDIV, V8Narrow, Legal);
785+
setOperationAction(ISD::FFLOOR, V8Narrow, Legal);
784786
setOperationAction(ISD::FMA, V8Narrow, Expand);
785-
setOperationAction(ISD::FMUL, V8Narrow, Expand);
786-
setOperationAction(ISD::FNEARBYINT, V8Narrow, Expand);
787-
setOperationAction(ISD::FNEG, V8Narrow, Legal);
788-
setOperationAction(ISD::FROUND, V8Narrow, Expand);
789-
setOperationAction(ISD::FROUNDEVEN, V8Narrow, Expand);
790-
setOperationAction(ISD::FRINT, V8Narrow, Expand);
787+
setOperationAction(ISD::FMUL, V8Narrow, Legal);
788+
setOperationAction(ISD::FNEARBYINT, V8Narrow, Legal);
789+
setOperationAction(ISD::FNEG, V8Narrow, Legal);
790+
setOperationAction(ISD::FROUND, V8Narrow, Legal);
791+
setOperationAction(ISD::FROUNDEVEN, V8Narrow, Legal);
792+
setOperationAction(ISD::FRINT, V8Narrow, Legal);
791793
setOperationAction(ISD::FSQRT, V8Narrow, Expand);
792-
setOperationAction(ISD::FSUB, V8Narrow, Expand);
793-
setOperationAction(ISD::FTRUNC, V8Narrow, Expand);
794+
setOperationAction(ISD::FSUB, V8Narrow, Legal);
795+
setOperationAction(ISD::FTRUNC, V8Narrow, Legal);
794796
setOperationAction(ISD::SETCC, V8Narrow, Expand);
795797
setOperationAction(ISD::BR_CC, V8Narrow, Expand);
796798
setOperationAction(ISD::SELECT, V8Narrow, Expand);
@@ -10593,13 +10595,19 @@ static SDValue getEstimate(const AArch64Subtarget *ST, unsigned Opcode,
1059310595
VT == MVT::v4f32)) ||
1059410596
(ST->hasSVE() &&
1059510597
(VT == MVT::nxv8f16 || VT == MVT::nxv4f32 || VT == MVT::nxv2f64))) {
10596-
if (ExtraSteps == TargetLoweringBase::ReciprocalEstimate::Unspecified)
10598+
if (ExtraSteps == TargetLoweringBase::ReciprocalEstimate::Unspecified) {
1059710599
// For the reciprocal estimates, convergence is quadratic, so the number
1059810600
// of digits is doubled after each iteration. In ARMv8, the accuracy of
1059910601
// the initial estimate is 2^-8. Thus the number of extra steps to refine
1060010602
// the result for float (23 mantissa bits) is 2 and for double (52
1060110603
// mantissa bits) is 3.
10602-
ExtraSteps = VT.getScalarType() == MVT::f64 ? 3 : 2;
10604+
constexpr unsigned AccurateBits = 8;
10605+
unsigned DesiredBits =
10606+
APFloat::semanticsPrecision(DAG.EVTToAPFloatSemantics(VT));
10607+
ExtraSteps = DesiredBits <= AccurateBits
10608+
? 0
10609+
: Log2_64_Ceil(DesiredBits) - Log2_64_Ceil(AccurateBits);
10610+
}
1060310611

1060410612
return DAG.getNode(Opcode, SDLoc(Operand), VT, Operand);
1060510613
}

llvm/lib/Target/AArch64/AArch64InstrInfo.td

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,7 @@ def HasRDM : Predicate<"Subtarget->hasRDM()">,
128128
AssemblerPredicateWithAll<(all_of FeatureRDM), "rdm">;
129129
def HasFullFP16 : Predicate<"Subtarget->hasFullFP16()">,
130130
AssemblerPredicateWithAll<(all_of FeatureFullFP16), "fullfp16">;
131+
def HasNoFullFP16 : Predicate<"!Subtarget->hasFullFP16()">;
131132
def HasFP16FML : Predicate<"Subtarget->hasFP16FML()">,
132133
AssemblerPredicateWithAll<(all_of FeatureFP16FML), "fp16fml">;
133134
def HasSPE : Predicate<"Subtarget->hasSPE()">,
@@ -254,6 +255,7 @@ def HasTRBE : Predicate<"Subtarget->hasTRBE()">,
254255
AssemblerPredicateWithAll<(all_of FeatureTRBE), "trbe">;
255256
def HasBF16 : Predicate<"Subtarget->hasBF16()">,
256257
AssemblerPredicateWithAll<(all_of FeatureBF16), "bf16">;
258+
def HasNoBF16 : Predicate<"!Subtarget->hasBF16()">;
257259
def HasMatMulInt8 : Predicate<"Subtarget->hasMatMulInt8()">,
258260
AssemblerPredicateWithAll<(all_of FeatureMatMulInt8), "i8mm">;
259261
def HasMatMulFP32 : Predicate<"Subtarget->hasMatMulFP32()">,
@@ -764,6 +766,8 @@ def AArch64fcvtxnv: PatFrags<(ops node:$Rn),
764766
[(int_aarch64_neon_fcvtxn node:$Rn),
765767
(AArch64fcvtxn_n node:$Rn)]>;
766768

769+
//def Aarch64softf32tobf16v8: SDNode<"AArch64ISD::", SDTFPRoundOp>;
770+
767771
def AArch64bici: SDNode<"AArch64ISD::BICi", SDT_AArch64vecimm>;
768772
def AArch64orri: SDNode<"AArch64ISD::ORRi", SDT_AArch64vecimm>;
769773

@@ -9739,6 +9743,93 @@ let Predicates = [HasCPA] in {
97399743
def MSUBPT : MulAccumCPA<1, "msubpt">;
97409744
}
97419745

9746+
def round_v4fp32_to_v4bf16 :
9747+
OutPatFrag<(ops node:$Rn),
9748+
// NaN? Round : Quiet(NaN)
9749+
(BSPv16i8 (FCMEQv4f32 $Rn, $Rn),
9750+
(ADDv4i32
9751+
(ADDv4i32 $Rn,
9752+
// Extract the LSB of the fp32 *truncated* to bf16.
9753+
(ANDv16i8 (USHRv4i32_shift V128:$Rn, (i32 16)),
9754+
(MOVIv4i32 (i32 1), (i32 0)))),
9755+
// Bias which will help us break ties correctly.
9756+
(MOVIv4s_msl (i32 127), (i32 264))),
9757+
// Set the quiet bit in the NaN.
9758+
(ORRv4i32 $Rn, (i32 64), (i32 16)))>;
9759+
9760+
multiclass PromoteUnaryv8f16Tov4f32<SDPatternOperator InOp, Instruction OutInst> {
9761+
let Predicates = [HasNoFullFP16] in
9762+
def : Pat<(InOp (v8f16 V128:$Rn)),
9763+
(v8f16 (FCVTNv8i16
9764+
(INSERT_SUBREG (IMPLICIT_DEF),
9765+
(v4f16 (FCVTNv4i16
9766+
(v4f32 (OutInst
9767+
(v4f32 (FCVTLv4i16 (v4i16 (EXTRACT_SUBREG V128:$Rn, dsub)))))))),
9768+
dsub),
9769+
(v4f32 (OutInst (v4f32 (FCVTLv8i16 V128:$Rn))))))>;
9770+
9771+
let Predicates = [HasBF16] in
9772+
def : Pat<(InOp (v8bf16 V128:$Rn)),
9773+
(v8bf16 (BFCVTN2
9774+
(v8bf16 (BFCVTN
9775+
(v4f32 (OutInst
9776+
(v4f32 (SHLLv4i16 (v4i16 (EXTRACT_SUBREG V128:$Rn, dsub)))))))),
9777+
(v4f32 (OutInst (v4f32 (SHLLv8i16 V128:$Rn))))))>;
9778+
9779+
let Predicates = [HasNoBF16] in
9780+
def : Pat<(InOp (v8bf16 V128:$Rn)),
9781+
(UZP2v8i16
9782+
(round_v4fp32_to_v4bf16 (v4f32 (OutInst
9783+
(v4f32 (SHLLv4i16 (v4i16 (EXTRACT_SUBREG V128:$Rn, dsub))))))),
9784+
(round_v4fp32_to_v4bf16 (v4f32 (OutInst
9785+
(v4f32 (SHLLv8i16 V128:$Rn))))))>;
9786+
}
9787+
defm : PromoteUnaryv8f16Tov4f32<any_fceil, FRINTPv4f32>;
9788+
defm : PromoteUnaryv8f16Tov4f32<any_ffloor, FRINTMv4f32>;
9789+
defm : PromoteUnaryv8f16Tov4f32<any_fnearbyint, FRINTIv4f32>;
9790+
defm : PromoteUnaryv8f16Tov4f32<any_fround, FRINTAv4f32>;
9791+
defm : PromoteUnaryv8f16Tov4f32<any_froundeven, FRINTNv4f32>;
9792+
defm : PromoteUnaryv8f16Tov4f32<any_frint, FRINTXv4f32>;
9793+
defm : PromoteUnaryv8f16Tov4f32<any_ftrunc, FRINTZv4f32>;
9794+
9795+
multiclass PromoteBinaryv8f16Tov4f32<SDPatternOperator InOp, Instruction OutInst> {
9796+
let Predicates = [HasNoFullFP16] in
9797+
def : Pat<(InOp (v8f16 V128:$Rn), (v8f16 V128:$Rm)),
9798+
(v8f16 (FCVTNv8i16
9799+
(INSERT_SUBREG (IMPLICIT_DEF),
9800+
(v4f16 (FCVTNv4i16
9801+
(v4f32 (OutInst
9802+
(v4f32 (FCVTLv4i16 (v4i16 (EXTRACT_SUBREG V128:$Rn, dsub)))),
9803+
(v4f32 (FCVTLv4i16 (v4i16 (EXTRACT_SUBREG V128:$Rm, dsub)))))))),
9804+
dsub),
9805+
(v4f32 (OutInst (v4f32 (FCVTLv8i16 V128:$Rn)),
9806+
(v4f32 (FCVTLv8i16 V128:$Rm))))))>;
9807+
9808+
let Predicates = [HasBF16] in
9809+
def : Pat<(InOp (v8bf16 V128:$Rn), (v8bf16 V128:$Rm)),
9810+
(v8bf16 (BFCVTN2
9811+
(v8bf16 (BFCVTN
9812+
(v4f32 (OutInst
9813+
(v4f32 (SHLLv4i16 (v4i16 (EXTRACT_SUBREG V128:$Rn, dsub)))),
9814+
(v4f32 (SHLLv4i16 (v4i16 (EXTRACT_SUBREG V128:$Rm, dsub)))))))),
9815+
(v4f32 (OutInst (v4f32 (SHLLv8i16 V128:$Rn)),
9816+
(v4f32 (SHLLv8i16 V128:$Rm))))))>;
9817+
9818+
let Predicates = [HasNoBF16] in
9819+
def : Pat<(InOp (v8bf16 V128:$Rn), (v8bf16 V128:$Rm)),
9820+
(UZP2v8i16
9821+
(round_v4fp32_to_v4bf16 (v4f32 (OutInst
9822+
(v4f32 (SHLLv4i16 (v4i16 (EXTRACT_SUBREG V128:$Rn, dsub)))),
9823+
(v4f32 (SHLLv4i16 (v4i16 (EXTRACT_SUBREG V128:$Rm, dsub))))))),
9824+
(round_v4fp32_to_v4bf16 (v4f32 (OutInst
9825+
(v4f32 (SHLLv8i16 V128:$Rn)),
9826+
(v4f32 (SHLLv8i16 V128:$Rm))))))>;
9827+
}
9828+
defm : PromoteBinaryv8f16Tov4f32<any_fadd, FADDv4f32>;
9829+
defm : PromoteBinaryv8f16Tov4f32<any_fdiv, FDIVv4f32>;
9830+
defm : PromoteBinaryv8f16Tov4f32<any_fmul, FMULv4f32>;
9831+
defm : PromoteBinaryv8f16Tov4f32<any_fsub, FSUBv4f32>;
9832+
97429833
include "AArch64InstrAtomics.td"
97439834
include "AArch64SVEInstrInfo.td"
97449835
include "AArch64SMEInstrInfo.td"

0 commit comments

Comments
 (0)