Skip to content

[RISCV][ISel] Fold FSGNJX idioms #100718

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jul 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 21 additions & 1 deletion llvm/lib/Target/RISCV/RISCVISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1473,7 +1473,7 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
setTargetDAGCombine(ISD::SRA);

if (Subtarget.hasStdExtFOrZfinx())
setTargetDAGCombine({ISD::FADD, ISD::FMAXNUM, ISD::FMINNUM});
setTargetDAGCombine({ISD::FADD, ISD::FMAXNUM, ISD::FMINNUM, ISD::FMUL});

if (Subtarget.hasStdExtZbb())
setTargetDAGCombine({ISD::UMAX, ISD::UMIN, ISD::SMAX, ISD::SMIN});
Expand Down Expand Up @@ -16711,6 +16711,25 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
if (SDValue V = combineBinOpOfZExt(N, DAG))
return V;
break;
case ISD::FMUL: {
// fmul X, (copysign 1.0, Y) -> fsgnjx X, Y
SDValue N0 = N->getOperand(0);
SDValue N1 = N->getOperand(1);
if (N0->getOpcode() != ISD::FCOPYSIGN)
std::swap(N0, N1);
if (N0->getOpcode() != ISD::FCOPYSIGN)
return SDValue();
ConstantFPSDNode *C = dyn_cast<ConstantFPSDNode>(N0->getOperand(0));
if (!C || !C->getValueAPF().isExactlyValue(+1.0))
return SDValue();
EVT VT = N->getValueType(0);
if (VT.isVector() || !isOperationLegal(ISD::FCOPYSIGN, VT))
return SDValue();
SDValue Sign = N0->getOperand(1);
if (Sign.getValueType() != VT)
return SDValue();
return DAG.getNode(RISCVISD::FSGNJX, SDLoc(N), VT, N1, N0->getOperand(1));
}
case ISD::FADD:
case ISD::UMAX:
case ISD::UMIN:
Expand Down Expand Up @@ -20261,6 +20280,7 @@ const char *RISCVTargetLowering::getTargetNodeName(unsigned Opcode) const {
NODE_NAME_CASE(FP_EXTEND_BF16)
NODE_NAME_CASE(FROUND)
NODE_NAME_CASE(FCLASS)
NODE_NAME_CASE(FSGNJX)
NODE_NAME_CASE(FMAX)
NODE_NAME_CASE(FMIN)
NODE_NAME_CASE(READ_COUNTER_WIDE)
Expand Down
1 change: 1 addition & 0 deletions llvm/lib/Target/RISCV/RISCVISelLowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ enum NodeType : unsigned {
FROUND,

FCLASS,
FSGNJX,

// Floating point fmax and fmin matching the RISC-V instruction semantics.
FMAX, FMIN,
Expand Down
3 changes: 3 additions & 0 deletions llvm/lib/Target/RISCV/RISCVInstrInfoD.td
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,7 @@ def : Pat<(fabs FPR64:$rs1), (FSGNJX_D $rs1, $rs1)>;
def : Pat<(riscv_fclass FPR64:$rs1), (FCLASS_D $rs1)>;

def : PatFprFpr<fcopysign, FSGNJ_D, FPR64, f64>;
def : PatFprFpr<riscv_fsgnjx, FSGNJX_D, FPR64, f64>;
def : Pat<(fcopysign FPR64:$rs1, (fneg FPR64:$rs2)), (FSGNJN_D $rs1, $rs2)>;
def : Pat<(fcopysign FPR64:$rs1, FPR32:$rs2), (FSGNJ_D $rs1, (FCVT_D_S $rs2,
FRM_RNE))>;
Expand Down Expand Up @@ -318,6 +319,7 @@ def : Pat<(fabs FPR64INX:$rs1), (FSGNJX_D_INX $rs1, $rs1)>;
def : Pat<(riscv_fclass FPR64INX:$rs1), (FCLASS_D_INX $rs1)>;

def : PatFprFpr<fcopysign, FSGNJ_D_INX, FPR64INX, f64>;
def : PatFprFpr<riscv_fsgnjx, FSGNJX_D_INX, FPR64INX, f64>;
def : Pat<(fcopysign FPR64INX:$rs1, (fneg FPR64INX:$rs2)),
(FSGNJN_D_INX $rs1, $rs2)>;
def : Pat<(fcopysign FPR64INX:$rs1, FPR32INX:$rs2),
Expand Down Expand Up @@ -355,6 +357,7 @@ def : Pat<(fabs FPR64IN32X:$rs1), (FSGNJX_D_IN32X $rs1, $rs1)>;
def : Pat<(riscv_fclass FPR64IN32X:$rs1), (FCLASS_D_IN32X $rs1)>;

def : PatFprFpr<fcopysign, FSGNJ_D_IN32X, FPR64IN32X, f64>;
def : PatFprFpr<riscv_fsgnjx, FSGNJX_D_IN32X, FPR64IN32X, f64>;
def : Pat<(fcopysign FPR64IN32X:$rs1, (fneg FPR64IN32X:$rs2)),
(FSGNJN_D_IN32X $rs1, $rs2)>;
def : Pat<(fcopysign FPR64IN32X:$rs1, FPR32INX:$rs2),
Expand Down
9 changes: 8 additions & 1 deletion llvm/lib/Target/RISCV/RISCVInstrInfoF.td
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,18 @@ def SDT_RISCVFROUND
SDTCisVT<3, XLenVT>]>;
def SDT_RISCVFCLASS
: SDTypeProfile<1, 1, [SDTCisVT<0, XLenVT>, SDTCisFP<1>]>;
def SDT_RISCVFSGNJX
: SDTypeProfile<1, 2, [SDTCisFP<0>, SDTCisSameAs<0, 1>, SDTCisSameAs<0, 2>]>;

def riscv_fclass
: SDNode<"RISCVISD::FCLASS", SDT_RISCVFCLASS>;

def riscv_fround
: SDNode<"RISCVISD::FROUND", SDT_RISCVFROUND>;

def riscv_fsgnjx
: SDNode<"RISCVISD::FSGNJX", SDT_RISCVFSGNJX>;

def riscv_fmv_w_x_rv64
: SDNode<"RISCVISD::FMV_W_X_RV64", SDT_RISCVFMV_W_X_RV64>;
def riscv_fmv_x_anyextw_rv64
Expand Down Expand Up @@ -539,8 +544,10 @@ def : Pat<(fabs FPR32INX:$rs1), (FSGNJX_S_INX $rs1, $rs1)>;
def : Pat<(riscv_fclass FPR32INX:$rs1), (FCLASS_S_INX $rs1)>;
} // Predicates = [HasStdExtZfinx]

