Skip to content

[RISCV][ISel] Fold FSGNJX idioms #100718

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jul 27, 2024
Merged

[RISCV][ISel] Fold FSGNJX idioms #100718

merged 2 commits into from
Jul 27, 2024

Conversation

dtcxzyw
Copy link
Member

@dtcxzyw dtcxzyw commented Jul 26, 2024

This patch folds fmul X, (fcopysign 1.0, Y) into fsgnjx X, Y. This pattern exists in some graphics applications/math libraries.
Alive2: https://alive2.llvm.org/ce/z/epyL33

Since fpimm +1.0 is lowered to a load from constant pool after OpLegalization, I have to introduce a new RISCVISD node FSGNJX and fold this pattern in DAGCombine.

Closes dtcxzyw/llvm-opt-benchmark#1072.

Copy link
Contributor

@wangpc-pp wangpc-pp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know if there may be an impact, but can you test if this combine works well with Zfa?

@llvmbot
Copy link
Member

llvmbot commented Jul 26, 2024

@llvm/pr-subscribers-backend-risc-v

Author: Yingwei Zheng (dtcxzyw)

Changes

This patch folds fmul X, (fcopysign 1.0, Y) into fsgnjx X, Y. This pattern exists in some graphics applications/math libraries.
Alive2: https://alive2.llvm.org/ce/z/epyL33

Since fpimm +1.0 is lowered to a load from constant pool after OpLegalization, I have to introduce a new RISCVISD node FSGNJX and fold this pattern in DAGCombine.

Closes dtcxzyw/llvm-opt-benchmark#1072.


Full diff: https://github.com/llvm/llvm-project/pull/100718.diff

8 Files Affected:

  • (modified) llvm/lib/Target/RISCV/RISCVISelLowering.cpp (+21-1)
  • (modified) llvm/lib/Target/RISCV/RISCVISelLowering.h (+1)
  • (modified) llvm/lib/Target/RISCV/RISCVInstrInfoD.td (+3)
  • (modified) llvm/lib/Target/RISCV/RISCVInstrInfoF.td (+8-1)
  • (modified) llvm/lib/Target/RISCV/RISCVInstrInfoZfh.td (+2)
  • (modified) llvm/test/CodeGen/RISCV/double-arith.ll (+48)
  • (modified) llvm/test/CodeGen/RISCV/float-arith.ll (+41)
  • (modified) llvm/test/CodeGen/RISCV/half-arith.ll (+109)
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index d40d4997d7614..8e1990a6080e7 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -1473,7 +1473,7 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
     setTargetDAGCombine(ISD::SRA);
 
   if (Subtarget.hasStdExtFOrZfinx())
-    setTargetDAGCombine({ISD::FADD, ISD::FMAXNUM, ISD::FMINNUM});
+    setTargetDAGCombine({ISD::FADD, ISD::FMAXNUM, ISD::FMINNUM, ISD::FMUL});
 
   if (Subtarget.hasStdExtZbb())
     setTargetDAGCombine({ISD::UMAX, ISD::UMIN, ISD::SMAX, ISD::SMIN});
@@ -16711,6 +16711,25 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
     if (SDValue V = combineBinOpOfZExt(N, DAG))
       return V;
     break;
+  case ISD::FMUL: {
+    // fmul X, (copysign 1.0, Y) -> fsgnjx X, Y
+    SDValue N0 = N->getOperand(0);
+    SDValue N1 = N->getOperand(1);
+    if (N0->getOpcode() != ISD::FCOPYSIGN)
+      std::swap(N0, N1);
+    if (N0->getOpcode() != ISD::FCOPYSIGN)
+      return SDValue();
+    ConstantFPSDNode *C = dyn_cast<ConstantFPSDNode>(N0->getOperand(0));
+    if (!C || !C->getValueAPF().isExactlyValue(+1.0))
+      return SDValue();
+    EVT VT = N->getValueType(0);
+    if (VT.isVector() || !isOperationLegal(ISD::FCOPYSIGN, VT))
+      return SDValue();
+    SDValue Sign = N0->getOperand(1);
+    if (Sign.getValueType() != VT)
+      return SDValue();
+    return DAG.getNode(RISCVISD::FSGNJX, SDLoc(N), VT, N1, N0->getOperand(1));
+  }
   case ISD::FADD:
   case ISD::UMAX:
   case ISD::UMIN:
