Skip to content

Reland "[NVPTX] Support copysign PTX instruction" #108125

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Sep 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 22 additions & 2 deletions llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -838,8 +838,8 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
setOperationAction(ISD::FCOPYSIGN, MVT::v2f16, Expand);
setOperationAction(ISD::FCOPYSIGN, MVT::bf16, Expand);
setOperationAction(ISD::FCOPYSIGN, MVT::v2bf16, Expand);
setOperationAction(ISD::FCOPYSIGN, MVT::f32, Expand);
setOperationAction(ISD::FCOPYSIGN, MVT::f64, Expand);
setOperationAction(ISD::FCOPYSIGN, MVT::f32, Custom);
setOperationAction(ISD::FCOPYSIGN, MVT::f64, Custom);

// These map to corresponding instructions for f32/f64. f16 must be
// promoted to f32. v2f16 is expanded to f16, which is then promoted
Expand Down Expand Up @@ -964,6 +964,7 @@ const char *NVPTXTargetLowering::getTargetNodeName(unsigned Opcode) const {
MAKE_CASE(NVPTXISD::BFE)
MAKE_CASE(NVPTXISD::BFI)
MAKE_CASE(NVPTXISD::PRMT)
MAKE_CASE(NVPTXISD::FCOPYSIGN)
MAKE_CASE(NVPTXISD::DYNAMIC_STACKALLOC)
MAKE_CASE(NVPTXISD::SETP_F16X2)
MAKE_CASE(NVPTXISD::SETP_BF16X2)
Expand Down Expand Up @@ -2560,6 +2561,23 @@ SDValue NVPTXTargetLowering::LowerShiftLeftParts(SDValue Op,
}
}

/// If the types match, convert the generic copysign to the NVPTXISD version,
/// otherwise bail ensuring that mismatched cases are properly expaned.
SDValue NVPTXTargetLowering::LowerFCOPYSIGN(SDValue Op,
SelectionDAG &DAG) const {
EVT VT = Op.getValueType();
SDLoc DL(Op);

SDValue In1 = Op.getOperand(0);
SDValue In2 = Op.getOperand(1);
EVT SrcVT = In2.getValueType();

if (!SrcVT.bitsEq(VT))
return SDValue();

return DAG.getNode(NVPTXISD::FCOPYSIGN, DL, VT, In1, In2);
}

