Skip to content

Commit 5587627

Browse files
authored
[NVPTX] Add support for f16 fabs (#116107)
Add support for f16 and f16x2 support for abs. See PTX ISA 9.7.4.6. Half Precision Floating Point Instructions: abs https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#half-precision-floating-point-instructions-abs
1 parent 31aa7f3 commit 5587627

File tree

2 files changed

+110
-9
lines changed

2 files changed

+110
-9
lines changed

llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -862,16 +862,19 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
862862
setOperationAction(Op, MVT::bf16, Promote);
863863
AddPromotedToType(Op, MVT::bf16, MVT::f32);
864864
}
865-
for (const auto &Op : {ISD::FABS}) {
866-
setOperationAction(Op, MVT::f16, Promote);
867-
setOperationAction(Op, MVT::f32, Legal);
868-
setOperationAction(Op, MVT::f64, Legal);
869-
setOperationAction(Op, MVT::v2f16, Expand);
870-
setBF16OperationAction(Op, MVT::v2bf16, Legal, Expand);
871-
setBF16OperationAction(Op, MVT::bf16, Legal, Promote);
872-
if (getOperationAction(Op, MVT::bf16) == Promote)
873-
AddPromotedToType(Op, MVT::bf16, MVT::f32);
865+
866+
setOperationAction(ISD::FABS, {MVT::f32, MVT::f64}, Legal);
867+
if (STI.getPTXVersion() >= 65) {
868+
setFP16OperationAction(ISD::FABS, MVT::f16, Legal, Promote);
869+
setFP16OperationAction(ISD::FABS, MVT::v2f16, Legal, Expand);
870+
} else {
871+
setOperationAction(ISD::FABS, MVT::f16, Promote);
872+
setOperationAction(ISD::FABS, MVT::v2f16, Expand);
874873
}
874+
setBF16OperationAction(ISD::FABS, MVT::v2bf16, Legal, Expand);
875+
setBF16OperationAction(ISD::FABS, MVT::bf16, Legal, Promote);
876+
if (getOperationAction(ISD::FABS, MVT::bf16) == Promote)
877+
AddPromotedToType(ISD::FABS, MVT::bf16, MVT::f32);
875878

876879
for (const auto &Op : {ISD::FMINNUM, ISD::FMAXNUM}) {
877880
setOperationAction(Op, MVT::f32, Legal);

llvm/test/CodeGen/NVPTX/f16-abs.ll

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
2+
3+
; ## FP16 abs is not supported by PTX version (PTX < 65).
4+
; RUN: llc < %s -mcpu=sm_53 -mattr=+ptx60 \
5+
; RUN: -O0 -disable-post-ra -verify-machineinstrs \
6+
; RUN: | FileCheck -check-prefix CHECK-NOF16 %s
7+
; RUN: %if ptxas %{ \
8+
; RUN: llc < %s -mcpu=sm_53 -mattr=+ptx60 \
9+
; RUN: -O0 -disable-post-ra -verify-machineinstrs \
10+
; RUN: | %ptxas-verify -arch=sm_53 \
11+
; RUN: %}
12+
13+
; ## FP16 support explicitly disabled (--nvptx-no-f16-math).
14+
; RUN: llc < %s -mcpu=sm_53 -mattr=+ptx65 --nvptx-no-f16-math \
15+
; RUN: -O0 -disable-post-ra -verify-machineinstrs \
16+
; RUN: | FileCheck -check-prefix CHECK-NOF16 %s
17+
; RUN: %if ptxas %{ \
18+
; RUN: llc < %s -mcpu=sm_53 -mattr=+ptx65 --nvptx-no-f16-math \
19+
; RUN: -O0 -disable-post-ra -verify-machineinstrs \
20+
; RUN: | %ptxas-verify -arch=sm_53 \
21+
; RUN: %}
22+
23+
; ## FP16 is not supported by hardware (SM < 53).
24+
; RUN: llc < %s -mcpu=sm_52 -mattr=+ptx65 \
25+
; RUN: -O0 -disable-post-ra -verify-machineinstrs \
26+
; RUN: | FileCheck -check-prefix CHECK-NOF16 %s
27+
; RUN: %if ptxas %{ \
28+
; RUN: llc < %s -mcpu=sm_52 -mattr=+ptx65 \
29+
; RUN: -O0 -disable-post-ra -verify-machineinstrs \
30+
; RUN: | %ptxas-verify -arch=sm_52 \
31+
; RUN: %}
32+
33+
; ## Full FP16 abs support.
34+
; RUN: llc < %s -mcpu=sm_53 -mattr=+ptx65 \
35+
; RUN: -O0 -disable-post-ra -verify-machineinstrs \
36+
; RUN: | FileCheck -check-prefix CHECK-F16-ABS %s
37+
; RUN: %if ptxas %{ \
38+
; RUN: llc < %s -mcpu=sm_53 -mattr=+ptx65 \
39+
; RUN: -O0 -disable-post-ra -verify-machineinstrs \
40+
; RUN: | %ptxas-verify -arch=sm_53 \
41+
; RUN: %}
42+
43+
target triple = "nvptx64-nvidia-cuda"
44+
45+
declare half @llvm.fabs.f16(half %a)
46+
declare <2 x half> @llvm.fabs.v2f16(<2 x half> %a)
47+
48+
define half @test_fabs(half %a) {
49+
; CHECK-NOF16-LABEL: test_fabs(
50+
; CHECK-NOF16: {
51+
; CHECK-NOF16-NEXT: .reg .b16 %rs<3>;
52+
; CHECK-NOF16-NEXT: .reg .f32 %f<3>;
53+
; CHECK-NOF16-EMPTY:
54+
; CHECK-NOF16-NEXT: // %bb.0:
55+
; CHECK-NOF16-NEXT: ld.param.b16 %rs1, [test_fabs_param_0];
56+
; CHECK-NOF16-NEXT: cvt.f32.f16 %f1, %rs1;
57+
; CHECK-NOF16-NEXT: abs.f32 %f2, %f1;
58+
; CHECK-NOF16-NEXT: cvt.rn.f16.f32 %rs2, %f2;
59+
; CHECK-NOF16-NEXT: st.param.b16 [func_retval0], %rs2;
60+
; CHECK-NOF16-NEXT: ret;
61+
;
62+
; CHECK-F16-ABS-LABEL: test_fabs(
63+
; CHECK-F16-ABS: {
64+
; CHECK-F16-ABS-NEXT: .reg .b16 %rs<3>;
65+
; CHECK-F16-ABS-EMPTY:
66+
; CHECK-F16-ABS-NEXT: // %bb.0:
67+
; CHECK-F16-ABS-NEXT: ld.param.b16 %rs1, [test_fabs_param_0];
68+
; CHECK-F16-ABS-NEXT: abs.f16 %rs2, %rs1;
69+
; CHECK-F16-ABS-NEXT: st.param.b16 [func_retval0], %rs2;
70+
; CHECK-F16-ABS-NEXT: ret;
71+
%r = call half @llvm.fabs.f16(half %a)
72+
ret half %r
73+
}
74+
75+
define <2 x half> @test_fabs_2(<2 x half> %a) #0 {
76+
; CHECK-F16-LABEL: test_fabs_2(
77+
; CHECK-F16: {
78+
; CHECK-F16-NEXT: .reg .b32 %r<5>;
79+
; CHECK-F16-EMPTY:
80+
; CHECK-F16-NEXT: // %bb.0:
81+
; CHECK-F16-NEXT: ld.param.b32 %r1, [test_fabs_2_param_0];
82+
; CHECK-F16-NEXT: and.b32 %r3, %r1, 2147450879;
83+
; CHECK-F16-NEXT: st.param.b32 [func_retval0], %r3;
84+
; CHECK-F16-NEXT: ret;
85+
;
86+
; CHECK-F16-ABS-LABEL: test_fabs_2(
87+
; CHECK-F16-ABS: {
88+
; CHECK-F16-ABS-NEXT: .reg .b32 %r<3>;
89+
; CHECK-F16-ABS-EMPTY:
90+
; CHECK-F16-ABS-NEXT: // %bb.0:
91+
; CHECK-F16-ABS-NEXT: ld.param.b32 %r1, [test_fabs_2_param_0];
92+
; CHECK-F16-ABS-NEXT: abs.f16x2 %r2, %r1;
93+
; CHECK-F16-ABS-NEXT: st.param.b32 [func_retval0], %r2;
94+
; CHECK-F16-ABS-NEXT: ret;
95+
%r = call <2 x half> @llvm.fabs.v2f16(<2 x half> %a)
96+
ret <2 x half> %r
97+
}
98+

0 commit comments

Comments
 (0)