@@ -20261,6 +20280,7 @@ const char *RISCVTargetLowering::getTargetNodeName(unsigned Opcode) const {
   NODE_NAME_CASE(FP_EXTEND_BF16)
   NODE_NAME_CASE(FROUND)
   NODE_NAME_CASE(FCLASS)
+  NODE_NAME_CASE(FSGNJX)
   NODE_NAME_CASE(FMAX)
   NODE_NAME_CASE(FMIN)
   NODE_NAME_CASE(READ_COUNTER_WIDE)
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.h b/llvm/lib/Target/RISCV/RISCVISelLowering.h
index e469a4b1238c7..498c77f1875ed 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.h
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.h
@@ -128,6 +128,7 @@ enum NodeType : unsigned {
   FROUND,
 
   FCLASS,
+  FSGNJX,
 
   // Floating point fmax and fmin matching the RISC-V instruction semantics.
   FMAX, FMIN,
diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoD.td b/llvm/lib/Target/RISCV/RISCVInstrInfoD.td
index 8efefee383a6a..35ab277fa3505 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfoD.td
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfoD.td
@@ -282,6 +282,7 @@ def : Pat<(fabs FPR64:$rs1), (FSGNJX_D $rs1, $rs1)>;
 def : Pat<(riscv_fclass FPR64:$rs1), (FCLASS_D $rs1)>;
 
 def : PatFprFpr<fcopysign, FSGNJ_D, FPR64, f64>;
+def : PatFprFpr<riscv_fsgnjx, FSGNJX_D, FPR64, f64>;
 def : Pat<(fcopysign FPR64:$rs1, (fneg FPR64:$rs2)), (FSGNJN_D $rs1, $rs2)>;
 def : Pat<(fcopysign FPR64:$rs1, FPR32:$rs2), (FSGNJ_D $rs1, (FCVT_D_S $rs2,
                                                               FRM_RNE))>;
@@ -318,6 +319,7 @@ def : Pat<(fabs FPR64INX:$rs1), (FSGNJX_D_INX $rs1, $rs1)>;
 def : Pat<(riscv_fclass FPR64INX:$rs1), (FCLASS_D_INX $rs1)>;
 
 def : PatFprFpr<fcopysign, FSGNJ_D_INX, FPR64INX, f64>;
+def : PatFprFpr<riscv_fsgnjx, FSGNJX_D_INX, FPR64INX, f64>;
 def : Pat<(fcopysign FPR64INX:$rs1, (fneg FPR64INX:$rs2)),
           (FSGNJN_D_INX $rs1, $rs2)>;
 def : Pat<(fcopysign FPR64INX:$rs1, FPR32INX:$rs2),
@@ -355,6 +357,7 @@ def : Pat<(fabs FPR64IN32X:$rs1), (FSGNJX_D_IN32X $rs1, $rs1)>;
 def : Pat<(riscv_fclass FPR64IN32X:$rs1), (FCLASS_D_IN32X $rs1)>;
 
 def : PatFprFpr<fcopysign, FSGNJ_D_IN32X, FPR64IN32X, f64>;
+def : PatFprFpr<riscv_fsgnjx, FSGNJX_D_IN32X, FPR64IN32X, f64>;
 def : Pat<(fcopysign FPR64IN32X:$rs1, (fneg FPR64IN32X:$rs2)),
           (FSGNJN_D_IN32X $rs1, $rs2)>;
 def : Pat<(fcopysign FPR64IN32X:$rs1, FPR32INX:$rs2),
diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoF.td b/llvm/lib/Target/RISCV/RISCVInstrInfoF.td
index 7d89608de1223..e6c25e0844fb2 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfoF.td
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfoF.td
@@ -31,6 +31,8 @@ def SDT_RISCVFROUND
                            SDTCisVT<3, XLenVT>]>;
 def SDT_RISCVFCLASS
     : SDTypeProfile<1, 1, [SDTCisVT<0, XLenVT>, SDTCisFP<1>]>;
+def SDT_RISCVFSGNJX
+    : SDTypeProfile<1, 2, [SDTCisFP<0>, SDTCisSameAs<0, 1>, SDTCisSameAs<0, 2>]>;
 
 def riscv_fclass
     : SDNode<"RISCVISD::FCLASS", SDT_RISCVFCLASS>;
@@ -38,6 +40,9 @@ def riscv_fclass
 def riscv_fround
     : SDNode<"RISCVISD::FROUND", SDT_RISCVFROUND>;
 
+def riscv_fsgnjx
+    : SDNode<"RISCVISD::FSGNJX", SDT_RISCVFSGNJX>;
+
 def riscv_fmv_w_x_rv64
     : SDNode<"RISCVISD::FMV_W_X_RV64", SDT_RISCVFMV_W_X_RV64>;
 def riscv_fmv_x_anyextw_rv64
@@ -539,8 +544,10 @@ def : Pat<(fabs FPR32INX:$rs1), (FSGNJX_S_INX $rs1, $rs1)>;
 def : Pat<(riscv_fclass FPR32INX:$rs1), (FCLASS_S_INX $rs1)>;
 } // Predicates = [HasStdExtZfinx]
 