SDValue NVPTXTargetLowering::LowerFROUND(SDValue Op, SelectionDAG &DAG) const {
EVT VT = Op.getValueType();

Expand Down Expand Up @@ -2803,6 +2821,8 @@ NVPTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
return LowerSelect(Op, DAG);
case ISD::FROUND:
return LowerFROUND(Op, DAG);
case ISD::FCOPYSIGN:
return LowerFCOPYSIGN(Op, DAG);
case ISD::SINT_TO_FP:
case ISD::UINT_TO_FP:
return LowerINT_TO_FP(Op, DAG);
Expand Down
3 changes: 3 additions & 0 deletions llvm/lib/Target/NVPTX/NVPTXISelLowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ enum NodeType : unsigned {
BFE,
BFI,
PRMT,
FCOPYSIGN,
DYNAMIC_STACKALLOC,
BrxStart,
BrxItem,
Expand Down Expand Up @@ -623,6 +624,8 @@ class NVPTXTargetLowering : public TargetLowering {
SDValue LowerINSERT_VECTOR_ELT(SDValue Op, SelectionDAG &DAG) const;
SDValue LowerVECTOR_SHUFFLE(SDValue Op, SelectionDAG &DAG) const;

SDValue LowerFCOPYSIGN(SDValue Op, SelectionDAG &DAG) const;

SDValue LowerFROUND(SDValue Op, SelectionDAG &DAG) const;
SDValue LowerFROUND32(SDValue Op, SelectionDAG &DAG) const;
SDValue LowerFROUND64(SDValue Op, SelectionDAG &DAG) const;
Expand Down
16 changes: 16 additions & 0 deletions llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
Original file line number Diff line number Diff line change
Expand Up @@ -977,6 +977,22 @@ def INT_NVVM_FABS_F : F_MATH_1<"abs.f32 \t$dst, $src0;", Float32Regs,
def INT_NVVM_FABS_D : F_MATH_1<"abs.f64 \t$dst, $src0;", Float64Regs,
Float64Regs, int_nvvm_fabs_d>;

//
// copysign
//

def fcopysign_nvptx : SDNode<"NVPTXISD::FCOPYSIGN", SDTFPBinOp>;

def COPYSIGN_F :
NVPTXInst<(outs Float32Regs:$dst), (ins Float32Regs:$src0, Float32Regs:$src1),
"copysign.f32 \t$dst, $src0, $src1;",
[(set Float32Regs:$dst, (fcopysign_nvptx Float32Regs:$src1, Float32Regs:$src0))]>;

def COPYSIGN_D :
NVPTXInst<(outs Float64Regs:$dst), (ins Float64Regs:$src0, Float64Regs:$src1),
"copysign.f64 \t$dst, $src0, $src1;",
[(set Float64Regs:$dst, (fcopysign_nvptx Float64Regs:$src1, Float64Regs:$src0))]>;

//
// Abs, Neg bf16, bf16x2
//
Expand Down
132 changes: 132 additions & 0 deletions llvm/test/CodeGen/NVPTX/copysign.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
; RUN: llc < %s -march=nvptx64 -mcpu=sm_20 -verify-machineinstrs | FileCheck %s
; RUN: %if ptxas %{ llc < %s -march=nvptx64 -mcpu=sm_20 -verify-machineinstrs | %ptxas-verify %}

target triple = "nvptx64-nvidia-cuda"
target datalayout = "e-p:32:32:32-i1:8:8-i8:8:8-i16:16:16-i32:32:32-i64:64:64-f32:32:32-f64:64:64-v16:16:16-v32:32:32-v64:64:64-v128:128:128-n16:32:64"

define float @fcopysign_f_f(float %a, float %b) {
; CHECK-LABEL: fcopysign_f_f(
; CHECK: {
; CHECK-NEXT: .reg .f32 %f<4>;
; CHECK-EMPTY:
; CHECK-NEXT: // %bb.0:
; CHECK-NEXT: ld.param.f32 %f1, [fcopysign_f_f_param_0];
; CHECK-NEXT: ld.param.f32 %f2, [fcopysign_f_f_param_1];
; CHECK-NEXT: copysign.f32 %f3, %f2, %f1;
; CHECK-NEXT: st.param.f32 [func_retval0+0], %f3;
; CHECK-NEXT: ret;
%val = call float @llvm.copysign.f32(float %a, float %b)
ret float %val
}

define double @fcopysign_d_d(double %a, double %b) {
; CHECK-LABEL: fcopysign_d_d(
; CHECK: {
; CHECK-NEXT: .reg .f64 %fd<4>;
; CHECK-EMPTY:
; CHECK-NEXT: // %bb.0:
; CHECK-NEXT: ld.param.f64 %fd1, [fcopysign_d_d_param_0];
; CHECK-NEXT: ld.param.f64 %fd2, [fcopysign_d_d_param_1];
; CHECK-NEXT: copysign.f64 %fd3, %fd2, %fd1;
; CHECK-NEXT: st.param.f64 [func_retval0+0], %fd3;
; CHECK-NEXT: ret;
%val = call double @llvm.copysign.f64(double %a, double %b)
ret double %val
}

define float @fcopysign_f_d(float %a, double %b) {
; CHECK-LABEL: fcopysign_f_d(
; CHECK: {
; CHECK-NEXT: .reg .pred %p<2>;
; CHECK-NEXT: .reg .f32 %f<5>;
; CHECK-NEXT: .reg .b64 %rd<4>;
; CHECK-EMPTY:
; CHECK-NEXT: // %bb.0:
; CHECK-NEXT: ld.param.f32 %f1, [fcopysign_f_d_param_0];
; CHECK-NEXT: abs.f32 %f2, %f1;
; CHECK-NEXT: neg.f32 %f3, %f2;
; CHECK-NEXT: ld.param.u64 %rd1, [fcopysign_f_d_param_1];
; CHECK-NEXT: shr.u64 %rd2, %rd1, 63;
; CHECK-NEXT: and.b64 %rd3, %rd2, 1;
; CHECK-NEXT: setp.eq.b64 %p1, %rd3, 1;
; CHECK-NEXT: selp.f32 %f4, %f3, %f2, %p1;
; CHECK-NEXT: st.param.f32 [func_retval0+0], %f4;
; CHECK-NEXT: ret;
%c = fptrunc double %b to float
%val = call float @llvm.copysign.f32(float %a, float %c)
ret float %val
}

define float @fcopysign_f_h(float %a, half %b) {
; CHECK-LABEL: fcopysign_f_h(
; CHECK: {
; CHECK-NEXT: .reg .pred %p<2>;
; CHECK-NEXT: .reg .b16 %rs<4>;
; CHECK-NEXT: .reg .f32 %f<5>;
; CHECK-EMPTY:
; CHECK-NEXT: // %bb.0:
; CHECK-NEXT: ld.param.f32 %f1, [fcopysign_f_h_param_0];
; CHECK-NEXT: abs.f32 %f2, %f1;
; CHECK-NEXT: neg.f32 %f3, %f2;
; CHECK-NEXT: ld.param.u16 %rs1, [fcopysign_f_h_param_1];
; CHECK-NEXT: shr.u16 %rs2, %rs1, 15;
; CHECK-NEXT: and.b16 %rs3, %rs2, 1;
; CHECK-NEXT: setp.eq.b16 %p1, %rs3, 1;
; CHECK-NEXT: selp.f32 %f4, %f3, %f2, %p1;
; CHECK-NEXT: st.param.f32 [func_retval0+0], %f4;
; CHECK-NEXT: ret;
%c = fpext half %b to float
%val = call float @llvm.copysign.f32(float %a, float %c)
ret float %val
}

define double @fcopysign_d_f(double %a, float %b) {
; CHECK-LABEL: fcopysign_d_f(
; CHECK: {
; CHECK-NEXT: .reg .pred %p<2>;
; CHECK-NEXT: .reg .b32 %r<4>;
; CHECK-NEXT: .reg .f64 %fd<5>;
; CHECK-EMPTY:
; CHECK-NEXT: // %bb.0:
; CHECK-NEXT: ld.param.f64 %fd1, [fcopysign_d_f_param_0];
; CHECK-NEXT: abs.f64 %fd2, %fd1;
; CHECK-NEXT: neg.f64 %fd3, %fd2;
; CHECK-NEXT: ld.param.u32 %r1, [fcopysign_d_f_param_1];
; CHECK-NEXT: shr.u32 %r2, %r1, 31;
; CHECK-NEXT: and.b32 %r3, %r2, 1;
; CHECK-NEXT: setp.eq.b32 %p1, %r3, 1;
; CHECK-NEXT: selp.f64 %fd4, %fd3, %fd2, %p1;
; CHECK-NEXT: st.param.f64 [func_retval0+0], %fd4;
; CHECK-NEXT: ret;
%c = fpext float %b to double
%val = call double @llvm.copysign.f64(double %a, double %c)
ret double %val
}

define double @fcopysign_d_h(double %a, half %b) {
; CHECK-LABEL: fcopysign_d_h(
; CHECK: {
; CHECK-NEXT: .reg .pred %p<2>;
; CHECK-NEXT: .reg .b16 %rs<4>;
; CHECK-NEXT: .reg .f64 %fd<5>;
; CHECK-EMPTY:
; CHECK-NEXT: // %bb.0:
; CHECK-NEXT: ld.param.f64 %fd1, [fcopysign_d_h_param_0];
; CHECK-NEXT: abs.f64 %fd2, %fd1;
; CHECK-NEXT: neg.f64 %fd3, %fd2;
; CHECK-NEXT: ld.param.u16 %rs1, [fcopysign_d_h_param_1];
; CHECK-NEXT: shr.u16 %rs2, %rs1, 15;
; CHECK-NEXT: and.b16 %rs3, %rs2, 1;
; CHECK-NEXT: setp.eq.b16 %p1, %rs3, 1;
; CHECK-NEXT: selp.f64 %fd4, %fd3, %fd2, %p1;
; CHECK-NEXT: st.param.f64 [func_retval0+0], %fd4;
; CHECK-NEXT: ret;
%c = fpext half %b to double
%val = call double @llvm.copysign.f64(double %a, double %c)
ret double %val
}


declare float @llvm.copysign.f32(float, float)
declare double @llvm.copysign.f64(double, double)
19 changes: 6 additions & 13 deletions llvm/test/CodeGen/NVPTX/math-intrins.ll
Original file line number Diff line number Diff line change
Expand Up @@ -195,9 +195,8 @@ define double @round_double(double %a) {
; check the use of 0.5 to implement round
; CHECK-LABEL: round_double(
; CHECK: {
; CHECK-NEXT: .reg .pred %p<4>;
; CHECK-NEXT: .reg .b64 %rd<4>;
; CHECK-NEXT: .reg .f64 %fd<10>;
; CHECK-NEXT: .reg .pred %p<3>;
; CHECK-NEXT: .reg .f64 %fd<8>;
; CHECK-EMPTY:
; CHECK-NEXT: // %bb.0:
; CHECK-NEXT: ld.param.f64 %fd1, [round_double_param_0];
Expand All @@ -206,16 +205,10 @@ define double @round_double(double %a) {
; CHECK-NEXT: add.rn.f64 %fd3, %fd2, 0d3FE0000000000000;
; CHECK-NEXT: cvt.rzi.f64.f64 %fd4, %fd3;
; CHECK-NEXT: selp.f64 %fd5, 0d0000000000000000, %fd4, %p1;
; CHECK-NEXT: abs.f64 %fd6, %fd5;
; CHECK-NEXT: neg.f64 %fd7, %fd6;
; CHECK-NEXT: mov.b64 %rd1, %fd1;
; CHECK-NEXT: shr.u64 %rd2, %rd1, 63;
; CHECK-NEXT: and.b64 %rd3, %rd2, 1;
; CHECK-NEXT: setp.eq.b64 %p2, %rd3, 1;
; CHECK-NEXT: selp.f64 %fd8, %fd7, %fd6, %p2;
; CHECK-NEXT: setp.gt.f64 %p3, %fd2, 0d4330000000000000;
; CHECK-NEXT: selp.f64 %fd9, %fd1, %fd8, %p3;
; CHECK-NEXT: st.param.f64 [func_retval0+0], %fd9;
; CHECK-NEXT: copysign.f64 %fd6, %fd1, %fd5;
; CHECK-NEXT: setp.gt.f64 %p2, %fd2, 0d4330000000000000;
; CHECK-NEXT: selp.f64 %fd7, %fd1, %fd6, %p2;
; CHECK-NEXT: st.param.f64 [func_retval0+0], %fd7;
; CHECK-NEXT: ret;
%b = call double @llvm.round.f64(double %a)
ret double %b
Expand Down
Loading