Skip to content

[RISCV] Custom legalize f16/bf16 FNEG/FABS with Zfhmin/Zbfmin. #106886

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Sep 1, 2024

Conversation

topperc
Copy link
Collaborator

@topperc topperc commented Sep 1, 2024

The LegalizeDAG expansion will go through memory since i16 isn't a legal type. Avoid this by using FMV nodes.

The LegalizeDAG expansion will go through memory since i16 isn't
a legal type. Avoid this by using FMV nodes.
@llvmbot
Copy link
Member

llvmbot commented Sep 1, 2024

@llvm/pr-subscribers-backend-risc-v

Author: Craig Topper (topperc)

Changes

The LegalizeDAG expansion will go through memory since i16 isn't a legal type. Avoid this by using FMV nodes.


Patch is 146.82 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/106886.diff

8 Files Affected:

  • (modified) llvm/lib/Target/RISCV/RISCVISelLowering.cpp (+32-6)
  • (modified) llvm/test/CodeGen/RISCV/bfloat-arith.ll (+225-434)
  • (modified) llvm/test/CodeGen/RISCV/copysign-casts.ll (+70-55)
  • (modified) llvm/test/CodeGen/RISCV/half-arith-strict.ll (+175-449)
  • (modified) llvm/test/CodeGen/RISCV/half-arith.ll (+431-835)
  • (modified) llvm/test/CodeGen/RISCV/half-bitmanip-dagcombines.ll (+46-50)
  • (modified) llvm/test/CodeGen/RISCV/half-intrinsics.ll (+54-44)
  • (modified) llvm/test/CodeGen/RISCV/rvv/fixed-vectors-fp.ll (+31-158)
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 47b43201105234..992d9a85c5259f 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -459,8 +459,8 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
     setOperationAction(ISD::BR_CC, MVT::bf16, Expand);
     setOperationAction(ZfhminZfbfminPromoteOps, MVT::bf16, Promote);
     setOperationAction(ISD::FREM, MVT::bf16, Promote);
-    setOperationAction(ISD::FABS, MVT::bf16, Expand);
-    setOperationAction(ISD::FNEG, MVT::bf16, Expand);
+    setOperationAction(ISD::FABS, MVT::bf16, Custom);
+    setOperationAction(ISD::FNEG, MVT::bf16, Custom);
     setOperationAction(ISD::FCOPYSIGN, MVT::bf16, Expand);
   }
 
@@ -476,8 +476,8 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
       setOperationAction({ISD::STRICT_LRINT, ISD::STRICT_LLRINT,
                           ISD::STRICT_LROUND, ISD::STRICT_LLROUND},
                          MVT::f16, Legal);
-      setOperationAction(ISD::FABS, MVT::f16, Expand);
-      setOperationAction(ISD::FNEG, MVT::f16, Expand);
+      setOperationAction(ISD::FABS, MVT::f16, Custom);
+      setOperationAction(ISD::FNEG, MVT::f16, Custom);
       setOperationAction(ISD::FCOPYSIGN, MVT::f16, Expand);
     }
 
@@ -5942,6 +5942,29 @@ static SDValue lowerFMAXIMUM_FMINIMUM(SDValue Op, SelectionDAG &DAG,
   return Res;
 }
 
