Skip to content

Commit e42f473

Browse files
authored
Reland "[NVPTX] Support copysign PTX instruction" (#108125)
Lower `fcopysign` SDNodes into `copysign` PTX instructions where possible. See [PTX ISA: 9.7.3.2. Floating Point Instructions: copysign] (https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-copysign). Copysign SDNodes with mismatched types are expanded as before, since the PTX instruction requires the types to match.
1 parent 859b785 commit e42f473

File tree

5 files changed

+179
-15
lines changed

5 files changed

+179
-15
lines changed

llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -838,8 +838,8 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
838838
setOperationAction(ISD::FCOPYSIGN, MVT::v2f16, Expand);
839839
setOperationAction(ISD::FCOPYSIGN, MVT::bf16, Expand);
840840
setOperationAction(ISD::FCOPYSIGN, MVT::v2bf16, Expand);
841-
setOperationAction(ISD::FCOPYSIGN, MVT::f32, Expand);
842-
setOperationAction(ISD::FCOPYSIGN, MVT::f64, Expand);
841+
setOperationAction(ISD::FCOPYSIGN, MVT::f32, Custom);
842+
setOperationAction(ISD::FCOPYSIGN, MVT::f64, Custom);
843843

844844
// These map to corresponding instructions for f32/f64. f16 must be
845845
// promoted to f32. v2f16 is expanded to f16, which is then promoted
@@ -964,6 +964,7 @@ const char *NVPTXTargetLowering::getTargetNodeName(unsigned Opcode) const {
964964
MAKE_CASE(NVPTXISD::BFE)
965965
MAKE_CASE(NVPTXISD::BFI)
966966
MAKE_CASE(NVPTXISD::PRMT)
967+
MAKE_CASE(NVPTXISD::FCOPYSIGN)
967968
MAKE_CASE(NVPTXISD::DYNAMIC_STACKALLOC)
968969
MAKE_CASE(NVPTXISD::SETP_F16X2)
969970
MAKE_CASE(NVPTXISD::SETP_BF16X2)
@@ -2560,6 +2561,23 @@ SDValue NVPTXTargetLowering::LowerShiftLeftParts(SDValue Op,
25602561
}
25612562
}
25622563

2564+
/// If the types match, convert the generic copysign to the NVPTXISD version,
2565+
/// otherwise bail ensuring that mismatched cases are properly expaned.
2566+
SDValue NVPTXTargetLowering::LowerFCOPYSIGN(SDValue Op,
2567+
SelectionDAG &DAG) const {
2568+
EVT VT = Op.getValueType();
2569+
SDLoc DL(Op);
2570+
2571+
SDValue In1 = Op.getOperand(0);
2572+
SDValue In2 = Op.getOperand(1);
2573+
EVT SrcVT = In2.getValueType();
2574+
2575+
if (!SrcVT.bitsEq(VT))
2576+
return SDValue();
2577+
2578+
return DAG.getNode(NVPTXISD::FCOPYSIGN, DL, VT, In1, In2);
2579+
}
2580+
25632581
SDValue NVPTXTargetLowering::LowerFROUND(SDValue Op, SelectionDAG &DAG) const {
25642582
EVT VT = Op.getValueType();
25652583

@@ -2803,6 +2821,8 @@ NVPTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
28032821
return LowerSelect(Op, DAG);
28042822
case ISD::FROUND:
28052823
return LowerFROUND(Op, DAG);
2824+
case ISD::FCOPYSIGN:
2825+
return LowerFCOPYSIGN(Op, DAG);
28062826
case ISD::SINT_TO_FP:
28072827
case ISD::UINT_TO_FP:
28082828
return LowerINT_TO_FP(Op, DAG);

llvm/lib/Target/NVPTX/NVPTXISelLowering.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ enum NodeType : unsigned {
6161
BFE,
6262
BFI,
6363
PRMT,
64+
FCOPYSIGN,
6465
DYNAMIC_STACKALLOC,
6566
BrxStart,
6667
BrxItem,
@@ -623,6 +624,8 @@ class NVPTXTargetLowering : public TargetLowering {
623624
SDValue LowerINSERT_VECTOR_ELT(SDValue Op, SelectionDAG &DAG) const;
624625
SDValue LowerVECTOR_SHUFFLE(SDValue Op, SelectionDAG &DAG) const;
625626

627+
SDValue LowerFCOPYSIGN(SDValue Op, SelectionDAG &DAG) const;
628+
626629
SDValue LowerFROUND(SDValue Op, SelectionDAG &DAG) const;
627630
SDValue LowerFROUND32(SDValue Op, SelectionDAG &DAG) const;
628631
SDValue LowerFROUND64(SDValue Op, SelectionDAG &DAG) const;

llvm/lib/Target/NVPTX/NVPTXIntrinsics.td

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -977,6 +977,22 @@ def INT_NVVM_FABS_F : F_MATH_1<"abs.f32 \t$dst, $src0;", Float32Regs,
977977
def INT_NVVM_FABS_D : F_MATH_1<"abs.f64 \t$dst, $src0;", Float64Regs,
978978
Float64Regs, int_nvvm_fabs_d>;
979979

980+
//
981+
// copysign
982+
//
983+
984+
def fcopysign_nvptx : SDNode<"NVPTXISD::FCOPYSIGN", SDTFPBinOp>;
985+
986+
def COPYSIGN_F :
987+
NVPTXInst<(outs Float32Regs:$dst), (ins Float32Regs:$src0, Float32Regs:$src1),
988+
"copysign.f32 \t$dst, $src0, $src1;",
989+
[(set Float32Regs:$dst, (fcopysign_nvptx Float32Regs:$src1, Float32Regs:$src0))]>;
990+
991+
def COPYSIGN_D :
992+
NVPTXInst<(outs Float64Regs:$dst), (ins Float64Regs:$src0, Float64Regs:$src1),
993+
"copysign.f64 \t$dst, $src0, $src1;",
994+
[(set Float64Regs:$dst, (fcopysign_nvptx Float64Regs:$src1, Float64Regs:$src0))]>;
995+
980996
//
981997
// Abs, Neg bf16, bf16x2
982998
//

llvm/test/CodeGen/NVPTX/copysign.ll

Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
2+
; RUN: llc < %s -march=nvptx64 -mcpu=sm_20 -verify-machineinstrs | FileCheck %s
3+
; RUN: %if ptxas %{ llc < %s -march=nvptx64 -mcpu=sm_20 -verify-machineinstrs | %ptxas-verify %}
4+
5+
target triple = "nvptx64-nvidia-cuda"
6+
target datalayout = "e-p:32:32:32-i1:8:8-i8:8:8-i16:16:16-i32:32:32-i64:64:64-f32:32:32-f64:64:64-v16:16:16-v32:32:32-v64:64:64-v128:128:128-n16:32:64"
7+
8+
define float @fcopysign_f_f(float %a, float %b) {
9+
; CHECK-LABEL: fcopysign_f_f(
10+
; CHECK: {
11+
; CHECK-NEXT: .reg .f32 %f<4>;
12+
; CHECK-EMPTY:
13+
; CHECK-NEXT: // %bb.0:
14+
; CHECK-NEXT: ld.param.f32 %f1, [fcopysign_f_f_param_0];
15+
; CHECK-NEXT: ld.param.f32 %f2, [fcopysign_f_f_param_1];
16+
; CHECK-NEXT: copysign.f32 %f3, %f2, %f1;
17+
; CHECK-NEXT: st.param.f32 [func_retval0+0], %f3;
18+
; CHECK-NEXT: ret;
19+
%val = call float @llvm.copysign.f32(float %a, float %b)
20+
ret float %val
21+
}
22+
23+
define double @fcopysign_d_d(double %a, double %b) {
24+
; CHECK-LABEL: fcopysign_d_d(
25+
; CHECK: {
26+
; CHECK-NEXT: .reg .f64 %fd<4>;
27+
; CHECK-EMPTY:
28+
; CHECK-NEXT: // %bb.0:
29+
; CHECK-NEXT: ld.param.f64 %fd1, [fcopysign_d_d_param_0];
30+
; CHECK-NEXT: ld.param.f64 %fd2, [fcopysign_d_d_param_1];
31+
; CHECK-NEXT: copysign.f64 %fd3, %fd2, %fd1;
32+
; CHECK-NEXT: st.param.f64 [func_retval0+0], %fd3;
33+
; CHECK-NEXT: ret;
34+
%val = call double @llvm.copysign.f64(double %a, double %b)
35+
ret double %val
36+
}
37+
38+
define float @fcopysign_f_d(float %a, double %b) {
39+
; CHECK-LABEL: fcopysign_f_d(
40+
; CHECK: {
41+
; CHECK-NEXT: .reg .pred %p<2>;
42+
; CHECK-NEXT: .reg .f32 %f<5>;
43+
; CHECK-NEXT: .reg .b64 %rd<4>;
44+
; CHECK-EMPTY:
45+
; CHECK-NEXT: // %bb.0:
46+
; CHECK-NEXT: ld.param.f32 %f1, [fcopysign_f_d_param_0];
47+
; CHECK-NEXT: abs.f32 %f2, %f1;
48+
; CHECK-NEXT: neg.f32 %f3, %f2;
49+
; CHECK-NEXT: ld.param.u64 %rd1, [fcopysign_f_d_param_1];
50+
; CHECK-NEXT: shr.u64 %rd2, %rd1, 63;
51+
; CHECK-NEXT: and.b64 %rd3, %rd2, 1;
52+
; CHECK-NEXT: setp.eq.b64 %p1, %rd3, 1;
53+
; CHECK-NEXT: selp.f32 %f4, %f3, %f2, %p1;
54+
; CHECK-NEXT: st.param.f32 [func_retval0+0], %f4;
55+
; CHECK-NEXT: ret;
56+
%c = fptrunc double %b to float
57+
%val = call float @llvm.copysign.f32(float %a, float %c)
58+
ret float %val
59+
}
60+
61+
define float @fcopysign_f_h(float %a, half %b) {
62+
; CHECK-LABEL: fcopysign_f_h(
63+
; CHECK: {
64+
; CHECK-NEXT: .reg .pred %p<2>;
65+
; CHECK-NEXT: .reg .b16 %rs<4>;
66+
; CHECK-NEXT: .reg .f32 %f<5>;
67+
; CHECK-EMPTY:
68+
; CHECK-NEXT: // %bb.0:
69+
; CHECK-NEXT: ld.param.f32 %f1, [fcopysign_f_h_param_0];
70+
; CHECK-NEXT: abs.f32 %f2, %f1;
71+
; CHECK-NEXT: neg.f32 %f3, %f2;
72+
; CHECK-NEXT: ld.param.u16 %rs1, [fcopysign_f_h_param_1];
73+
; CHECK-NEXT: shr.u16 %rs2, %rs1, 15;
74+
; CHECK-NEXT: and.b16 %rs3, %rs2, 1;
75+
; CHECK-NEXT: setp.eq.b16 %p1, %rs3, 1;
76+
; CHECK-NEXT: selp.f32 %f4, %f3, %f2, %p1;
77+
; CHECK-NEXT: st.param.f32 [func_retval0+0], %f4;
78+
; CHECK-NEXT: ret;
79+
%c = fpext half %b to float
80+
%val = call float @llvm.copysign.f32(float %a, float %c)
81+
ret float %val
82+
}
83+
84+
define double @fcopysign_d_f(double %a, float %b) {
85+
; CHECK-LABEL: fcopysign_d_f(
86+
; CHECK: {
87+
; CHECK-NEXT: .reg .pred %p<2>;
88+
; CHECK-NEXT: .reg .b32 %r<4>;
89+
; CHECK-NEXT: .reg .f64 %fd<5>;
90+
; CHECK-EMPTY:
91+
; CHECK-NEXT: // %bb.0:
92+
; CHECK-NEXT: ld.param.f64 %fd1, [fcopysign_d_f_param_0];
93+
; CHECK-NEXT: abs.f64 %fd2, %fd1;
94+
; CHECK-NEXT: neg.f64 %fd3, %fd2;
95+
; CHECK-NEXT: ld.param.u32 %r1, [fcopysign_d_f_param_1];
96+
; CHECK-NEXT: shr.u32 %r2, %r1, 31;
97+
; CHECK-NEXT: and.b32 %r3, %r2, 1;
98+
; CHECK-NEXT: setp.eq.b32 %p1, %r3, 1;
99+
; CHECK-NEXT: selp.f64 %fd4, %fd3, %fd2, %p1;
100+
; CHECK-NEXT: st.param.f64 [func_retval0+0], %fd4;
101+
; CHECK-NEXT: ret;
102+
%c = fpext float %b to double
103+
%val = call double @llvm.copysign.f64(double %a, double %c)
104+
ret double %val
105+
}
106+
107+
define double @fcopysign_d_h(double %a, half %b) {
108+
; CHECK-LABEL: fcopysign_d_h(
109+
; CHECK: {
110+
; CHECK-NEXT: .reg .pred %p<2>;
111+
; CHECK-NEXT: .reg .b16 %rs<4>;
112+
; CHECK-NEXT: .reg .f64 %fd<5>;
113+
; CHECK-EMPTY:
114+
; CHECK-NEXT: // %bb.0:
115+
; CHECK-NEXT: ld.param.f64 %fd1, [fcopysign_d_h_param_0];
116+
; CHECK-NEXT: abs.f64 %fd2, %fd1;
117+
; CHECK-NEXT: neg.f64 %fd3, %fd2;
118+
; CHECK-NEXT: ld.param.u16 %rs1, [fcopysign_d_h_param_1];
119+
; CHECK-NEXT: shr.u16 %rs2, %rs1, 15;
120+
; CHECK-NEXT: and.b16 %rs3, %rs2, 1;
121+
; CHECK-NEXT: setp.eq.b16 %p1, %rs3, 1;
122+
; CHECK-NEXT: selp.f64 %fd4, %fd3, %fd2, %p1;
123+
; CHECK-NEXT: st.param.f64 [func_retval0+0], %fd4;
124+
; CHECK-NEXT: ret;
125+
%c = fpext half %b to double
126+
%val = call double @llvm.copysign.f64(double %a, double %c)
127+
ret double %val
128+
}
129+
130+
131+
declare float @llvm.copysign.f32(float, float)
132+
declare double @llvm.copysign.f64(double, double)

llvm/test/CodeGen/NVPTX/math-intrins.ll

Lines changed: 6 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -195,9 +195,8 @@ define double @round_double(double %a) {
195195
; check the use of 0.5 to implement round
196196
; CHECK-LABEL: round_double(
197197
; CHECK: {
198-
; CHECK-NEXT: .reg .pred %p<4>;
199-
; CHECK-NEXT: .reg .b64 %rd<4>;
200-
; CHECK-NEXT: .reg .f64 %fd<10>;
198+
; CHECK-NEXT: .reg .pred %p<3>;
199+
; CHECK-NEXT: .reg .f64 %fd<8>;
201200
; CHECK-EMPTY:
202201
; CHECK-NEXT: // %bb.0:
203202
; CHECK-NEXT: ld.param.f64 %fd1, [round_double_param_0];
@@ -206,16 +205,10 @@ define double @round_double(double %a) {
206205
; CHECK-NEXT: add.rn.f64 %fd3, %fd2, 0d3FE0000000000000;
207206
; CHECK-NEXT: cvt.rzi.f64.f64 %fd4, %fd3;
208207
; CHECK-NEXT: selp.f64 %fd5, 0d0000000000000000, %fd4, %p1;
209-
; CHECK-NEXT: abs.f64 %fd6, %fd5;
210-
; CHECK-NEXT: neg.f64 %fd7, %fd6;
211-
; CHECK-NEXT: mov.b64 %rd1, %fd1;
212-
; CHECK-NEXT: shr.u64 %rd2, %rd1, 63;
213-
; CHECK-NEXT: and.b64 %rd3, %rd2, 1;
214-
; CHECK-NEXT: setp.eq.b64 %p2, %rd3, 1;
215-
; CHECK-NEXT: selp.f64 %fd8, %fd7, %fd6, %p2;
216-
; CHECK-NEXT: setp.gt.f64 %p3, %fd2, 0d4330000000000000;
217-
; CHECK-NEXT: selp.f64 %fd9, %fd1, %fd8, %p3;
218-
; CHECK-NEXT: st.param.f64 [func_retval0+0], %fd9;
208+
; CHECK-NEXT: copysign.f64 %fd6, %fd1, %fd5;
209+
; CHECK-NEXT: setp.gt.f64 %p2, %fd2, 0d4330000000000000;
210+
; CHECK-NEXT: selp.f64 %fd7, %fd1, %fd6, %p2;
211+
; CHECK-NEXT: st.param.f64 [func_retval0+0], %fd7;
219212
; CHECK-NEXT: ret;
220213
%b = call double @llvm.round.f64(double %a)
221214
ret double %b

0 commit comments

Comments
 (0)