foreach Ext = FExts in
foreach Ext = FExts in {
defm : PatFprFpr_m<fcopysign, FSGNJ_S, Ext>;
defm : PatFprFpr_m<riscv_fsgnjx, FSGNJX_S, Ext>;
}

let Predicates = [HasStdExtF] in {
def : Pat<(fcopysign FPR32:$rs1, (fneg FPR32:$rs2)), (FSGNJN_S $rs1, $rs2)>;
Expand Down
2 changes: 2 additions & 0 deletions llvm/lib/Target/RISCV/RISCVInstrInfoZfh.td
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,7 @@ def : Pat<(f16 (fabs FPR16:$rs1)), (FSGNJX_H $rs1, $rs1)>;
def : Pat<(riscv_fclass (f16 FPR16:$rs1)), (FCLASS_H $rs1)>;

def : PatFprFpr<fcopysign, FSGNJ_H, FPR16, f16>;
def : PatFprFpr<riscv_fsgnjx, FSGNJX_H, FPR16, f16>;
def : Pat<(f16 (fcopysign FPR16:$rs1, (f16 (fneg FPR16:$rs2)))), (FSGNJN_H $rs1, $rs2)>;
def : Pat<(f16 (fcopysign FPR16:$rs1, FPR32:$rs2)),
(FSGNJ_H $rs1, (FCVT_H_S $rs2, FRM_DYN))>;
Expand Down Expand Up @@ -314,6 +315,7 @@ def : Pat<(fabs FPR16INX:$rs1), (FSGNJX_H_INX $rs1, $rs1)>;
def : Pat<(riscv_fclass FPR16INX:$rs1), (FCLASS_H_INX $rs1)>;

def : PatFprFpr<fcopysign, FSGNJ_H_INX, FPR16INX, f16>;
def : PatFprFpr<riscv_fsgnjx, FSGNJX_H_INX, FPR16INX, f16>;
def : Pat<(fcopysign FPR16INX:$rs1, (fneg FPR16INX:$rs2)), (FSGNJN_H_INX $rs1, $rs2)>;
def : Pat<(fcopysign FPR16INX:$rs1, FPR32INX:$rs2),
(FSGNJ_H_INX $rs1, (FCVT_H_S_INX $rs2, FRM_DYN))>;
Expand Down
48 changes: 48 additions & 0 deletions llvm/test/CodeGen/RISCV/double-arith.ll
Original file line number Diff line number Diff line change
Expand Up @@ -1497,3 +1497,51 @@ define double @fnmsub_d_contract(double %a, double %b, double %c) nounwind {
%2 = fsub contract double %c, %1
ret double %2
}

define double @fsgnjx_f64(double %x, double %y) nounwind {
; CHECKIFD-LABEL: fsgnjx_f64:
; CHECKIFD: # %bb.0:
; CHECKIFD-NEXT: fsgnjx.d fa0, fa1, fa0
; CHECKIFD-NEXT: ret
;
; RV32IZFINXZDINX-LABEL: fsgnjx_f64:
; RV32IZFINXZDINX: # %bb.0:
; RV32IZFINXZDINX-NEXT: fsgnjx.d a0, a2, a0
; RV32IZFINXZDINX-NEXT: ret
;
; RV64IZFINXZDINX-LABEL: fsgnjx_f64:
; RV64IZFINXZDINX: # %bb.0:
; RV64IZFINXZDINX-NEXT: fsgnjx.d a0, a1, a0
; RV64IZFINXZDINX-NEXT: ret
;
; RV32I-LABEL: fsgnjx_f64:
; RV32I: # %bb.0:
; RV32I-NEXT: addi sp, sp, -16
; RV32I-NEXT: sw ra, 12(sp) # 4-byte Folded Spill
; RV32I-NEXT: lui a0, 524288
; RV32I-NEXT: and a0, a1, a0
; RV32I-NEXT: lui a1, 261888
; RV32I-NEXT: or a1, a0, a1
; RV32I-NEXT: li a0, 0
; RV32I-NEXT: call __muldf3
; RV32I-NEXT: lw ra, 12(sp) # 4-byte Folded Reload
; RV32I-NEXT: addi sp, sp, 16
; RV32I-NEXT: ret
;
; RV64I-LABEL: fsgnjx_f64:
; RV64I: # %bb.0:
; RV64I-NEXT: addi sp, sp, -16
; RV64I-NEXT: sd ra, 8(sp) # 8-byte Folded Spill
; RV64I-NEXT: srli a0, a0, 63
; RV64I-NEXT: slli a0, a0, 63
; RV64I-NEXT: li a2, 1023
; RV64I-NEXT: slli a2, a2, 52
; RV64I-NEXT: or a0, a0, a2
; RV64I-NEXT: call __muldf3
; RV64I-NEXT: ld ra, 8(sp) # 8-byte Folded Reload
; RV64I-NEXT: addi sp, sp, 16
; RV64I-NEXT: ret
%z = call double @llvm.copysign.f64(double 1.0, double %x)
%mul = fmul double %z, %y
ret double %mul
}
41 changes: 41 additions & 0 deletions llvm/test/CodeGen/RISCV/float-arith.ll
Original file line number Diff line number Diff line change
Expand Up @@ -1195,3 +1195,44 @@ define float @fnmsub_s_contract(float %a, float %b, float %c) nounwind {
%2 = fsub contract float %c, %1
ret float %2
}

define float @fsgnjx_f32(float %x, float %y) nounwind {
; CHECKIF-LABEL: fsgnjx_f32:
; CHECKIF: # %bb.0:
; CHECKIF-NEXT: fsgnjx.s fa0, fa1, fa0
; CHECKIF-NEXT: ret
;
; CHECKIZFINX-LABEL: fsgnjx_f32:
; CHECKIZFINX: # %bb.0:
; CHECKIZFINX-NEXT: fsgnjx.s a0, a1, a0
; CHECKIZFINX-NEXT: ret
;
; RV32I-LABEL: fsgnjx_f32:
; RV32I: # %bb.0:
; RV32I-NEXT: addi sp, sp, -16
; RV32I-NEXT: sw ra, 12(sp) # 4-byte Folded Spill
; RV32I-NEXT: lui a2, 524288
; RV32I-NEXT: and a0, a0, a2
; RV32I-NEXT: lui a2, 260096
; RV32I-NEXT: or a0, a0, a2
; RV32I-NEXT: call __mulsf3
; RV32I-NEXT: lw ra, 12(sp) # 4-byte Folded Reload
; RV32I-NEXT: addi sp, sp, 16
; RV32I-NEXT: ret
;
; RV64I-LABEL: fsgnjx_f32:
; RV64I: # %bb.0:
; RV64I-NEXT: addi sp, sp, -16
; RV64I-NEXT: sd ra, 8(sp) # 8-byte Folded Spill
; RV64I-NEXT: lui a2, 524288
; RV64I-NEXT: and a0, a0, a2
; RV64I-NEXT: lui a2, 260096
; RV64I-NEXT: or a0, a0, a2
; RV64I-NEXT: call __mulsf3
; RV64I-NEXT: ld ra, 8(sp) # 8-byte Folded Reload
; RV64I-NEXT: addi sp, sp, 16
; RV64I-NEXT: ret
%z = call float @llvm.copysign.f32(float 1.0, float %x)
%mul = fmul float %z, %y
ret float %mul
}
109 changes: 109 additions & 0 deletions llvm/test/CodeGen/RISCV/half-arith.ll
Original file line number Diff line number Diff line change
Expand Up @@ -3104,3 +3104,112 @@ define half @fnmsub_s_contract(half %a, half %b, half %c) nounwind {
%2 = fsub contract half %c, %1
ret half %2
}

define half @fsgnjx_f16(half %x, half %y) nounwind {
; CHECKIZFH-LABEL: fsgnjx_f16:
; CHECKIZFH: # %bb.0:
; CHECKIZFH-NEXT: fsgnjx.h fa0, fa1, fa0
; CHECKIZFH-NEXT: ret
;
; CHECK-ZHINX-LABEL: fsgnjx_f16:
; CHECK-ZHINX: # %bb.0:
; CHECK-ZHINX-NEXT: fsgnjx.h a0, a1, a0
; CHECK-ZHINX-NEXT: ret
;
; RV32I-LABEL: fsgnjx_f16:
; RV32I: # %bb.0:
; RV32I-NEXT: addi sp, sp, -16
; RV32I-NEXT: sw ra, 12(sp) # 4-byte Folded Spill
; RV32I-NEXT: sw s0, 8(sp) # 4-byte Folded Spill
; RV32I-NEXT: sw s1, 4(sp) # 4-byte Folded Spill
; RV32I-NEXT: li a2, 15
; RV32I-NEXT: slli a2, a2, 10
; RV32I-NEXT: or s1, a0, a2
; RV32I-NEXT: slli a0, a1, 16
; RV32I-NEXT: srli a0, a0, 16
; RV32I-NEXT: call __extendhfsf2
; RV32I-NEXT: mv s0, a0
; RV32I-NEXT: lui a0, 12
; RV32I-NEXT: addi a0, a0, -1024
; RV32I-NEXT: and a0, s1, a0
; RV32I-NEXT: call __extendhfsf2
; RV32I-NEXT: mv a1, s0
; RV32I-NEXT: call __mulsf3
; RV32I-NEXT: call __truncsfhf2
; RV32I-NEXT: lw ra, 12(sp) # 4-byte Folded Reload
; RV32I-NEXT: lw s0, 8(sp) # 4-byte Folded Reload
; RV32I-NEXT: lw s1, 4(sp) # 4-byte Folded Reload
; RV32I-NEXT: addi sp, sp, 16
; RV32I-NEXT: ret
;
; RV64I-LABEL: fsgnjx_f16:
; RV64I: # %bb.0:
; RV64I-NEXT: addi sp, sp, -32
; RV64I-NEXT: sd ra, 24(sp) # 8-byte Folded Spill
; RV64I-NEXT: sd s0, 16(sp) # 8-byte Folded Spill
; RV64I-NEXT: sd s1, 8(sp) # 8-byte Folded Spill
; RV64I-NEXT: li a2, 15
; RV64I-NEXT: slli a2, a2, 10
; RV64I-NEXT: or s1, a0, a2
; RV64I-NEXT: slli a0, a1, 48
; RV64I-NEXT: srli a0, a0, 48
; RV64I-NEXT: call __extendhfsf2
; RV64I-NEXT: mv s0, a0
; RV64I-NEXT: lui a0, 12
; RV64I-NEXT: addiw a0, a0, -1024
; RV64I-NEXT: and a0, s1, a0
; RV64I-NEXT: call __extendhfsf2
; RV64I-NEXT: mv a1, s0
; RV64I-NEXT: call __mulsf3
; RV64I-NEXT: call __truncsfhf2
; RV64I-NEXT: ld ra, 24(sp) # 8-byte Folded Reload
; RV64I-NEXT: ld s0, 16(sp) # 8-byte Folded Reload
; RV64I-NEXT: ld s1, 8(sp) # 8-byte Folded Reload
; RV64I-NEXT: addi sp, sp, 32
; RV64I-NEXT: ret
;
; CHECK-RV32-FSGNJ-LABEL: fsgnjx_f16:
; CHECK-RV32-FSGNJ: # %bb.0:
; CHECK-RV32-FSGNJ-NEXT: addi sp, sp, -16
; CHECK-RV32-FSGNJ-NEXT: lui a0, %hi(.LCPI23_0)
; CHECK-RV32-FSGNJ-NEXT: flh fa5, %lo(.LCPI23_0)(a0)
; CHECK-RV32-FSGNJ-NEXT: fsh fa0, 12(sp)
; CHECK-RV32-FSGNJ-NEXT: fsh fa5, 8(sp)
; CHECK-RV32-FSGNJ-NEXT: lbu a0, 13(sp)
; CHECK-RV32-FSGNJ-NEXT: lbu a1, 9(sp)
; CHECK-RV32-FSGNJ-NEXT: andi a0, a0, 128
; CHECK-RV32-FSGNJ-NEXT: andi a1, a1, 127
; CHECK-RV32-FSGNJ-NEXT: or a0, a1, a0
; CHECK-RV32-FSGNJ-NEXT: sb a0, 9(sp)
; CHECK-RV32-FSGNJ-NEXT: flh fa5, 8(sp)
; CHECK-RV32-FSGNJ-NEXT: fcvt.s.h fa4, fa1
; CHECK-RV32-FSGNJ-NEXT: fcvt.s.h fa5, fa5
; CHECK-RV32-FSGNJ-NEXT: fmul.s fa5, fa5, fa4
; CHECK-RV32-FSGNJ-NEXT: fcvt.h.s fa0, fa5
; CHECK-RV32-FSGNJ-NEXT: addi sp, sp, 16
; CHECK-RV32-FSGNJ-NEXT: ret
;
; CHECK-RV64-FSGNJ-LABEL: fsgnjx_f16:
; CHECK-RV64-FSGNJ: # %bb.0:
; CHECK-RV64-FSGNJ-NEXT: addi sp, sp, -16
; CHECK-RV64-FSGNJ-NEXT: lui a0, %hi(.LCPI23_0)
; CHECK-RV64-FSGNJ-NEXT: flh fa5, %lo(.LCPI23_0)(a0)
; CHECK-RV64-FSGNJ-NEXT: fsh fa0, 8(sp)
; CHECK-RV64-FSGNJ-NEXT: fsh fa5, 0(sp)
; CHECK-RV64-FSGNJ-NEXT: lbu a0, 9(sp)
; CHECK-RV64-FSGNJ-NEXT: lbu a1, 1(sp)
; CHECK-RV64-FSGNJ-NEXT: andi a0, a0, 128
; CHECK-RV64-FSGNJ-NEXT: andi a1, a1, 127
; CHECK-RV64-FSGNJ-NEXT: or a0, a1, a0
; CHECK-RV64-FSGNJ-NEXT: sb a0, 1(sp)
; CHECK-RV64-FSGNJ-NEXT: flh fa5, 0(sp)
; CHECK-RV64-FSGNJ-NEXT: fcvt.s.h fa4, fa1
; CHECK-RV64-FSGNJ-NEXT: fcvt.s.h fa5, fa5
; CHECK-RV64-FSGNJ-NEXT: fmul.s fa5, fa5, fa4
; CHECK-RV64-FSGNJ-NEXT: fcvt.h.s fa0, fa5
; CHECK-RV64-FSGNJ-NEXT: addi sp, sp, 16
; CHECK-RV64-FSGNJ-NEXT: ret
%z = call half @llvm.copysign.f16(half 1.0, half %x)
%mul = fmul half %z, %y
ret half %mul
}
Loading