@@ -676,11 +676,13 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
676
676
setOperationAction(ISD::FPOW, MVT::f64, Expand);
677
677
setOperationAction(ISD::FCOPYSIGN, MVT::f64, Custom);
678
678
setOperationAction(ISD::FCOPYSIGN, MVT::f32, Custom);
679
- if (Subtarget->hasFullFP16())
679
+ if (Subtarget->hasFullFP16()) {
680
680
setOperationAction(ISD::FCOPYSIGN, MVT::f16, Custom);
681
- else
681
+ setOperationAction(ISD::FCOPYSIGN, MVT::bf16, Custom);
682
+ } else {
682
683
setOperationAction(ISD::FCOPYSIGN, MVT::f16, Promote);
683
- setOperationAction(ISD::FCOPYSIGN, MVT::bf16, Promote);
684
+ setOperationAction(ISD::FCOPYSIGN, MVT::bf16, Promote);
685
+ }
684
686
685
687
for (auto Op : {ISD::FREM, ISD::FPOW, ISD::FPOWI,
686
688
ISD::FCOS, ISD::FSIN, ISD::FSINCOS,
@@ -699,23 +701,48 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
699
701
}
700
702
701
703
auto LegalizeNarrowFP = [this](MVT ScalarVT) {
702
- for (auto Op :
703
- {ISD::SETCC, ISD::SELECT_CC,
704
- ISD::BR_CC, ISD::FADD, ISD::FSUB,
705
- ISD::FMUL, ISD::FDIV, ISD::FMA,
706
- ISD::FNEG, ISD::FABS, ISD::FCEIL,
707
- ISD::FSQRT, ISD::FFLOOR, ISD::FNEARBYINT,
708
- ISD::FRINT, ISD::FROUND, ISD::FROUNDEVEN,
709
- ISD::FTRUNC, ISD::FMINNUM, ISD::FMAXNUM,
710
- ISD::FMINIMUM, ISD::FMAXIMUM, ISD::STRICT_FADD,
711
- ISD::STRICT_FSUB, ISD::STRICT_FMUL, ISD::STRICT_FDIV,
712
- ISD::STRICT_FMA, ISD::STRICT_FCEIL, ISD::STRICT_FFLOOR,
713
- ISD::STRICT_FSQRT, ISD::STRICT_FRINT, ISD::STRICT_FNEARBYINT,
714
- ISD::STRICT_FROUND, ISD::STRICT_FTRUNC, ISD::STRICT_FROUNDEVEN,
715
- ISD::STRICT_FMINNUM, ISD::STRICT_FMAXNUM, ISD::STRICT_FMINIMUM,
716
- ISD::STRICT_FMAXIMUM})
704
+ for (auto Op : {ISD::SETCC,
705
+ ISD::SELECT_CC,
706
+ ISD::BR_CC,
707
+ ISD::FADD,
708
+ ISD::FSUB,
709
+ ISD::FMUL,
710
+ ISD::FDIV,
711
+ ISD::FMA,
712
+ ISD::FCEIL,
713
+ ISD::FSQRT,
714
+ ISD::FFLOOR,
715
+ ISD::FNEARBYINT,
716
+ ISD::FRINT,
717
+ ISD::FROUND,
718
+ ISD::FROUNDEVEN,
719
+ ISD::FTRUNC,
720
+ ISD::FMINNUM,
721
+ ISD::FMAXNUM,
722
+ ISD::FMINIMUM,
723
+ ISD::FMAXIMUM,
724
+ ISD::STRICT_FADD,
725
+ ISD::STRICT_FSUB,
726
+ ISD::STRICT_FMUL,
727
+ ISD::STRICT_FDIV,
728
+ ISD::STRICT_FMA,
729
+ ISD::STRICT_FCEIL,
730
+ ISD::STRICT_FFLOOR,
731
+ ISD::STRICT_FSQRT,
732
+ ISD::STRICT_FRINT,
733
+ ISD::STRICT_FNEARBYINT,
734
+ ISD::STRICT_FROUND,
735
+ ISD::STRICT_FTRUNC,
736
+ ISD::STRICT_FROUNDEVEN,
737
+ ISD::STRICT_FMINNUM,
738
+ ISD::STRICT_FMAXNUM,
739
+ ISD::STRICT_FMINIMUM,
740
+ ISD::STRICT_FMAXIMUM})
717
741
setOperationAction(Op, ScalarVT, Promote);
718
742
743
+ for (auto Op : {ISD::FNEG, ISD::FABS})
744
+ setOperationAction(Op, ScalarVT, Legal);
745
+
719
746
// Round-to-integer need custom lowering for fp16, as Promote doesn't work
720
747
// because the result type is integer.
721
748
for (auto Op : {ISD::LROUND, ISD::LLROUND, ISD::LRINT, ISD::LLRINT,
@@ -730,8 +757,8 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
730
757
setOperationPromotedToType(ISD::FMUL, V4Narrow, MVT::v4f32);
731
758
setOperationPromotedToType(ISD::FDIV, V4Narrow, MVT::v4f32);
732
759
733
- setOperationAction(ISD::FABS, V4Narrow, Expand );
734
- setOperationAction(ISD::FNEG, V4Narrow, Expand );
760
+ setOperationAction(ISD::FABS, V4Narrow, Legal );
761
+ setOperationAction(ISD::FNEG, V4Narrow, Legal );
735
762
setOperationAction(ISD::FROUND, V4Narrow, Expand);
736
763
setOperationAction(ISD::FROUNDEVEN, V4Narrow, Expand);
737
764
setOperationAction(ISD::FMA, V4Narrow, Expand);
@@ -740,24 +767,24 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
740
767
setOperationAction(ISD::SELECT, V4Narrow, Expand);
741
768
setOperationAction(ISD::SELECT_CC, V4Narrow, Expand);
742
769
setOperationAction(ISD::FTRUNC, V4Narrow, Expand);
743
- setOperationAction(ISD::FCOPYSIGN, V4Narrow, Expand );
770
+ setOperationAction(ISD::FCOPYSIGN, V4Narrow, Custom );
744
771
setOperationAction(ISD::FFLOOR, V4Narrow, Expand);
745
772
setOperationAction(ISD::FCEIL, V4Narrow, Expand);
746
773
setOperationAction(ISD::FRINT, V4Narrow, Expand);
747
774
setOperationAction(ISD::FNEARBYINT, V4Narrow, Expand);
748
775
setOperationAction(ISD::FSQRT, V4Narrow, Expand);
749
776
750
777
auto V8Narrow = MVT::getVectorVT(ScalarVT, 8);
751
- setOperationAction(ISD::FABS, V8Narrow, Expand );
778
+ setOperationAction(ISD::FABS, V8Narrow, Legal );
752
779
setOperationAction(ISD::FADD, V8Narrow, Expand);
753
780
setOperationAction(ISD::FCEIL, V8Narrow, Expand);
754
- setOperationAction(ISD::FCOPYSIGN, V8Narrow, Expand );
781
+ setOperationAction(ISD::FCOPYSIGN, V8Narrow, Custom );
755
782
setOperationAction(ISD::FDIV, V8Narrow, Expand);
756
783
setOperationAction(ISD::FFLOOR, V8Narrow, Expand);
757
784
setOperationAction(ISD::FMA, V8Narrow, Expand);
758
785
setOperationAction(ISD::FMUL, V8Narrow, Expand);
759
786
setOperationAction(ISD::FNEARBYINT, V8Narrow, Expand);
760
- setOperationAction(ISD::FNEG, V8Narrow, Expand );
787
+ setOperationAction(ISD::FNEG, V8Narrow, Legal );
761
788
setOperationAction(ISD::FROUND, V8Narrow, Expand);
762
789
setOperationAction(ISD::FROUNDEVEN, V8Narrow, Expand);
763
790
setOperationAction(ISD::FRINT, V8Narrow, Expand);
@@ -1745,7 +1772,9 @@ void AArch64TargetLowering::addTypeForNEON(MVT VT) {
1745
1772
1746
1773
// But we do support custom-lowering for FCOPYSIGN.
1747
1774
if (VT == MVT::v2f32 || VT == MVT::v4f32 || VT == MVT::v2f64 ||
1748
- ((VT == MVT::v4f16 || VT == MVT::v8f16) && Subtarget->hasFullFP16()))
1775
+ ((VT == MVT::v4bf16 || VT == MVT::v8bf16 || VT == MVT::v4f16 ||
1776
+ VT == MVT::v8f16) &&
1777
+ Subtarget->hasFullFP16()))
1749
1778
setOperationAction(ISD::FCOPYSIGN, VT, Custom);
1750
1779
1751
1780
setOperationAction(ISD::EXTRACT_VECTOR_ELT, VT, Custom);
@@ -9208,7 +9237,7 @@ SDValue AArch64TargetLowering::LowerFCOPYSIGN(SDValue Op,
9208
9237
} else if (VT == MVT::f32) {
9209
9238
VecVT = MVT::v4i32;
9210
9239
SetVecVal(AArch64::ssub);
9211
- } else if (VT == MVT::f16) {
9240
+ } else if (VT == MVT::f16 || VT == MVT::bf16 ) {
9212
9241
VecVT = MVT::v8i16;
9213
9242
SetVecVal(AArch64::hsub);
9214
9243
} else {
@@ -9230,7 +9259,7 @@ SDValue AArch64TargetLowering::LowerFCOPYSIGN(SDValue Op,
9230
9259
9231
9260
SDValue BSP =
9232
9261
DAG.getNode(AArch64ISD::BSP, DL, VecVT, SignMaskV, VecVal1, VecVal2);
9233
- if (VT == MVT::f16)
9262
+ if (VT == MVT::f16 || VT == MVT::bf16 )
9234
9263
return DAG.getTargetExtractSubreg(AArch64::hsub, DL, VT, BSP);
9235
9264
if (VT == MVT::f32)
9236
9265
return DAG.getTargetExtractSubreg(AArch64::ssub, DL, VT, BSP);
0 commit comments