Skip to content

[NVPTX] Add support for f16 fabs #116107

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Nov 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 12 additions & 9 deletions llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -862,16 +862,19 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
setOperationAction(Op, MVT::bf16, Promote);
AddPromotedToType(Op, MVT::bf16, MVT::f32);
}
for (const auto &Op : {ISD::FABS}) {
setOperationAction(Op, MVT::f16, Promote);
setOperationAction(Op, MVT::f32, Legal);
setOperationAction(Op, MVT::f64, Legal);
setOperationAction(Op, MVT::v2f16, Expand);
setBF16OperationAction(Op, MVT::v2bf16, Legal, Expand);
setBF16OperationAction(Op, MVT::bf16, Legal, Promote);
if (getOperationAction(Op, MVT::bf16) == Promote)
AddPromotedToType(Op, MVT::bf16, MVT::f32);

setOperationAction(ISD::FABS, {MVT::f32, MVT::f64}, Legal);
if (STI.getPTXVersion() >= 65) {
setFP16OperationAction(ISD::FABS, MVT::f16, Legal, Promote);
setFP16OperationAction(ISD::FABS, MVT::v2f16, Legal, Expand);
} else {
setOperationAction(ISD::FABS, MVT::f16, Promote);
setOperationAction(ISD::FABS, MVT::v2f16, Expand);
}
setBF16OperationAction(ISD::FABS, MVT::v2bf16, Legal, Expand);
setBF16OperationAction(ISD::FABS, MVT::bf16, Legal, Promote);
if (getOperationAction(ISD::FABS, MVT::bf16) == Promote)
AddPromotedToType(ISD::FABS, MVT::bf16, MVT::f32);

for (const auto &Op : {ISD::FMINNUM, ISD::FMAXNUM}) {
setOperationAction(Op, MVT::f32, Legal);
Expand Down
98 changes: 98 additions & 0 deletions llvm/test/CodeGen/NVPTX/f16-abs.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5

; ## FP16 abs is not supported by PTX version (PTX < 65).
; RUN: llc < %s -mcpu=sm_53 -mattr=+ptx60 \
; RUN: -O0 -disable-post-ra -verify-machineinstrs \
; RUN: | FileCheck -check-prefix CHECK-NOF16 %s
; RUN: %if ptxas %{ \
; RUN: llc < %s -mcpu=sm_53 -mattr=+ptx60 \
; RUN: -O0 -disable-post-ra -verify-machineinstrs \
; RUN: | %ptxas-verify -arch=sm_53 \
; RUN: %}

; ## FP16 support explicitly disabled (--nvptx-no-f16-math).
; RUN: llc < %s -mcpu=sm_53 -mattr=+ptx65 --nvptx-no-f16-math \
; RUN: -O0 -disable-post-ra -verify-machineinstrs \
; RUN: | FileCheck -check-prefix CHECK-NOF16 %s
; RUN: %if ptxas %{ \
; RUN: llc < %s -mcpu=sm_53 -mattr=+ptx65 --nvptx-no-f16-math \
; RUN: -O0 -disable-post-ra -verify-machineinstrs \
; RUN: | %ptxas-verify -arch=sm_53 \
; RUN: %}

; ## FP16 is not supported by hardware (SM < 53).
; RUN: llc < %s -mcpu=sm_52 -mattr=+ptx65 \
; RUN: -O0 -disable-post-ra -verify-machineinstrs \
; RUN: | FileCheck -check-prefix CHECK-NOF16 %s
; RUN: %if ptxas %{ \
; RUN: llc < %s -mcpu=sm_52 -mattr=+ptx65 \
; RUN: -O0 -disable-post-ra -verify-machineinstrs \
; RUN: | %ptxas-verify -arch=sm_52 \
; RUN: %}

; ## Full FP16 abs support.
; RUN: llc < %s -mcpu=sm_53 -mattr=+ptx65 \
; RUN: -O0 -disable-post-ra -verify-machineinstrs \
; RUN: | FileCheck -check-prefix CHECK-F16-ABS %s
; RUN: %if ptxas %{ \
; RUN: llc < %s -mcpu=sm_53 -mattr=+ptx65 \
; RUN: -O0 -disable-post-ra -verify-machineinstrs \
; RUN: | %ptxas-verify -arch=sm_53 \
; RUN: %}

target triple = "nvptx64-nvidia-cuda"

declare half @llvm.fabs.f16(half %a)
declare <2 x half> @llvm.fabs.v2f16(<2 x half> %a)

define half @test_fabs(half %a) {
; CHECK-NOF16-LABEL: test_fabs(
; CHECK-NOF16: {
; CHECK-NOF16-NEXT: .reg .b16 %rs<3>;
; CHECK-NOF16-NEXT: .reg .f32 %f<3>;
; CHECK-NOF16-EMPTY:
; CHECK-NOF16-NEXT: // %bb.0:
; CHECK-NOF16-NEXT: ld.param.b16 %rs1, [test_fabs_param_0];
; CHECK-NOF16-NEXT: cvt.f32.f16 %f1, %rs1;
; CHECK-NOF16-NEXT: abs.f32 %f2, %f1;
; CHECK-NOF16-NEXT: cvt.rn.f16.f32 %rs2, %f2;
; CHECK-NOF16-NEXT: st.param.b16 [func_retval0], %rs2;
; CHECK-NOF16-NEXT: ret;
;
; CHECK-F16-ABS-LABEL: test_fabs(
; CHECK-F16-ABS: {
; CHECK-F16-ABS-NEXT: .reg .b16 %rs<3>;
; CHECK-F16-ABS-EMPTY:
; CHECK-F16-ABS-NEXT: // %bb.0:
; CHECK-F16-ABS-NEXT: ld.param.b16 %rs1, [test_fabs_param_0];
; CHECK-F16-ABS-NEXT: abs.f16 %rs2, %rs1;
; CHECK-F16-ABS-NEXT: st.param.b16 [func_retval0], %rs2;
; CHECK-F16-ABS-NEXT: ret;
%r = call half @llvm.fabs.f16(half %a)
ret half %r
}

define <2 x half> @test_fabs_2(<2 x half> %a) #0 {
; CHECK-F16-LABEL: test_fabs_2(
; CHECK-F16: {
; CHECK-F16-NEXT: .reg .b32 %r<5>;
; CHECK-F16-EMPTY:
; CHECK-F16-NEXT: // %bb.0:
; CHECK-F16-NEXT: ld.param.b32 %r1, [test_fabs_2_param_0];
; CHECK-F16-NEXT: and.b32 %r3, %r1, 2147450879;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Huh. This is interesting. If we are allowed to do fabs via sign masking, perhaps instead of promotion to fp32 for scalars, we should cutom-lower it to a logical op for the fallback path for fp16 scalars where native fp16 abs is not supported. That's likely more efficient than abs.f32 plus two conversions.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To clarify -- it's an optimization opportunity. If it easy to incorporate into this change -- great. If not, can be done separately. This CL is good to go as is.

Copy link
Member Author

@AlexMaclean AlexMaclean Nov 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it is fine from the perspective of llvm IR's semantics to implement f16 abs with an and. Actually, it's probably more conformant because it will preserve NaN payloads, while the conversions may not. That being said, I'm not sure about the perf implications of going this route. Maybe in some cases maybe the abs could be strung together with other promoted operations and result in better codegen?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the only case where it will matter is when we would have a cluster of fp16 ops promoted to f32, with a few abs in the middle. If LLVM would decide to do that abs in fp16, that indeed may be slower. I'm not sure how exectly this is handled. In theory it would boil down to cost analysis between two casts + and vs fp32 fabs, with fp32 fabs winning.

That may be something to teach instcombine about, if we do not, yet.

In either case, sm_60 and older GPUs are nearly obsolete these days and are not worth spending much effort on. We can just leave things as is.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good, will leave as is for older architectures for now.

; CHECK-F16-NEXT: st.param.b32 [func_retval0], %r3;
; CHECK-F16-NEXT: ret;
;
; CHECK-F16-ABS-LABEL: test_fabs_2(
; CHECK-F16-ABS: {
; CHECK-F16-ABS-NEXT: .reg .b32 %r<3>;
; CHECK-F16-ABS-EMPTY:
; CHECK-F16-ABS-NEXT: // %bb.0:
; CHECK-F16-ABS-NEXT: ld.param.b32 %r1, [test_fabs_2_param_0];
; CHECK-F16-ABS-NEXT: abs.f16x2 %r2, %r1;
; CHECK-F16-ABS-NEXT: st.param.b32 [func_retval0], %r2;
; CHECK-F16-ABS-NEXT: ret;
%r = call <2 x half> @llvm.fabs.v2f16(<2 x half> %a)
ret <2 x half> %r
}

Loading