+static SDValue lowerFABSorFNEG(SDValue Op, SelectionDAG &DAG,
+                               const RISCVSubtarget &Subtarget) {
+  bool IsFABS = Op.getOpcode() == ISD::FABS;
+  assert((IsFABS || Op.getOpcode() == ISD::FNEG) &&
+         "Wrong opcode for lowering FABS or FNEG.");
+
+  MVT XLenVT = Subtarget.getXLenVT();
+  MVT VT = Op.getSimpleValueType();
+  assert((VT == MVT::f16 || VT == MVT::bf16) && "Unexpected type");
+
+  SDLoc DL(Op);
+  SDValue Fmv =
+      DAG.getNode(RISCVISD::FMV_X_ANYEXTH, DL, XLenVT, Op.getOperand(0));
+
+  APInt Mask = IsFABS ? APInt::getSignedMaxValue(16) : APInt::getSignMask(16);
+  Mask = Mask.sext(Subtarget.getXLen());
+
+  unsigned LogicOpc = IsFABS ? ISD::AND : ISD::XOR;
+  SDValue Logic =
+      DAG.getNode(LogicOpc, DL, XLenVT, Fmv, DAG.getConstant(Mask, DL, XLenVT));
+  return DAG.getNode(RISCVISD::FMV_H_X, DL, VT, Logic);
+}
+
 /// Get a RISC-V target specified VL op for a given SDNode.
 static unsigned getRISCVVLOp(SDValue Op) {
 #define OP_CASE(NODE)                                                          \
@@ -7071,12 +7094,15 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op,
     assert(Op.getOperand(1).getValueType() == MVT::i32 && Subtarget.is64Bit() &&
            "Unexpected custom legalisation");
     return SDValue();
+  case ISD::FABS:
+  case ISD::FNEG:
+    if (Op.getValueType() == MVT::f16 || Op.getValueType() == MVT::bf16)
+      return lowerFABSorFNEG(Op, DAG, Subtarget);
+    [[fallthrough]];
   case ISD::FADD:
   case ISD::FSUB:
   case ISD::FMUL:
   case ISD::FDIV:
-  case ISD::FNEG:
-  case ISD::FABS:
   case ISD::FSQRT:
   case ISD::FMA:
   case ISD::FMINNUM:
diff --git a/llvm/test/CodeGen/RISCV/bfloat-arith.ll b/llvm/test/CodeGen/RISCV/bfloat-arith.ll
index 56a30dd0f6ffee..20150b5994b78d 100644
--- a/llvm/test/CodeGen/RISCV/bfloat-arith.ll
+++ b/llvm/test/CodeGen/RISCV/bfloat-arith.ll
@@ -75,14 +75,19 @@ define bfloat @fsgnj_s(bfloat %a, bfloat %b) nounwind {
 ; RV32IZFBFMIN:       # %bb.0:
 ; RV32IZFBFMIN-NEXT:    addi sp, sp, -16
 ; RV32IZFBFMIN-NEXT:    fsh fa1, 12(sp)
-; RV32IZFBFMIN-NEXT:    fsh fa0, 8(sp)
 ; RV32IZFBFMIN-NEXT:    lbu a0, 13(sp)
-; RV32IZFBFMIN-NEXT:    lbu a1, 9(sp)
-; RV32IZFBFMIN-NEXT:    andi a0, a0, 128
-; RV32IZFBFMIN-NEXT:    andi a1, a1, 127
-; RV32IZFBFMIN-NEXT:    or a0, a1, a0
-; RV32IZFBFMIN-NEXT:    sb a0, 9(sp)
-; RV32IZFBFMIN-NEXT:    flh fa0, 8(sp)
+; RV32IZFBFMIN-NEXT:    fmv.x.h a1, fa0
+; RV32IZFBFMIN-NEXT:    slli a1, a1, 17
+; RV32IZFBFMIN-NEXT:    andi a2, a0, 128
+; RV32IZFBFMIN-NEXT:    srli a0, a1, 17
+; RV32IZFBFMIN-NEXT:    beqz a2, .LBB5_2
+; RV32IZFBFMIN-NEXT:  # %bb.1:
+; RV32IZFBFMIN-NEXT:    lui a1, 1048568
+; RV32IZFBFMIN-NEXT:    or a0, a0, a1
+; RV32IZFBFMIN-NEXT:  .LBB5_2:
+; RV32IZFBFMIN-NEXT:    fmv.h.x fa5, a0
+; RV32IZFBFMIN-NEXT:    fcvt.s.bf16 fa5, fa5
+; RV32IZFBFMIN-NEXT:    fcvt.bf16.s fa0, fa5
 ; RV32IZFBFMIN-NEXT:    addi sp, sp, 16
 ; RV32IZFBFMIN-NEXT:    ret
 ;
@@ -90,14 +95,19 @@ define bfloat @fsgnj_s(bfloat %a, bfloat %b) nounwind {
 ; RV64IZFBFMIN:       # %bb.0:
 ; RV64IZFBFMIN-NEXT:    addi sp, sp, -16
 ; RV64IZFBFMIN-NEXT:    fsh fa1, 8(sp)
-; RV64IZFBFMIN-NEXT:    fsh fa0, 0(sp)
 ; RV64IZFBFMIN-NEXT:    lbu a0, 9(sp)
-; RV64IZFBFMIN-NEXT:    lbu a1, 1(sp)
-; RV64IZFBFMIN-NEXT:    andi a0, a0, 128
-; RV64IZFBFMIN-NEXT:    andi a1, a1, 127
-; RV64IZFBFMIN-NEXT:    or a0, a1, a0
-; RV64IZFBFMIN-NEXT:    sb a0, 1(sp)
-; RV64IZFBFMIN-NEXT:    flh fa0, 0(sp)
+; RV64IZFBFMIN-NEXT:    fmv.x.h a1, fa0
+; RV64IZFBFMIN-NEXT:    slli a1, a1, 49
+; RV64IZFBFMIN-NEXT:    andi a2, a0, 128
+; RV64IZFBFMIN-NEXT:    srli a0, a1, 49
+; RV64IZFBFMIN-NEXT:    beqz a2, .LBB5_2
+; RV64IZFBFMIN-NEXT:  # %bb.1:
+; RV64IZFBFMIN-NEXT:    lui a1, 1048568
+; RV64IZFBFMIN-NEXT:    or a0, a0, a1
+; RV64IZFBFMIN-NEXT:  .LBB5_2:
+; RV64IZFBFMIN-NEXT:    fmv.h.x fa5, a0
+; RV64IZFBFMIN-NEXT:    fcvt.s.bf16 fa5, fa5
+; RV64IZFBFMIN-NEXT:    fcvt.bf16.s fa0, fa5
 ; RV64IZFBFMIN-NEXT:    addi sp, sp, 16
 ; RV64IZFBFMIN-NEXT:    ret
   %1 = call bfloat @llvm.copysign.bf16(bfloat %a, bfloat %b)
@@ -105,39 +115,19 @@ define bfloat @fsgnj_s(bfloat %a, bfloat %b) nounwind {
 }
 
 define i32 @fneg_s(bfloat %a, bfloat %b) nounwind {
-; RV32IZFBFMIN-LABEL: fneg_s:
-; RV32IZFBFMIN:       # %bb.0:
-; RV32IZFBFMIN-NEXT:    addi sp, sp, -16
-; RV32IZFBFMIN-NEXT:    fcvt.s.bf16 fa5, fa0
-; RV32IZFBFMIN-NEXT:    fadd.s fa5, fa5, fa5
-; RV32IZFBFMIN-NEXT:    fcvt.bf16.s fa5, fa5
-; RV32IZFBFMIN-NEXT:    fsh fa5, 12(sp)
-; RV32IZFBFMIN-NEXT:    lbu a0, 13(sp)
-; RV32IZFBFMIN-NEXT:    xori a0, a0, 128
-; RV32IZFBFMIN-NEXT:    sb a0, 13(sp)
-; RV32IZFBFMIN-NEXT:    flh fa4, 12(sp)
-; RV32IZFBFMIN-NEXT:    fcvt.s.bf16 fa5, fa5
-; RV32IZFBFMIN-NEXT:    fcvt.s.bf16 fa4, fa4
-; RV32IZFBFMIN-NEXT:    feq.s a0, fa5, fa4
-; RV32IZFBFMIN-NEXT:    addi sp, sp, 16
-; RV32IZFBFMIN-NEXT:    ret
-;
-; RV64IZFBFMIN-LABEL: fneg_s:
-; RV64IZFBFMIN:       # %bb.0:
-; RV64IZFBFMIN-NEXT:    addi sp, sp, -16
-; RV64IZFBFMIN-NEXT:    fcvt.s.bf16 fa5, fa0
-; RV64IZFBFMIN-NEXT:    fadd.s fa5, fa5, fa5
-; RV64IZFBFMIN-NEXT:    fcvt.bf16.s fa5, fa5
-; RV64IZFBFMIN-NEXT:    fsh fa5, 8(sp)
-; RV64IZFBFMIN-NEXT:    lbu a0, 9(sp)
-; RV64IZFBFMIN-NEXT:    xori a0, a0, 128
-; RV64IZFBFMIN-NEXT:    sb a0, 9(sp)
-; RV64IZFBFMIN-NEXT:    flh fa4, 8(sp)
-; RV64IZFBFMIN-NEXT:    fcvt.s.bf16 fa5, fa5
-; RV64IZFBFMIN-NEXT:    fcvt.s.bf16 fa4, fa4
-; RV64IZFBFMIN-NEXT:    feq.s a0, fa5, fa4
-; RV64IZFBFMIN-NEXT:    addi sp, sp, 16
-; RV64IZFBFMIN-NEXT:    ret
+; CHECK-LABEL: fneg_s:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    fcvt.s.bf16 fa5, fa0
+; CHECK-NEXT:    fadd.s fa5, fa5, fa5
+; CHECK-NEXT:    fcvt.bf16.s fa5, fa5
+; CHECK-NEXT:    fmv.x.h a0, fa5
+; CHECK-NEXT:    lui a1, 1048568
+; CHECK-NEXT:    xor a0, a0, a1
+; CHECK-NEXT:    fmv.h.x fa4, a0
+; CHECK-NEXT:    fcvt.s.bf16 fa4, fa4
+; CHECK-NEXT:    fcvt.s.bf16 fa5, fa5
+; CHECK-NEXT:    feq.s a0, fa5, fa4
+; CHECK-NEXT:    ret
   %1 = fadd bfloat %a, %a
   %2 = fneg bfloat %1
   %3 = fcmp oeq bfloat %1, %2
@@ -153,45 +143,57 @@ define bfloat @fsgnjn_s(bfloat %a, bfloat %b) nounwind {
 ; RV32IZFBFMIN-NEXT:    fcvt.s.bf16 fa4, fa0
 ; RV32IZFBFMIN-NEXT:    fadd.s fa5, fa4, fa5
 ; RV32IZFBFMIN-NEXT:    fcvt.bf16.s fa5, fa5
-; RV32IZFBFMIN-NEXT:    fsh fa5, 4(sp)
-; RV32IZFBFMIN-NEXT:    lbu a0, 5(sp)
-; RV32IZFBFMIN-NEXT:    xori a0, a0, 128
-; RV32IZFBFMIN-NEXT:    sb a0, 5(sp)
-; RV32IZFBFMIN-NEXT:    flh fa5, 4(sp)
-; RV32IZFBFMIN-NEXT:    fsh fa0, 8(sp)
+; RV32IZFBFMIN-NEXT:    fmv.x.h a1, fa5
+; RV32IZFBFMIN-NEXT:    lui a0, 1048568
+; RV32IZFBFMIN-NEXT:    xor a1, a1, a0
+; RV32IZFBFMIN-NEXT:    fmv.h.x fa5, a1
 ; RV32IZFBFMIN-NEXT:    fsh fa5, 12(sp)
-; RV32IZFBFMIN-NEXT:    lbu a0, 9(sp)
 ; RV32IZFBFMIN-NEXT:    lbu a1, 13(sp)
-; RV32IZFBFMIN-NEXT:    andi a0, a0, 127
-; RV32IZFBFMIN-NEXT:    andi a1, a1, 128
-; RV32IZFBFMIN-NEXT:    or a0, a0, a1
-; RV32IZFBFMIN-NEXT:    sb a0, 9(sp)
-; RV32IZFBFMIN-NEXT:    flh fa0, 8(sp)
+; RV32IZFBFMIN-NEXT:    fmv.x.h a2, fa0
+; RV32IZFBFMIN-NEXT:    slli a2, a2, 17
+; RV32IZFBFMIN-NEXT:    andi a3, a1, 128
+; RV32IZFBFMIN-NEXT:    srli a1, a2, 17
+; RV32IZFBFMIN-NEXT:    bnez a3, .LBB7_2
+; RV32IZFBFMIN-NEXT:  # %bb.1:
+; RV32IZFBFMIN-NEXT:    fmv.h.x fa5, a1
+; RV32IZFBFMIN-NEXT:    j .LBB7_3
+; RV32IZFBFMIN-NEXT:  .LBB7_2:
+; RV32IZFBFMIN-NEXT:    or a0, a1, a0
+; RV32IZFBFMIN-NEXT:    fmv.h.x fa5, a0
+; RV32IZFBFMIN-NEXT:  .LBB7_3:
+; RV32IZFBFMIN-NEXT:    fcvt.s.bf16 fa5, fa5
+; RV32IZFBFMIN-NEXT:    fcvt.bf16.s fa0, fa5
 ; RV32IZFBFMIN-NEXT:    addi sp, sp, 16
 ; RV32IZFBFMIN-NEXT:    ret
 ;
 ; RV64IZFBFMIN-LABEL: fsgnjn_s:
 ; RV64IZFBFMIN:       # %bb.0:
-; RV64IZFBFMIN-NEXT:    addi sp, sp, -32
+; RV64IZFBFMIN-NEXT:    addi sp, sp, -16
 ; RV64IZFBFMIN-NEXT:    fcvt.s.bf16 fa5, fa1
 ; RV64IZFBFMIN-NEXT:    fcvt.s.bf16 fa4, fa0
 ; RV64IZFBFMIN-NEXT:    fadd.s fa5, fa4, fa5
 ; RV64IZFBFMIN-NEXT:    fcvt.bf16.s fa5, fa5
+; RV64IZFBFMIN-NEXT:    fmv.x.h a1, fa5
+; RV64IZFBFMIN-NEXT:    lui a0, 1048568
+; RV64IZFBFMIN-NEXT:    xor a1, a1, a0
+; RV64IZFBFMIN-NEXT:    fmv.h.x fa5, a1
 ; RV64IZFBFMIN-NEXT:    fsh fa5, 8(sp)
-; RV64IZFBFMIN-NEXT:    lbu a0, 9(sp)
-; RV64IZFBFMIN-NEXT:    xori a0, a0, 128
-; RV64IZFBFMIN-NEXT:    sb a0, 9(sp)
-; RV64IZFBFMIN-NEXT:    flh fa5, 8(sp)
-; RV64IZFBFMIN-NEXT:    fsh fa0, 16(sp)
-; RV64IZFBFMIN-NEXT:    fsh fa5, 24(sp)
-; RV64IZFBFMIN-NEXT:    lbu a0, 17(sp)
-; RV64IZFBFMIN-NEXT:    lbu a1, 25(sp)
-; RV64IZFBFMIN-NEXT:    andi a0, a0, 127
-; RV64IZFBFMIN-NEXT:    andi a1, a1, 128
-; RV64IZFBFMIN-NEXT:    or a0, a0, a1
-; RV64IZFBFMIN-NEXT:    sb a0, 17(sp)
-; RV64IZFBFMIN-NEXT:    flh fa0, 16(sp)
-; RV64IZFBFMIN-NEXT:    addi sp, sp, 32
+; RV64IZFBFMIN-NEXT:    lbu a1, 9(sp)
+; RV64IZFBFMIN-NEXT:    fmv.x.h a2, fa0
+; RV64IZFBFMIN-NEXT:    slli a2, a2, 49
+; RV64IZFBFMIN-NEXT:    andi a3, a1, 128
+; RV64IZFBFMIN-NEXT:    srli a1, a2, 49
+; RV64IZFBFMIN-NEXT:    bnez a3, .LBB7_2
+; RV64IZFBFMIN-NEXT:  # %bb.1:
+; RV64IZFBFMIN-NEXT:    fmv.h.x fa5, a1
+; RV64IZFBFMIN-NEXT:    j .LBB7_3
+; RV64IZFBFMIN-NEXT:  .LBB7_2:
+; RV64IZFBFMIN-NEXT:    or a0, a1, a0
+; RV64IZFBFMIN-NEXT:    fmv.h.x fa5, a0
+; RV64IZFBFMIN-NEXT:  .LBB7_3:
+; RV64IZFBFMIN-NEXT:    fcvt.s.bf16 fa5, fa5
+; RV64IZFBFMIN-NEXT:    fcvt.bf16.s fa0, fa5
+; RV64IZFBFMIN-NEXT:    addi sp, sp, 16
 ; RV64IZFBFMIN-NEXT:    ret
   %1 = fadd bfloat %a, %b
   %2 = fneg bfloat %1
@@ -204,40 +206,34 @@ declare bfloat @llvm.fabs.bf16(bfloat)
 define bfloat @fabs_s(bfloat %a, bfloat %b) nounwind {
 ; RV32IZFBFMIN-LABEL: fabs_s:
 ; RV32IZFBFMIN:       # %bb.0:
-; RV32IZFBFMIN-NEXT:    addi sp, sp, -16
 ; RV32IZFBFMIN-NEXT:    fcvt.s.bf16 fa5, fa1
 ; RV32IZFBFMIN-NEXT:    fcvt.s.bf16 fa4, fa0
 ; RV32IZFBFMIN-NEXT:    fadd.s fa5, fa4, fa5
 ; RV32IZFBFMIN-NEXT:    fcvt.bf16.s fa5, fa5
-; RV32IZFBFMIN-NEXT:    fsh fa5, 12(sp)
-; RV32IZFBFMIN-NEXT:    lbu a0, 13(sp)
-; RV32IZFBFMIN-NEXT:    andi a0, a0, 127
-; RV32IZFBFMIN-NEXT:    sb a0, 13(sp)
-; RV32IZFBFMIN-NEXT:    flh fa4, 12(sp)
+; RV32IZFBFMIN-NEXT:    fmv.x.h a0, fa5
+; RV32IZFBFMIN-NEXT:    slli a0, a0, 17
+; RV32IZFBFMIN-NEXT:    srli a0, a0, 17
+; RV32IZFBFMIN-NEXT:    fmv.h.x fa4, a0
 ; RV32IZFBFMIN-NEXT:    fcvt.s.bf16 fa5, fa5
 ; RV32IZFBFMIN-NEXT:    fcvt.s.bf16 fa4, fa4
 ; RV32IZFBFMIN-NEXT:    fadd.s fa5, fa4, fa5
 ; RV32IZFBFMIN-NEXT:    fcvt.bf16.s fa0, fa5
-; RV32IZFBFMIN-NEXT:    addi sp, sp, 16
 ; RV32IZFBFMIN-NEXT:    ret
 ;
 ; RV64IZFBFMIN-LABEL: fabs_s:
 ; RV64IZFBFMIN:       # %bb.0:
-; RV64IZFBFMIN-NEXT:    addi sp, sp, -16
 ; RV64IZFBFMIN-NEXT:    fcvt.s.bf16 fa5, fa1
 ; RV64IZFBFMIN-NEXT:    fcvt.s.bf16 fa4, fa0
 ; RV64IZFBFMIN-NEXT:    fadd.s fa5, fa4, fa5
 ; RV64IZFBFMIN-NEXT:    fcvt.bf16.s fa5, fa5
-; RV64IZFBFMIN-NEXT:    fsh fa5, 8(sp)
-; RV64IZFBFMIN-NEXT:    lbu a0, 9(sp)
-; RV64IZFBFMIN-NEXT:    andi a0, a0, 127
-; RV64IZFBFMIN-NEXT:    sb a0, 9(sp)
-; RV64IZFBFMIN-NEXT:    flh fa4, 8(sp)
+; RV64IZFBFMIN-NEXT:    fmv.x.h a0, fa5
+; RV64IZFBFMIN-NEXT:    slli a0, a0, 49
+; RV64IZFBFMIN-NEXT:    srli a0, a0, 49
+; RV64IZFBFMIN-NEXT:    fmv.h.x fa4, a0
 ; RV64IZFBFMIN-NEXT:    fcvt.s.bf16 fa5, fa5
 ; RV64IZFBFMIN-NEXT:    fcvt.s.bf16 fa4, fa4
 ; RV64IZFBFMIN-NEXT:    fadd.s fa5, fa4, fa5
 ; RV64IZFBFMIN-NEXT:    fcvt.bf16.s fa0, fa5
-; RV64IZFBFMIN-NEXT:    addi sp, sp, 16
 ; RV64IZFBFMIN-NEXT:    ret
   %1 = fadd bfloat %a, %b
   %2 = call bfloat @llvm.fabs.bf16(bfloat %1)
@@ -289,45 +285,22 @@ define bfloat @fmadd_s(bfloat %a, bfloat %b, bfloat %c) nounwind {
 }
 
 define bfloat @fmsub_s(bfloat %a, bfloat %b, bfloat %c) nounwind {
-; RV32IZFBFMIN-LABEL: fmsub_s:
-; RV32IZFBFMIN:       # %bb.0:
-; RV32IZFBFMIN-NEXT:    addi sp, sp, -16
-; RV32IZFBFMIN-NEXT:    fcvt.s.bf16 fa5, fa2
-; RV32IZFBFMIN-NEXT:    fmv.w.x fa4, zero
-; RV32IZFBFMIN-NEXT:    fadd.s fa5, fa5, fa4
-; RV32IZFBFMIN-NEXT:    fcvt.bf16.s fa5, fa5
-; RV32IZFBFMIN-NEXT:    fsh fa5, 12(sp)
-; RV32IZFBFMIN-NEXT:    lbu a0, 13(sp)
-; RV32IZFBFMIN-NEXT:    xori a0, a0, 128
-; RV32IZFBFMIN-NEXT:    sb a0, 13(sp)
-; RV32IZFBFMIN-NEXT:    flh fa5, 12(sp)
-; RV32IZFBFMIN-NEXT:    fcvt.s.bf16 fa4, fa1
-; RV32IZFBFMIN-NEXT:    fcvt.s.bf16 fa3, fa0
-; RV32IZFBFMIN-NEXT:    fcvt.s.bf16 fa5, fa5
-; RV32IZFBFMIN-NEXT:    fmadd.s fa5, fa3, fa4, fa5
-; RV32IZFBFMIN-NEXT:    fcvt.bf16.s fa0, fa5
-; RV32IZFBFMIN-NEXT:    addi sp, sp, 16
-; RV32IZFBFMIN-NEXT:    ret
-;
-; RV64IZFBFMIN-LABEL: fmsub_s:
-; RV64IZFBFMIN:       # %bb.0:
-; RV64IZFBFMIN-NEXT:    addi sp, sp, -16
-; RV64IZFBFMIN-NEXT:    fcvt.s.bf16 fa5, fa2
-; RV64IZFBFMIN-NEXT:    fmv.w.x fa4, zero
-; RV64IZFBFMIN-NEXT:    fadd.s fa5, fa5, fa4
-; RV64IZFBFMIN-NEXT:    fcvt.bf16.s fa5, fa5
-; RV64IZFBFMIN-NEXT:    fsh fa5, 8(sp)
-; RV64IZFBFMIN-NEXT:    lbu a0, 9(sp)
-; RV64IZFBFMIN-NEXT:    xori a0, a0, 128
-; RV64IZFBFMIN-NEXT:    sb a0, 9(sp)
-; RV64IZFBFMIN-NEXT:    flh fa5, 8(sp)
-; RV64IZFBFMIN-NEXT:    fcvt.s.bf16 fa4, fa1
-; RV64IZFBFMIN-NEXT:    fcvt.s.bf16 fa3, fa0
-; RV64IZFBFMIN-NEXT:    fcvt.s.bf16 fa5, fa5
-; RV64IZFBFMIN-NEXT:    fmadd.s fa5, fa3, fa4, fa5
-; RV64IZFBFMIN-NEXT:    fcvt.bf16.s fa0, fa5
-; RV64IZFBFMIN-NEXT:    addi sp, sp, 16
-; RV64IZFBFMIN-NEXT:    ret
+; CHECK-LABEL: fmsub_s:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    fcvt.s.bf16 fa5, fa2
+; CHECK-NEXT:    fmv.w.x fa4, zero
+; CHECK-NEXT:    fadd.s fa5, fa5, fa4
+; CHECK-NEXT:    fcvt.bf16.s fa5, fa5
+; CHECK-NEXT:    fmv.x.h a0, fa5
+; CHECK-NEXT:    lui a1, 1048568
+; CHECK-NEXT:    xor a0, a0, a1
+; CHECK-NEXT:    fmv.h.x fa5, a0
+; CHECK-NEXT:    fcvt.s.bf16 fa5, fa5
+; CHECK-NEXT:    fcvt.s.bf16 fa4, fa1
+; CHECK-NEXT:    fcvt.s.bf16 fa3, fa0
+; CHECK-NEXT:    fmadd.s fa5, fa3, fa4, fa5
+; CHECK-NEXT:    fcvt.bf16.s fa0, fa5
+; CHECK-NEXT:    ret
   %c_ = fadd bfloat 0.0, %c ; avoid negation using xor
   %negc = fsub bfloat -0.0, %c_
   %1 = call bfloat @llvm.fma.bf16(bfloat %a, bfloat %b, bfloat %negc)
@@ -335,61 +308,28 @@ define bfloat @fmsub_s(bfloat %a, bfloat %b, bfloat %c) nounwind {
 }
 
 define bfloat @fnmadd_s(bfloat %a, bfloat %b, bfloat %c) nounwind {
-; RV32IZFBFMIN-LABEL: fnmadd_s:
-; RV32IZFBFMIN:       # %bb.0:
-; RV32IZFBFMIN-NEXT:    addi sp, sp, -16
-; RV32IZFBFMIN-NEXT:    fcvt.s.bf16 fa5, fa0
-; RV32IZFBFMIN-NEXT:    fmv.w.x fa4, zero
-; RV32IZFBFMIN-NEXT:    fadd.s fa5, fa5, fa4
-; RV32IZFBFMIN-NEXT:    fcvt.bf16.s fa5, fa5
-; RV32IZFBFMIN-NEXT:    fsh fa5, 8(sp)
-; RV32IZFBFMIN-NEXT:    lbu a0, 9(sp)
-; RV32IZFBFMIN-NEXT:    fcvt.s.bf16 fa5, fa2
-; RV32IZFBFMIN-NEXT:    fadd.s fa5, fa5, fa4
-; RV32IZFBFMIN-NEXT:    fcvt.bf16.s fa5, fa5
-; RV32IZFBFMIN-NEXT:    xori a0, a0, 128
-; RV32IZFBFMIN-NEXT:    sb a0, 9(sp)
-; RV32IZFBFMIN-NEXT:    flh fa4, 8(sp)
-; RV32IZFBFMIN-NEXT:    fsh fa5, 12(sp)
-; RV32IZFBFMIN-NEXT:    lbu a0, 13(sp)
-; RV32IZFBFMIN-NEXT:    xori a0, a0, 128
-; RV32IZFBFMIN-NEXT:    sb a0, 13(sp)
-; RV32IZFBFMIN-NEXT:    flh fa5, 12(sp)
-; RV32IZFBFMIN-NEXT:    fcvt.s.bf16 fa3, fa1
-; RV32IZFBFMIN-NEXT:    fcvt.s.bf16 fa5, fa5
-; RV32IZFBFMIN-NEXT:    fcvt.s.bf16 fa4, fa4
-; RV32IZFBFMIN-NEXT:    fmadd.s fa5, fa4, fa3, fa5
-; RV32IZFBFMIN-NEXT:    fcvt.bf16.s fa0, fa5
-; RV32IZFBFMIN-NEXT:    addi sp, sp, 16
-; RV32IZFBFMIN-NEXT:    ret
-;
-; RV64IZFBFMIN-LABEL: fnmadd_s:
-; RV64IZFBFMIN:       # %bb.0:
-; RV64IZFBFMIN-NEXT:    addi sp, sp, -16
-; RV64IZFBFMIN-NEXT:    fcvt.s.bf16 fa5, fa0
-; RV64IZFBFMIN-NEXT:    fmv.w.x fa4, zero
-; RV64IZFBFMIN-NEXT:    fadd.s fa5, fa5, fa4
-; RV64IZFBFMIN-NEXT:    fcvt.bf16.s fa5, fa5
-; RV64IZFBFMIN-NEXT:    fsh fa5, 0(sp)
-; RV64IZFBFMIN-NEXT:    lbu a0, 1(sp)
-; RV64IZFBFMIN-NEXT:    fcvt.s.bf16 fa5, fa2
-; RV64IZFBFMIN-NEXT:    fadd.s fa5, fa5, fa4
-; RV64IZFBFMIN-NEXT:    fcvt.bf16.s fa5, fa5
-; RV64IZFBFMIN-NEXT:    xori a0, a0, 128
-; RV64IZFBFMIN-NEXT:    sb a0, 1(sp)
-; RV64IZFBFMIN-NEXT:    flh fa4, 0(sp)
-; RV64IZFBFMIN-NEXT:    fsh fa5, 8(sp)
-; RV64IZFBFMIN-NEXT:    lbu a0, 9(sp)
-; RV64IZFBFMIN-NEXT:    xori a0, a0, 128
-; RV64IZFBFMIN-NEXT:    sb a0, 9(sp)
-; RV64IZFBFMIN-NEXT:    flh fa5, 8(sp)
-; RV64IZFBFMIN-NEXT:    fcvt.s.bf16 fa3, fa1
-; RV64IZFBFMIN-NEXT:    fcvt.s.bf16 fa5, fa5
-; RV64IZFBFMIN-NEXT:    fcvt.s.bf16 fa4, fa4
-; RV64IZFBFMIN-NEXT:    fmadd.s fa5, fa4, fa3, fa5
-; RV64IZFBFMIN-NEXT:    fcvt.bf16.s fa0, fa5
-; RV64IZFBFMIN-NEXT:    addi sp, sp, 16
-; RV64IZFBFMIN-NEXT:    ret
+; CHECK-LABEL: fnmadd_s:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    fcvt.s.bf16 fa5, fa0
+; CHECK-NEXT:    fmv.w.x fa4, zero
+; CHECK-NEXT:    fadd.s fa5, fa5, fa4
+; CHECK-NEXT:    fcvt.bf16.s fa5, fa5
+; CHECK-NEXT:    fcvt.s.bf16 fa3, fa2
+; CHECK-NEXT:    fadd.s fa4, fa3, fa4
+; CHECK-NEXT:    fcvt.bf16.s fa4, fa4
+; CHECK-NEXT:    fmv.x.h a0, fa5
+; CHECK-NEXT:    lui a1, 1048568
+; CHECK-NEXT:    xor a0, a0, a1
+; CHECK-NEXT:    fmv.h.x fa5, a0
+; CHECK-NEXT:    fmv.x.h a0, fa4
+; CHECK-NEXT:    xor a0, a0, a1
+; CHECK-NEXT:    fmv.h.x fa4, a0
+; CHECK-NEXT:    fcvt.s.bf16 fa4, fa4
+; CHECK-NEXT:    fcvt.s.bf16 fa5, fa5
+; CHECK-NEXT:    fcvt.s.bf16 fa3, fa1
+; CHECK-NEXT:    fmadd.s fa5, fa5, fa3, fa4
+; CHECK-NEXT:    fcvt.bf16.s fa0, fa5
+; CHECK-NEXT:    ret
   %a_ = fadd bfloat 0.0, %a
   %c_ = fadd bfloat 0.0, %c
   %nega = fsub bfloat -0.0, %a_
@@ -399,61 +339,28 @@ define bfloat @fnmadd_s(bfloat %a, bfloat %b, bfloat %c) nounwind {
 }
 
 define bfloat @fnmadd_s_2(bfloat %a, bfloat %b, bfloat %c) nounwind {
-; RV32IZFBFMIN-LABEL: fnmadd_s_2:
-; RV32IZFBFMIN:       # %bb.0:
-; RV32IZFBFMIN-NEXT:    addi sp, sp, -16
-; RV32IZFBFMIN-NEXT:    fcvt.s.bf16 fa5, fa1
-; RV32IZFBFMIN-NEXT:    fmv.w.x fa4, zero
-; RV32IZFBFMIN-NEXT:    fadd.s fa5, fa5, fa4
-; RV32IZFBFMIN-NEXT:    fcvt.bf16.s fa5, fa5
-; RV32IZFBFMIN-NEXT:    fsh fa5, 8(sp)
-; RV32IZFBFMIN-NEXT:    lbu a0, 9(sp)
-; RV32IZFBFMIN-NEXT:    fcvt.s.bf16 fa5, fa2
-; RV32IZFBFMIN-NEXT:    fadd.s fa5, fa5, fa4
-; RV32IZFBFMIN-NEXT:    fcvt.bf16.s fa5, fa5
-; RV32IZFBFMIN-NEXT:    xori a0, a0, 128
-; RV32IZFBFMIN-NEXT:    sb a0, 9(sp)
-; RV32IZFBFMIN-NEXT:    flh fa4, 8(sp)
-; RV32IZFBFMIN-NEXT:    fsh fa5, 12(sp)
-; RV32IZFBFMIN-NEXT:    lbu a0, 13(sp)
-; RV32IZFBFMIN-NEXT:    xori a0, a0, 128
-; RV32IZFBFMIN-NEXT:    sb a0, 13(sp)
-; RV32IZFBFMIN-NEXT:    flh fa5, 12(sp)
-; RV32IZFBFMIN-NEXT:    fcvt.s.bf16 fa3, fa0
-; RV32IZFBFMIN-NEXT:    fcvt.s.bf16 fa5, fa5
-; RV32IZFBFMIN-NEXT:    fcvt.s.bf16 fa4, fa4
-; RV32IZFBFMIN-NEXT:    fmadd.s fa5, fa3, fa4, fa5
-; RV32IZFBFMIN-NEXT:    fcvt.bf16.s fa0, fa5
-; RV32IZFBFMIN-NEXT:    addi sp, sp, 16
-; RV32IZFBFMIN-NEXT:    ret
-;
-; RV64IZFBFMIN-LABEL: fnmadd_s_2:
-; RV64IZFBFMIN:       # %bb.0:
-; RV64IZFBFMIN-NEXT:    addi sp, sp, -16
-; RV64IZFBFMIN-NEXT:    fcvt.s.bf16 fa5, fa1
-; RV64IZFBFMIN-NEXT:    fmv.w.x fa4, zero
-; RV64IZFBFMIN-NEXT:    fadd.s fa5, fa5, fa4
-; RV64IZFBFMIN-NEXT:    fcvt.bf16.s fa5, fa5
-; RV64IZFBFMIN-NEXT:    fsh fa5, 0(sp)
-; RV64IZFBFMIN-NEXT:    lbu a0, 1(sp)
-; RV64IZFBFMIN-NEXT:    fcvt.s.bf16 fa5, fa2
-; RV64IZFBFMIN-NEXT:    fadd.s fa5, fa5, fa4
-; RV64IZFBFMIN-NEXT:    fcvt.bf16.s f...
[truncated]

Copy link
Member

@dtcxzyw dtcxzyw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LG. BTW we have the same problem with f16/bf16 fcopysign.

; bin/llc -mtriple=riscv64 -mattr=+zfhmin test.ll -o -
define half @test(half %a, half %b) nounwind {
  %t = call half @llvm.copysign.f16(half %a, half %b)
  ret half %t
}
        addi    sp, sp, -16
        fsh     fa1, 8(sp)
        fsh     fa0, 0(sp)
        lbu     a0, 9(sp)
        lbu     a1, 1(sp)
        andi    a0, a0, 128
        andi    a1, a1, 127
        or      a0, a1, a0
        sb      a0, 1(sp)
        flh     fa0, 0(sp)
        addi    sp, sp, 16
        ret

@topperc
Copy link
Collaborator Author

topperc commented Sep 1, 2024

LG. BTW we have the same problem with f16/bf16 fcopysign.


; bin/llc -mtriple=riscv64 -mattr=+zfhmin test.ll -o -

define half @test(half %a, half %b) nounwind {

  %t = call half @llvm.copysign.f16(half %a, half %b)

  ret half %t

}


        addi    sp, sp, -16

        fsh     fa1, 8(sp)

        fsh     fa0, 0(sp)

        lbu     a0, 9(sp)

        lbu     a1, 1(sp)

        andi    a0, a0, 128

        andi    a1, a1, 127

        or      a0, a1, a0

        sb      a0, 1(sp)

        flh     fa0, 0(sp)

        addi    sp, sp, 16

        ret

I know. It needs a more complex fix because of the two operands and the sign operand can be a different type. I'll work on it soon.

We should also stop promoting ISD::SELECT.

@topperc topperc merged commit 3bdec31 into llvm:main Sep 1, 2024
10 checks passed
@topperc topperc deleted the pr/f16-fneg-fabs branch September 1, 2024 06:57
topperc added a commit to topperc/llvm-project that referenced this pull request Sep 3, 2024
The LegalizeDAG expansion will go through memory since i16 isn't a legal
type. Avoid this by using FMV nodes.

Similar to what we did for llvm#106886 for FNEG and FABS. Special care
is needed to handle the Sign operand being a different type.
topperc added a commit that referenced this pull request Sep 3, 2024
The LegalizeDAG expansion will go through memory since i16 isn't a legal
type. Avoid this by using FMV nodes.

Similar to what we did for #106886 for FNEG and FABS. Special care is
needed to handle the Sign operand being a different type.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants