Skip to content

Commit 930e7ff

Browse files
committed
[AArch64] Optimize abs, neg and copysign for fp16/bf16
We can use bitwise arithmetic to implement these, making them considerably faster than legalization via promotion.
1 parent 0597644 commit 930e7ff

File tree

7 files changed

+241
-1307
lines changed

7 files changed

+241
-1307
lines changed

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 56 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -676,11 +676,13 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
676676
setOperationAction(ISD::FPOW, MVT::f64, Expand);
677677
setOperationAction(ISD::FCOPYSIGN, MVT::f64, Custom);
678678
setOperationAction(ISD::FCOPYSIGN, MVT::f32, Custom);
679-
if (Subtarget->hasFullFP16())
679+
if (Subtarget->hasFullFP16()) {
680680
setOperationAction(ISD::FCOPYSIGN, MVT::f16, Custom);
681-
else
681+
setOperationAction(ISD::FCOPYSIGN, MVT::bf16, Custom);
682+
} else {
682683
setOperationAction(ISD::FCOPYSIGN, MVT::f16, Promote);
683-
setOperationAction(ISD::FCOPYSIGN, MVT::bf16, Promote);
684+
setOperationAction(ISD::FCOPYSIGN, MVT::bf16, Promote);
685+
}
684686

685687
for (auto Op : {ISD::FREM, ISD::FPOW, ISD::FPOWI,
686688
ISD::FCOS, ISD::FSIN, ISD::FSINCOS,
@@ -699,23 +701,48 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
699701
}
700702

701703
auto LegalizeNarrowFP = [this](MVT ScalarVT) {
702-
for (auto Op :
703-
{ISD::SETCC, ISD::SELECT_CC,
704-
ISD::BR_CC, ISD::FADD, ISD::FSUB,
705-
ISD::FMUL, ISD::FDIV, ISD::FMA,
706-
ISD::FNEG, ISD::FABS, ISD::FCEIL,
707-
ISD::FSQRT, ISD::FFLOOR, ISD::FNEARBYINT,
708-
ISD::FRINT, ISD::FROUND, ISD::FROUNDEVEN,
709-
ISD::FTRUNC, ISD::FMINNUM, ISD::FMAXNUM,
710-
ISD::FMINIMUM, ISD::FMAXIMUM, ISD::STRICT_FADD,
711-
ISD::STRICT_FSUB, ISD::STRICT_FMUL, ISD::STRICT_FDIV,
712-
ISD::STRICT_FMA, ISD::STRICT_FCEIL, ISD::STRICT_FFLOOR,
713-
ISD::STRICT_FSQRT, ISD::STRICT_FRINT, ISD::STRICT_FNEARBYINT,
714-
ISD::STRICT_FROUND, ISD::STRICT_FTRUNC, ISD::STRICT_FROUNDEVEN,
715-
ISD::STRICT_FMINNUM, ISD::STRICT_FMAXNUM, ISD::STRICT_FMINIMUM,
716-
ISD::STRICT_FMAXIMUM})
704+
for (auto Op : {ISD::SETCC,
705+
ISD::SELECT_CC,
706+
ISD::BR_CC,
707+
ISD::FADD,
708+
ISD::FSUB,
709+
ISD::FMUL,
710+
ISD::FDIV,
711+
ISD::FMA,
712+
ISD::FCEIL,
713+
ISD::FSQRT,
714+
ISD::FFLOOR,
715+
ISD::FNEARBYINT,
716+
ISD::FRINT,
717+
ISD::FROUND,
718+
ISD::FROUNDEVEN,
719+
ISD::FTRUNC,
720+
ISD::FMINNUM,
721+
ISD::FMAXNUM,
722+
ISD::FMINIMUM,
723+
ISD::FMAXIMUM,
724+
ISD::STRICT_FADD,
725+
ISD::STRICT_FSUB,
726+
ISD::STRICT_FMUL,
727+
ISD::STRICT_FDIV,
728+
ISD::STRICT_FMA,
729+
ISD::STRICT_FCEIL,
730+
ISD::STRICT_FFLOOR,
731+
ISD::STRICT_FSQRT,
732+
ISD::STRICT_FRINT,
733+
ISD::STRICT_FNEARBYINT,
734+
ISD::STRICT_FROUND,
735+
ISD::STRICT_FTRUNC,
736+
ISD::STRICT_FROUNDEVEN,
737+
ISD::STRICT_FMINNUM,
738+
ISD::STRICT_FMAXNUM,
739+
ISD::STRICT_FMINIMUM,
740+
ISD::STRICT_FMAXIMUM})
717741
setOperationAction(Op, ScalarVT, Promote);
718742

743+
for (auto Op : {ISD::FNEG, ISD::FABS})
744+
setOperationAction(Op, ScalarVT, Legal);
745+
719746
// Round-to-integer need custom lowering for fp16, as Promote doesn't work
720747
// because the result type is integer.
721748
for (auto Op : {ISD::LROUND, ISD::LLROUND, ISD::LRINT, ISD::LLRINT,
@@ -730,8 +757,8 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
730757
setOperationPromotedToType(ISD::FMUL, V4Narrow, MVT::v4f32);
731758
setOperationPromotedToType(ISD::FDIV, V4Narrow, MVT::v4f32);
732759

733-
setOperationAction(ISD::FABS, V4Narrow, Expand);
734-
setOperationAction(ISD::FNEG, V4Narrow, Expand);
760+
setOperationAction(ISD::FABS, V4Narrow, Legal);
761+
setOperationAction(ISD::FNEG, V4Narrow, Legal);
735762
setOperationAction(ISD::FROUND, V4Narrow, Expand);
736763
setOperationAction(ISD::FROUNDEVEN, V4Narrow, Expand);
737764
setOperationAction(ISD::FMA, V4Narrow, Expand);
@@ -740,24 +767,24 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
740767
setOperationAction(ISD::SELECT, V4Narrow, Expand);
741768
setOperationAction(ISD::SELECT_CC, V4Narrow, Expand);
742769
setOperationAction(ISD::FTRUNC, V4Narrow, Expand);
743-
setOperationAction(ISD::FCOPYSIGN, V4Narrow, Expand);
770+
setOperationAction(ISD::FCOPYSIGN, V4Narrow, Custom);
744771
setOperationAction(ISD::FFLOOR, V4Narrow, Expand);
745772
setOperationAction(ISD::FCEIL, V4Narrow, Expand);
746773
setOperationAction(ISD::FRINT, V4Narrow, Expand);
747774
setOperationAction(ISD::FNEARBYINT, V4Narrow, Expand);
748775
setOperationAction(ISD::FSQRT, V4Narrow, Expand);
749776

750777
auto V8Narrow = MVT::getVectorVT(ScalarVT, 8);
751-
setOperationAction(ISD::FABS, V8Narrow, Expand);
778+
setOperationAction(ISD::FABS, V8Narrow, Legal);
752779
setOperationAction(ISD::FADD, V8Narrow, Expand);
753780
setOperationAction(ISD::FCEIL, V8Narrow, Expand);
754-
setOperationAction(ISD::FCOPYSIGN, V8Narrow, Expand);
781+
setOperationAction(ISD::FCOPYSIGN, V8Narrow, Custom);
755782
setOperationAction(ISD::FDIV, V8Narrow, Expand);
756783
setOperationAction(ISD::FFLOOR, V8Narrow, Expand);
757784
setOperationAction(ISD::FMA, V8Narrow, Expand);
758785
setOperationAction(ISD::FMUL, V8Narrow, Expand);
759786
setOperationAction(ISD::FNEARBYINT, V8Narrow, Expand);
760-
setOperationAction(ISD::FNEG, V8Narrow, Expand);
787+
setOperationAction(ISD::FNEG, V8Narrow, Legal);
761788
setOperationAction(ISD::FROUND, V8Narrow, Expand);
762789
setOperationAction(ISD::FROUNDEVEN, V8Narrow, Expand);
763790
setOperationAction(ISD::FRINT, V8Narrow, Expand);
@@ -1745,7 +1772,9 @@ void AArch64TargetLowering::addTypeForNEON(MVT VT) {
17451772

17461773
// But we do support custom-lowering for FCOPYSIGN.
17471774
if (VT == MVT::v2f32 || VT == MVT::v4f32 || VT == MVT::v2f64 ||
1748-
((VT == MVT::v4f16 || VT == MVT::v8f16) && Subtarget->hasFullFP16()))
1775+
((VT == MVT::v4bf16 || VT == MVT::v8bf16 || VT == MVT::v4f16 ||
1776+
VT == MVT::v8f16) &&
1777+
Subtarget->hasFullFP16()))
17491778
setOperationAction(ISD::FCOPYSIGN, VT, Custom);
17501779

17511780
setOperationAction(ISD::EXTRACT_VECTOR_ELT, VT, Custom);
@@ -9208,7 +9237,7 @@ SDValue AArch64TargetLowering::LowerFCOPYSIGN(SDValue Op,
92089237
} else if (VT == MVT::f32) {
92099238
VecVT = MVT::v4i32;
92109239
SetVecVal(AArch64::ssub);
9211-
} else if (VT == MVT::f16) {
9240+
} else if (VT == MVT::f16 || VT == MVT::bf16) {
92129241
VecVT = MVT::v8i16;
92139242
SetVecVal(AArch64::hsub);
92149243
} else {
@@ -9230,7 +9259,7 @@ SDValue AArch64TargetLowering::LowerFCOPYSIGN(SDValue Op,
92309259

92319260
SDValue BSP =
92329261
DAG.getNode(AArch64ISD::BSP, DL, VecVT, SignMaskV, VecVal1, VecVal2);
9233-
if (VT == MVT::f16)
9262+
if (VT == MVT::f16 || VT == MVT::bf16)
92349263
return DAG.getTargetExtractSubreg(AArch64::hsub, DL, VT, BSP);
92359264
if (VT == MVT::f32)
92369265
return DAG.getTargetExtractSubreg(AArch64::ssub, DL, VT, BSP);

llvm/lib/Target/AArch64/AArch64InstrInfo.td

Lines changed: 36 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5028,9 +5028,6 @@ defm FCVTNU : SIMDTwoVectorFPToInt<1,0,0b11010, "fcvtnu",int_aarch64_neon_fcvtnu
50285028
defm FCVTN : SIMDFPNarrowTwoVector<0, 0, 0b10110, "fcvtn">;
50295029
def : Pat<(v4i16 (int_aarch64_neon_vcvtfp2hf (v4f32 V128:$Rn))),
50305030
(FCVTNv4i16 V128:$Rn)>;
5031-
//def : Pat<(concat_vectors V64:$Rd,
5032-
// (v4bf16 (any_fpround (v4f32 V128:$Rn)))),
5033-
// (FCVTNv8bf16 (INSERT_SUBREG (IMPLICIT_DEF), V64:$Rd, dsub), V128:$Rn)>;
50345031
def : Pat<(concat_vectors V64:$Rd,
50355032
(v4i16 (int_aarch64_neon_vcvtfp2hf (v4f32 V128:$Rn)))),
50365033
(FCVTNv8i16 (INSERT_SUBREG (IMPLICIT_DEF), V64:$Rd, dsub), V128:$Rn)>;
@@ -7813,6 +7810,42 @@ def : InstAlias<"uxtl2 $dst.2d, $src1.4s",
78137810
(USHLLv4i32_shift V128:$dst, V128:$src1, 0)>;
78147811
}
78157812

7813+
def abs_f16 :
7814+
OutPatFrag<(ops node:$Rn),
7815+
(EXTRACT_SUBREG (f32 (COPY_TO_REGCLASS
7816+
(i32 (ANDWri
7817+
(i32 (COPY_TO_REGCLASS (INSERT_SUBREG (f32 (IMPLICIT_DEF)),
7818+
node:$Rn, hsub), GPR32)),
7819+
(i32 (logical_imm32_XFORM(i32 0x7fff))))),
7820+
FPR32)), hsub)>;
7821+
7822+
def : Pat<(f16 (fabs (f16 FPR16:$Rn))), (f16 (abs_f16 (f16 FPR16:$Rn)))>;
7823+
def : Pat<(bf16 (fabs (bf16 FPR16:$Rn))), (bf16 (abs_f16 (bf16 FPR16:$Rn)))>;
7824+
7825+
def neg_f16 :
7826+
OutPatFrag<(ops node:$Rn),
7827+
(EXTRACT_SUBREG (f32 (COPY_TO_REGCLASS
7828+
(i32 (EORWri
7829+
(i32 (COPY_TO_REGCLASS (INSERT_SUBREG (f32 (IMPLICIT_DEF)),
7830+
node:$Rn, hsub), GPR32)),
7831+
(i32 (logical_imm32_XFORM(i32 0x8000))))),
7832+
FPR32)), hsub)>;
7833+
7834+
def : Pat<(f16 (fneg (f16 FPR16:$Rn))), (f16 (neg_f16 (f16 FPR16:$Rn)))>;
7835+
def : Pat<(bf16 (fneg (bf16 FPR16:$Rn))), (bf16 (neg_f16 (bf16 FPR16:$Rn)))>;
7836+
7837+
let Predicates = [HasNEON] in {
7838+
def : Pat<(v4f16 (fabs (v4f16 V64:$Rn))), (v4f16 (BICv4i16 (v4f16 V64:$Rn), (i32 128), (i32 8)))>;
7839+
def : Pat<(v4bf16 (fabs (v4bf16 V64:$Rn))), (v4bf16 (BICv4i16 (v4bf16 V64:$Rn), (i32 128), (i32 8)))>;
7840+
def : Pat<(v8f16 (fabs (v8f16 V128:$Rn))), (v8f16 (BICv8i16 (v8f16 V128:$Rn), (i32 128), (i32 8)))>;
7841+
def : Pat<(v8bf16 (fabs (v8bf16 V128:$Rn))), (v8bf16 (BICv8i16 (v8bf16 V128:$Rn), (i32 128), (i32 8)))>;
7842+
7843+
def : Pat<(v4f16 (fneg (v4f16 V64:$Rn))), (v4f16 (EORv8i8 (v4f16 V64:$Rn), (MOVIv4i16 (i32 128), (i32 8))))>;
7844+
def : Pat<(v4bf16 (fneg (v4bf16 V64:$Rn))), (v4bf16 (EORv8i8 (v4bf16 V64:$Rn), (v4i16 (MOVIv4i16 (i32 0x80), (i32 8)))))>;
7845+
def : Pat<(v8f16 (fneg (v8f16 V128:$Rn))), (v8f16 (EORv16i8 (v8f16 V128:$Rn), (MOVIv8i16 (i32 128), (i32 8))))>;
7846+
def : Pat<(v8bf16 (fneg (v8bf16 V128:$Rn))), (v8bf16 (EORv16i8 (v8bf16 V128:$Rn), (v8i16 (MOVIv8i16 (i32 0x80), (i32 8)))))>;
7847+
}
7848+
78167849
// If an integer is about to be converted to a floating point value,
78177850
// just load it on the floating point unit.
78187851
// These patterns are more complex because floating point loads do not

llvm/test/CodeGen/AArch64/f16-instructions.ll

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1027,9 +1027,9 @@ define half @test_fma(half %a, half %b, half %c) #0 {
10271027
}
10281028

10291029
; CHECK-CVT-LABEL: test_fabs:
1030-
; CHECK-CVT-NEXT: fcvt s0, h0
1031-
; CHECK-CVT-NEXT: fabs s0, s0
1032-
; CHECK-CVT-NEXT: fcvt h0, s0
1030+
; CHECK-CVT-NEXT: fmov w8, s0
1031+
; CHECK-CVT-NEXT: and w8, w8, #0x7fff
1032+
; CHECK-CVT-NEXT: fmov s0, w8
10331033
; CHECK-CVT-NEXT: ret
10341034

10351035
; CHECK-FP16-LABEL: test_fabs:
@@ -1338,3 +1338,12 @@ define half @test_fmuladd(half %a, half %b, half %c) #0 {
13381338
}
13391339

13401340
attributes #0 = { nounwind }
1341+
;; NOTE: These prefixes are unused and the list is autogenerated. Do not add tests below this line:
1342+
; CHECK-COMMON: {{.*}}
1343+
; CHECK-CVT: {{.*}}
1344+
; CHECK-FP16: {{.*}}
1345+
; FALLBACK: {{.*}}
1346+
; FALLBACK-FP16: {{.*}}
1347+
; GISEL: {{.*}}
1348+
; GISEL-CVT: {{.*}}
1349+
; GISEL-FP16: {{.*}}

0 commit comments

Comments
 (0)