-foreach Ext = FExts in
+foreach Ext = FExts in {
 defm : PatFprFpr_m<fcopysign, FSGNJ_S, Ext>;
+defm : PatFprFpr_m<riscv_fsgnjx, FSGNJX_S, Ext>;
+}
 
 let Predicates = [HasStdExtF] in {
 def : Pat<(fcopysign FPR32:$rs1, (fneg FPR32:$rs2)), (FSGNJN_S $rs1, $rs2)>;
diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoZfh.td b/llvm/lib/Target/RISCV/RISCVInstrInfoZfh.td
index e0f1c71548344..85715ca9145c3 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfoZfh.td
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfoZfh.td
@@ -272,6 +272,7 @@ def : Pat<(f16 (fabs FPR16:$rs1)), (FSGNJX_H $rs1, $rs1)>;
 def : Pat<(riscv_fclass (f16 FPR16:$rs1)), (FCLASS_H $rs1)>;
 
 def : PatFprFpr<fcopysign, FSGNJ_H, FPR16, f16>;
+def : PatFprFpr<riscv_fsgnjx, FSGNJX_H, FPR16, f16>;
 def : Pat<(f16 (fcopysign FPR16:$rs1, (f16 (fneg FPR16:$rs2)))), (FSGNJN_H $rs1, $rs2)>;
 def : Pat<(f16 (fcopysign FPR16:$rs1, FPR32:$rs2)),
           (FSGNJ_H $rs1, (FCVT_H_S $rs2, FRM_DYN))>;
@@ -314,6 +315,7 @@ def : Pat<(fabs FPR16INX:$rs1), (FSGNJX_H_INX $rs1, $rs1)>;
 def : Pat<(riscv_fclass FPR16INX:$rs1), (FCLASS_H_INX $rs1)>;
 
 def : PatFprFpr<fcopysign, FSGNJ_H_INX, FPR16INX, f16>;
+def : PatFprFpr<riscv_fsgnjx, FSGNJX_H_INX, FPR16INX, f16>;
 def : Pat<(fcopysign FPR16INX:$rs1, (fneg FPR16INX:$rs2)), (FSGNJN_H_INX $rs1, $rs2)>;
 def : Pat<(fcopysign FPR16INX:$rs1, FPR32INX:$rs2),
           (FSGNJ_H_INX $rs1, (FCVT_H_S_INX $rs2, FRM_DYN))>;
diff --git a/llvm/test/CodeGen/RISCV/double-arith.ll b/llvm/test/CodeGen/RISCV/double-arith.ll
index ced6ff66ef678..ee54501fe59a8 100644
--- a/llvm/test/CodeGen/RISCV/double-arith.ll
+++ b/llvm/test/CodeGen/RISCV/double-arith.ll
@@ -1497,3 +1497,51 @@ define double @fnmsub_d_contract(double %a, double %b, double %c) nounwind {
   %2 = fsub contract double %c, %1
   ret double %2
 }
+
+define double @fsgnjx_f64(double %x, double %y) nounwind {
+; CHECKIFD-LABEL: fsgnjx_f64:
+; CHECKIFD:       # %bb.0:
+; CHECKIFD-NEXT:    fsgnjx.d fa0, fa1, fa0
+; CHECKIFD-NEXT:    ret
+;
+; RV32IZFINXZDINX-LABEL: fsgnjx_f64:
+; RV32IZFINXZDINX:       # %bb.0:
+; RV32IZFINXZDINX-NEXT:    fsgnjx.d a0, a2, a0
+; RV32IZFINXZDINX-NEXT:    ret
+;
+; RV64IZFINXZDINX-LABEL: fsgnjx_f64:
+; RV64IZFINXZDINX:       # %bb.0:
+; RV64IZFINXZDINX-NEXT:    fsgnjx.d a0, a1, a0
+; RV64IZFINXZDINX-NEXT:    ret
+;
+; RV32I-LABEL: fsgnjx_f64:
+; RV32I:       # %bb.0:
+; RV32I-NEXT:    addi sp, sp, -16
+; RV32I-NEXT:    sw ra, 12(sp) # 4-byte Folded Spill
+; RV32I-NEXT:    lui a0, 524288
+; RV32I-NEXT:    and a0, a1, a0
+; RV32I-NEXT:    lui a1, 261888
+; RV32I-NEXT:    or a1, a0, a1
+; RV32I-NEXT:    li a0, 0
+; RV32I-NEXT:    call __muldf3
+; RV32I-NEXT:    lw ra, 12(sp) # 4-byte Folded Reload
+; RV32I-NEXT:    addi sp, sp, 16
+; RV32I-NEXT:    ret
+;
+; RV64I-LABEL: fsgnjx_f64:
+; RV64I:       # %bb.0:
+; RV64I-NEXT:    addi sp, sp, -16
+; RV64I-NEXT:    sd ra, 8(sp) # 8-byte Folded Spill
+; RV64I-NEXT:    srli a0, a0, 63
+; RV64I-NEXT:    slli a0, a0, 63
+; RV64I-NEXT:    li a2, 1023
+; RV64I-NEXT:    slli a2, a2, 52
+; RV64I-NEXT:    or a0, a0, a2
+; RV64I-NEXT:    call __muldf3
+; RV64I-NEXT:    ld ra, 8(sp) # 8-byte Folded Reload
+; RV64I-NEXT:    addi sp, sp, 16
+; RV64I-NEXT:    ret
+  %z = call double @llvm.copysign.f64(double 1.0, double %x)
+  %mul = fmul double %z, %y
+  ret double %mul
+}
diff --git a/llvm/test/CodeGen/RISCV/float-arith.ll b/llvm/test/CodeGen/RISCV/float-arith.ll
index 7a7ebe651c08e..931f73a94170a 100644
--- a/llvm/test/CodeGen/RISCV/float-arith.ll
+++ b/llvm/test/CodeGen/RISCV/float-arith.ll
@@ -1195,3 +1195,44 @@ define float @fnmsub_s_contract(float %a, float %b, float %c) nounwind {
   %2 = fsub contract float %c, %1
   ret float %2
 }
+
+define float @fsgnjx_f32(float %x, float %y) nounwind {
+; CHECKIF-LABEL: fsgnjx_f32:
+; CHECKIF:       # %bb.0:
+; CHECKIF-NEXT:    fsgnjx.s fa0, fa1, fa0
+; CHECKIF-NEXT:    ret
+;
+; CHECKIZFINX-LABEL: fsgnjx_f32:
+; CHECKIZFINX:       # %bb.0:
+; CHECKIZFINX-NEXT:    fsgnjx.s a0, a1, a0
+; CHECKIZFINX-NEXT:    ret
+;
+; RV32I-LABEL: fsgnjx_f32:
+; RV32I:       # %bb.0:
+; RV32I-NEXT:    addi sp, sp, -16
+; RV32I-NEXT:    sw ra, 12(sp) # 4-byte Folded Spill
+; RV32I-NEXT:    lui a2, 524288
+; RV32I-NEXT:    and a0, a0, a2
+; RV32I-NEXT:    lui a2, 260096
+; RV32I-NEXT:    or a0, a0, a2
+; RV32I-NEXT:    call __mulsf3
+; RV32I-NEXT:    lw ra, 12(sp) # 4-byte Folded Reload
+; RV32I-NEXT:    addi sp, sp, 16
+; RV32I-NEXT:    ret
+;
+; RV64I-LABEL: fsgnjx_f32:
+; RV64I:       # %bb.0:
+; RV64I-NEXT:    addi sp, sp, -16
+; RV64I-NEXT:    sd ra, 8(sp) # 8-byte Folded Spill
+; RV64I-NEXT:    lui a2, 524288
+; RV64I-NEXT:    and a0, a0, a2
+; RV64I-NEXT:    lui a2, 260096
+; RV64I-NEXT:    or a0, a0, a2
+; RV64I-NEXT:    call __mulsf3
+; RV64I-NEXT:    ld ra, 8(sp) # 8-byte Folded Reload
+; RV64I-NEXT:    addi sp, sp, 16
+; RV64I-NEXT:    ret
+  %z = call float @llvm.copysign.f32(float 1.0, float %x)
+  %mul = fmul float %z, %y
+  ret float %mul
+}
diff --git a/llvm/test/CodeGen/RISCV/half-arith.ll b/llvm/test/CodeGen/RISCV/half-arith.ll
index f54adaa24b1bb..10e63e3a9f748 100644
--- a/llvm/test/CodeGen/RISCV/half-arith.ll
+++ b/llvm/test/CodeGen/RISCV/half-arith.ll
@@ -3104,3 +3104,112 @@ define half @fnmsub_s_contract(half %a, half %b, half %c) nounwind {
   %2 = fsub contract half %c, %1
   ret half %2
 }
+
+define half @fsgnjx_f16(half %x, half %y) nounwind {
+; CHECKIZFH-LABEL: fsgnjx_f16:
+; CHECKIZFH:       # %bb.0:
+; CHECKIZFH-NEXT:    fsgnjx.h fa0, fa1, fa0
+; CHECKIZFH-NEXT:    ret
+;
+; CHECK-ZHINX-LABEL: fsgnjx_f16:
+; CHECK-ZHINX:       # %bb.0:
+; CHECK-ZHINX-NEXT:    fsgnjx.h a0, a1, a0
+; CHECK-ZHINX-NEXT:    ret
+;
+; RV32I-LABEL: fsgnjx_f16:
+; RV32I:       # %bb.0:
+; RV32I-NEXT:    addi sp, sp, -16
+; RV32I-NEXT:    sw ra, 12(sp) # 4-byte Folded Spill
+; RV32I-NEXT:    sw s0, 8(sp) # 4-byte Folded Spill
+; RV32I-NEXT:    sw s1, 4(sp) # 4-byte Folded Spill
+; RV32I-NEXT:    li a2, 15
+; RV32I-NEXT:    slli a2, a2, 10
+; RV32I-NEXT:    or s1, a0, a2
+; RV32I-NEXT:    slli a0, a1, 16
+; RV32I-NEXT:    srli a0, a0, 16
+; RV32I-NEXT:    call __extendhfsf2
+; RV32I-NEXT:    mv s0, a0
+; RV32I-NEXT:    lui a0, 12
+; RV32I-NEXT:    addi a0, a0, -1024
+; RV32I-NEXT:    and a0, s1, a0
+; RV32I-NEXT:    call __extendhfsf2
+; RV32I-NEXT:    mv a1, s0
+; RV32I-NEXT:    call __mulsf3
+; RV32I-NEXT:    call __truncsfhf2
+; RV32I-NEXT:    lw ra, 12(sp) # 4-byte Folded Reload
+; RV32I-NEXT:    lw s0, 8(sp) # 4-byte Folded Reload
+; RV32I-NEXT:    lw s1, 4(sp) # 4-byte Folded Reload
+; RV32I-NEXT:    addi sp, sp, 16
+; RV32I-NEXT:    ret
+;
+; RV64I-LABEL: fsgnjx_f16:
+; RV64I:       # %bb.0:
+; RV64I-NEXT:    addi sp, sp, -32
+; RV64I-NEXT:    sd ra, 24(sp) # 8-byte Folded Spill
+; RV64I-NEXT:    sd s0, 16(sp) # 8-byte Folded Spill
+; RV64I-NEXT:    sd s1, 8(sp) # 8-byte Folded Spill
+; RV64I-NEXT:    li a2, 15
+; RV64I-NEXT:    slli a2, a2, 10
+; RV64I-NEXT:    or s1, a0, a2
+; RV64I-NEXT:    slli a0, a1, 48
+; RV64I-NEXT:    srli a0, a0, 48
+; RV64I-NEXT:    call __extendhfsf2
+; RV64I-NEXT:    mv s0, a0
+; RV64I-NEXT:    lui a0, 12
+; RV64I-NEXT:    addiw a0, a0, -1024
+; RV64I-NEXT:    and a0, s1, a0
+; RV64I-NEXT:    call __extendhfsf2
+; RV64I-NEXT:    mv a1, s0
+; RV64I-NEXT:    call __mulsf3
+; RV64I-NEXT:    call __truncsfhf2
+; RV64I-NEXT:    ld ra, 24(sp) # 8-byte Folded Reload
+; RV64I-NEXT:    ld s0, 16(sp) # 8-byte Folded Reload
+; RV64I-NEXT:    ld s1, 8(sp) # 8-byte Folded Reload
+; RV64I-NEXT:    addi sp, sp, 32
+; RV64I-NEXT:    ret
+;
+; CHECK-RV32-FSGNJ-LABEL: fsgnjx_f16:
+; CHECK-RV32-FSGNJ:       # %bb.0:
+; CHECK-RV32-FSGNJ-NEXT:    addi sp, sp, -16
+; CHECK-RV32-FSGNJ-NEXT:    lui a0, %hi(.LCPI23_0)
+; CHECK-RV32-FSGNJ-NEXT:    flh fa5, %lo(.LCPI23_0)(a0)
+; CHECK-RV32-FSGNJ-NEXT:    fsh fa0, 12(sp)
+; CHECK-RV32-FSGNJ-NEXT:    fsh fa5, 8(sp)
+; CHECK-RV32-FSGNJ-NEXT:    lbu a0, 13(sp)
+; CHECK-RV32-FSGNJ-NEXT:    lbu a1, 9(sp)
+; CHECK-RV32-FSGNJ-NEXT:    andi a0, a0, 128
+; CHECK-RV32-FSGNJ-NEXT:    andi a1, a1, 127
+; CHECK-RV32-FSGNJ-NEXT:    or a0, a1, a0
+; CHECK-RV32-FSGNJ-NEXT:    sb a0, 9(sp)
+; CHECK-RV32-FSGNJ-NEXT:    flh fa5, 8(sp)
+; CHECK-RV32-FSGNJ-NEXT:    fcvt.s.h fa4, fa1
+; CHECK-RV32-FSGNJ-NEXT:    fcvt.s.h fa5, fa5
+; CHECK-RV32-FSGNJ-NEXT:    fmul.s fa5, fa5, fa4
+; CHECK-RV32-FSGNJ-NEXT:    fcvt.h.s fa0, fa5
+; CHECK-RV32-FSGNJ-NEXT:    addi sp, sp, 16
+; CHECK-RV32-FSGNJ-NEXT:    ret
+;
+; CHECK-RV64-FSGNJ-LABEL: fsgnjx_f16:
+; CHECK-RV64-FSGNJ:       # %bb.0:
+; CHECK-RV64-FSGNJ-NEXT:    addi sp, sp, -16
+; CHECK-RV64-FSGNJ-NEXT:    lui a0, %hi(.LCPI23_0)
+; CHECK-RV64-FSGNJ-NEXT:    flh fa5, %lo(.LCPI23_0)(a0)
+; CHECK-RV64-FSGNJ-NEXT:    fsh fa0, 8(sp)
+; CHECK-RV64-FSGNJ-NEXT:    fsh fa5, 0(sp)
+; CHECK-RV64-FSGNJ-NEXT:    lbu a0, 9(sp)
+; CHECK-RV64-FSGNJ-NEXT:    lbu a1, 1(sp)
+; CHECK-RV64-FSGNJ-NEXT:    andi a0, a0, 128
+; CHECK-RV64-FSGNJ-NEXT:    andi a1, a1, 127
+; CHECK-RV64-FSGNJ-NEXT:    or a0, a1, a0
+; CHECK-RV64-FSGNJ-NEXT:    sb a0, 1(sp)
+; CHECK-RV64-FSGNJ-NEXT:    flh fa5, 0(sp)
+; CHECK-RV64-FSGNJ-NEXT:    fcvt.s.h fa4, fa1
+; CHECK-RV64-FSGNJ-NEXT:    fcvt.s.h fa5, fa5
+; CHECK-RV64-FSGNJ-NEXT:    fmul.s fa5, fa5, fa4
+; CHECK-RV64-FSGNJ-NEXT:    fcvt.h.s fa0, fa5
+; CHECK-RV64-FSGNJ-NEXT:    addi sp, sp, 16
+; CHECK-RV64-FSGNJ-NEXT:    ret
+  %z = call half @llvm.copysign.f16(half 1.0, half %x)
+  %mul = fmul half %z, %y
+  ret half %mul
+}

Copy link
Collaborator

@topperc topperc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@preames
Copy link
Collaborator

preames commented Jul 26, 2024

Any reason this can't be done by e.g. instcombine?

@topperc
Copy link
Collaborator

topperc commented Jul 26, 2024

Any reason this can't be done by e.g. instcombine?

What IR could InstCombine generate?

@preames
Copy link
Collaborator

preames commented Jul 26, 2024

Any reason this can't be done by e.g. instcombine?

What IR could InstCombine generate?

We have a llvm.copysign, so wouldn't this just be (copysign X, Y)? Or is there some corner case here I'm missing?

@topperc
Copy link
Collaborator

topperc commented Jul 26, 2024

Any reason this can't be done by e.g. instcombine?

What IR could InstCombine generate?

We have a llvm.copysign, so wouldn't this just be (copysign X, Y)? Or is there some corner case here I'm missing?

fmul X, (fcopysign 1.0, Y) is multiply X by 1.0 if Y is positive, multiply X by -1.0 if Y is negative. This has the effect of toggling X's sign bit if Y is negative and leaving it alone otherwise. So its not a copy.

@dtcxzyw dtcxzyw merged commit 1399637 into llvm:main Jul 27, 2024
9 checks passed
@dtcxzyw dtcxzyw deleted the riscv-fsgnjx-idiom branch July 27, 2024 04:52
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Grep FSGNJX
5 participants