Skip to content

[DAGCombiner] Eliminate fp casts if we have the right fast math flags #131345

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Apr 28, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 45 additions & 0 deletions llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18455,7 +18455,45 @@ SDValue DAGCombiner::visitFP_ROUND(SDNode *N) {
return SDValue();
}

// Eliminate a floating-point widening of a narrowed value if the fast math
// flags allow it.
static SDValue eliminateFPCastPair(SDNode *N) {
SDValue N0 = N->getOperand(0);
EVT VT = N->getValueType(0);

unsigned NarrowingOp;
switch (N->getOpcode()) {
case ISD::FP16_TO_FP:
NarrowingOp = ISD::FP_TO_FP16;
break;
case ISD::BF16_TO_FP:
NarrowingOp = ISD::FP_TO_BF16;
break;
Comment on lines +18466 to +18471
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you leave these cases for a separate patch?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a specific reason for doing it that way?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because these cases are a pain, and it's not obvious to me this is even tested in the patch

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FP16_TO_FP is tested in llvm/test/CodeGen/ARM/fp16_fast_math.ll - ARM uses FP16_TO_FP for fp16-to-float conversion as it doesn't have registers for fp16 types, AArch64 uses FP_EXTEND as it does have such registers. We don't have any BF16 tests though, so I've now added these.

case ISD::FP_EXTEND:
NarrowingOp = ISD::FP_ROUND;
break;
default:
llvm_unreachable("Expected widening FP cast");
}
Comment on lines +18464 to +18477
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should not be a switch/assign/break. Surely there's a utility function for this somewhere already?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There doesn't seem to be one as far as I can tell. I could move this out into a separate function like e.g. ISD::getInverseMinMaxOpcode, but that would just be moving the switch to a separate location.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You could change each call site to pass in NarrowingOp as an argument.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That then means visitFP16_TO_FP needs that same switch since despite the name, it is not only called for ISD::FP16_TO_FP. That doesn't get rid of the switch, that just moves it.


if (N0.getOpcode() == NarrowingOp && N0.getOperand(0).getValueType() == VT) {
const SDNodeFlags NarrowFlags = N0->getFlags();
const SDNodeFlags WidenFlags = N->getFlags();
// Narrowing can introduce inf and change the encoding of a nan, so the
// widen must have the nnan and ninf flags to indicate that we don't need to
// care about that. We are also removing a rounding step, and that requires
// both the narrow and widen to allow contraction.
if (WidenFlags.hasNoNaNs() && WidenFlags.hasNoInfs() &&
NarrowFlags.hasAllowContract() && WidenFlags.hasAllowContract()) {
return N0.getOperand(0);
}
}

return SDValue();
}

SDValue DAGCombiner::visitFP_EXTEND(SDNode *N) {
SelectionDAG::FlagInserter FlagsInserter(DAG, N);
SDValue N0 = N->getOperand(0);
EVT VT = N->getValueType(0);
SDLoc DL(N);
Expand Down Expand Up @@ -18507,6 +18545,9 @@ SDValue DAGCombiner::visitFP_EXTEND(SDNode *N) {
if (SDValue NewVSel = matchVSelectOpSizesWithSetCC(N))
return NewVSel;

if (SDValue CastEliminated = eliminateFPCastPair(N))
return CastEliminated;

return SDValue();
}

Expand Down Expand Up @@ -27209,6 +27250,7 @@ SDValue DAGCombiner::visitFP_TO_FP16(SDNode *N) {
}

SDValue DAGCombiner::visitFP16_TO_FP(SDNode *N) {
SelectionDAG::FlagInserter FlagsInserter(DAG, N);
auto Op = N->getOpcode();
assert((Op == ISD::FP16_TO_FP || Op == ISD::BF16_TO_FP) &&
"opcode should be FP16_TO_FP or BF16_TO_FP.");
Expand All @@ -27223,6 +27265,9 @@ SDValue DAGCombiner::visitFP16_TO_FP(SDNode *N) {
}
}

if (SDValue CastEliminated = eliminateFPCastPair(N))
return CastEliminated;

// Sometimes constants manage to survive very late in the pipeline, e.g.,
// because they are wrapped inside the <1 x f16> type. Try one last time to
// get rid of them.
Expand Down
400 changes: 400 additions & 0 deletions llvm/test/CodeGen/AArch64/bf16_fast_math.ll

Large diffs are not rendered by default.

16 changes: 11 additions & 5 deletions llvm/test/CodeGen/AArch64/f16-instructions.ll
Original file line number Diff line number Diff line change
Expand Up @@ -84,11 +84,8 @@ define half @test_fmadd(half %a, half %b, half %c) #0 {
; CHECK-CVT-SD: // %bb.0:
; CHECK-CVT-SD-NEXT: fcvt s1, h1
; CHECK-CVT-SD-NEXT: fcvt s0, h0
; CHECK-CVT-SD-NEXT: fmul s0, s0, s1
; CHECK-CVT-SD-NEXT: fcvt s1, h2
; CHECK-CVT-SD-NEXT: fcvt h0, s0
; CHECK-CVT-SD-NEXT: fcvt s0, h0
; CHECK-CVT-SD-NEXT: fadd s0, s0, s1
; CHECK-CVT-SD-NEXT: fcvt s2, h2
; CHECK-CVT-SD-NEXT: fmadd s0, s0, s1, s2
; CHECK-CVT-SD-NEXT: fcvt h0, s0
; CHECK-CVT-SD-NEXT: ret
;
Expand Down Expand Up @@ -1248,6 +1245,15 @@ define half @test_atan(half %a) #0 {
}

define half @test_atan2(half %a, half %b) #0 {
; CHECK-LABEL: test_atan2:
; CHECK: // %bb.0:
; CHECK-NEXT: str x30, [sp, #-16]! // 8-byte Folded Spill
; CHECK-NEXT: fcvt s0, h0
; CHECK-NEXT: fcvt s1, h1
; CHECK-NEXT: bl atan2f
; CHECK-NEXT: fcvt h0, s0
; CHECK-NEXT: ldr x30, [sp], #16 // 8-byte Folded Reload
; CHECK-NEXT: ret
%r = call half @llvm.atan2.f16(half %a, half %b)
ret half %r
}
Expand Down
7 changes: 2 additions & 5 deletions llvm/test/CodeGen/AArch64/fmla.ll
Original file line number Diff line number Diff line change
Expand Up @@ -1114,11 +1114,8 @@ define half @fmul_f16(half %a, half %b, half %c) {
; CHECK-SD-NOFP16: // %bb.0: // %entry
; CHECK-SD-NOFP16-NEXT: fcvt s1, h1
; CHECK-SD-NOFP16-NEXT: fcvt s0, h0
; CHECK-SD-NOFP16-NEXT: fmul s0, s0, s1
; CHECK-SD-NOFP16-NEXT: fcvt s1, h2
; CHECK-SD-NOFP16-NEXT: fcvt h0, s0
; CHECK-SD-NOFP16-NEXT: fcvt s0, h0
; CHECK-SD-NOFP16-NEXT: fadd s0, s0, s1
; CHECK-SD-NOFP16-NEXT: fcvt s2, h2
; CHECK-SD-NOFP16-NEXT: fmadd s0, s0, s1, s2
; CHECK-SD-NOFP16-NEXT: fcvt h0, s0
; CHECK-SD-NOFP16-NEXT: ret
;
Expand Down
109 changes: 109 additions & 0 deletions llvm/test/CodeGen/AArch64/fp16_fast_math.ll
Original file line number Diff line number Diff line change
Expand Up @@ -88,3 +88,112 @@ entry:
%add = fadd ninf half %x, %y
ret half %add
}

; Check that when we have the right fast math flags the converts in between the
; two fadds are removed.

define half @normal_fadd_sequence(half %x, half %y, half %z) {
; CHECK-CVT-LABEL: name: normal_fadd_sequence
; CHECK-CVT: bb.0.entry:
; CHECK-CVT-NEXT: liveins: $h0, $h1, $h2
; CHECK-CVT-NEXT: {{ $}}
; CHECK-CVT-NEXT: [[COPY:%[0-9]+]]:fpr16 = COPY $h2
; CHECK-CVT-NEXT: [[COPY1:%[0-9]+]]:fpr16 = COPY $h1
; CHECK-CVT-NEXT: [[COPY2:%[0-9]+]]:fpr16 = COPY $h0
; CHECK-CVT-NEXT: [[FCVTSHr:%[0-9]+]]:fpr32 = nofpexcept FCVTSHr [[COPY1]], implicit $fpcr
; CHECK-CVT-NEXT: [[FCVTSHr1:%[0-9]+]]:fpr32 = nofpexcept FCVTSHr [[COPY2]], implicit $fpcr
; CHECK-CVT-NEXT: [[FADDSrr:%[0-9]+]]:fpr32 = nofpexcept FADDSrr killed [[FCVTSHr1]], killed [[FCVTSHr]], implicit $fpcr
; CHECK-CVT-NEXT: [[FCVTHSr:%[0-9]+]]:fpr16 = nofpexcept FCVTHSr killed [[FADDSrr]], implicit $fpcr
; CHECK-CVT-NEXT: [[FCVTSHr2:%[0-9]+]]:fpr32 = nofpexcept FCVTSHr killed [[FCVTHSr]], implicit $fpcr
; CHECK-CVT-NEXT: [[FCVTSHr3:%[0-9]+]]:fpr32 = nofpexcept FCVTSHr [[COPY]], implicit $fpcr
; CHECK-CVT-NEXT: [[FADDSrr1:%[0-9]+]]:fpr32 = nofpexcept FADDSrr killed [[FCVTSHr2]], killed [[FCVTSHr3]], implicit $fpcr
; CHECK-CVT-NEXT: [[FCVTHSr1:%[0-9]+]]:fpr16 = nofpexcept FCVTHSr killed [[FADDSrr1]], implicit $fpcr
; CHECK-CVT-NEXT: $h0 = COPY [[FCVTHSr1]]
; CHECK-CVT-NEXT: RET_ReallyLR implicit $h0
;
; CHECK-FP16-LABEL: name: normal_fadd_sequence
; CHECK-FP16: bb.0.entry:
; CHECK-FP16-NEXT: liveins: $h0, $h1, $h2
; CHECK-FP16-NEXT: {{ $}}
; CHECK-FP16-NEXT: [[COPY:%[0-9]+]]:fpr16 = COPY $h2
; CHECK-FP16-NEXT: [[COPY1:%[0-9]+]]:fpr16 = COPY $h1
; CHECK-FP16-NEXT: [[COPY2:%[0-9]+]]:fpr16 = COPY $h0
; CHECK-FP16-NEXT: [[FADDHrr:%[0-9]+]]:fpr16 = nofpexcept FADDHrr [[COPY2]], [[COPY1]], implicit $fpcr
; CHECK-FP16-NEXT: [[FADDHrr1:%[0-9]+]]:fpr16 = nofpexcept FADDHrr killed [[FADDHrr]], [[COPY]], implicit $fpcr
; CHECK-FP16-NEXT: $h0 = COPY [[FADDHrr1]]
; CHECK-FP16-NEXT: RET_ReallyLR implicit $h0
entry:
%add1 = fadd half %x, %y
%add2 = fadd half %add1, %z
ret half %add2
}

define half @nnan_ninf_contract_fadd_sequence(half %x, half %y, half %z) {
; CHECK-CVT-LABEL: name: nnan_ninf_contract_fadd_sequence
; CHECK-CVT: bb.0.entry:
; CHECK-CVT-NEXT: liveins: $h0, $h1, $h2
; CHECK-CVT-NEXT: {{ $}}
; CHECK-CVT-NEXT: [[COPY:%[0-9]+]]:fpr16 = COPY $h2
; CHECK-CVT-NEXT: [[COPY1:%[0-9]+]]:fpr16 = COPY $h1
; CHECK-CVT-NEXT: [[COPY2:%[0-9]+]]:fpr16 = COPY $h0
; CHECK-CVT-NEXT: [[FCVTSHr:%[0-9]+]]:fpr32 = nnan ninf contract nofpexcept FCVTSHr [[COPY1]], implicit $fpcr
; CHECK-CVT-NEXT: [[FCVTSHr1:%[0-9]+]]:fpr32 = nnan ninf contract nofpexcept FCVTSHr [[COPY2]], implicit $fpcr
; CHECK-CVT-NEXT: [[FADDSrr:%[0-9]+]]:fpr32 = nnan ninf contract nofpexcept FADDSrr killed [[FCVTSHr1]], killed [[FCVTSHr]], implicit $fpcr
; CHECK-CVT-NEXT: [[FCVTSHr2:%[0-9]+]]:fpr32 = nnan ninf contract nofpexcept FCVTSHr [[COPY]], implicit $fpcr
; CHECK-CVT-NEXT: [[FADDSrr1:%[0-9]+]]:fpr32 = nnan ninf contract nofpexcept FADDSrr killed [[FADDSrr]], killed [[FCVTSHr2]], implicit $fpcr
; CHECK-CVT-NEXT: [[FCVTHSr:%[0-9]+]]:fpr16 = nnan ninf contract nofpexcept FCVTHSr killed [[FADDSrr1]], implicit $fpcr
; CHECK-CVT-NEXT: $h0 = COPY [[FCVTHSr]]
; CHECK-CVT-NEXT: RET_ReallyLR implicit $h0
;
; CHECK-FP16-LABEL: name: nnan_ninf_contract_fadd_sequence
; CHECK-FP16: bb.0.entry:
; CHECK-FP16-NEXT: liveins: $h0, $h1, $h2
; CHECK-FP16-NEXT: {{ $}}
; CHECK-FP16-NEXT: [[COPY:%[0-9]+]]:fpr16 = COPY $h2
; CHECK-FP16-NEXT: [[COPY1:%[0-9]+]]:fpr16 = COPY $h1
; CHECK-FP16-NEXT: [[COPY2:%[0-9]+]]:fpr16 = COPY $h0
; CHECK-FP16-NEXT: [[FADDHrr:%[0-9]+]]:fpr16 = nnan ninf contract nofpexcept FADDHrr [[COPY2]], [[COPY1]], implicit $fpcr
; CHECK-FP16-NEXT: [[FADDHrr1:%[0-9]+]]:fpr16 = nnan ninf contract nofpexcept FADDHrr killed [[FADDHrr]], [[COPY]], implicit $fpcr
; CHECK-FP16-NEXT: $h0 = COPY [[FADDHrr1]]
; CHECK-FP16-NEXT: RET_ReallyLR implicit $h0
entry:
%add1 = fadd nnan ninf contract half %x, %y
%add2 = fadd nnan ninf contract half %add1, %z
ret half %add2
}

define half @ninf_fadd_sequence(half %x, half %y, half %z) {
; CHECK-CVT-LABEL: name: ninf_fadd_sequence
; CHECK-CVT: bb.0.entry:
; CHECK-CVT-NEXT: liveins: $h0, $h1, $h2
; CHECK-CVT-NEXT: {{ $}}
; CHECK-CVT-NEXT: [[COPY:%[0-9]+]]:fpr16 = COPY $h2
; CHECK-CVT-NEXT: [[COPY1:%[0-9]+]]:fpr16 = COPY $h1
; CHECK-CVT-NEXT: [[COPY2:%[0-9]+]]:fpr16 = COPY $h0
; CHECK-CVT-NEXT: [[FCVTSHr:%[0-9]+]]:fpr32 = ninf nofpexcept FCVTSHr [[COPY1]], implicit $fpcr
; CHECK-CVT-NEXT: [[FCVTSHr1:%[0-9]+]]:fpr32 = ninf nofpexcept FCVTSHr [[COPY2]], implicit $fpcr
; CHECK-CVT-NEXT: [[FADDSrr:%[0-9]+]]:fpr32 = ninf nofpexcept FADDSrr killed [[FCVTSHr1]], killed [[FCVTSHr]], implicit $fpcr
; CHECK-CVT-NEXT: [[FCVTHSr:%[0-9]+]]:fpr16 = ninf nofpexcept FCVTHSr killed [[FADDSrr]], implicit $fpcr
; CHECK-CVT-NEXT: [[FCVTSHr2:%[0-9]+]]:fpr32 = ninf nofpexcept FCVTSHr killed [[FCVTHSr]], implicit $fpcr
; CHECK-CVT-NEXT: [[FCVTSHr3:%[0-9]+]]:fpr32 = ninf nofpexcept FCVTSHr [[COPY]], implicit $fpcr
; CHECK-CVT-NEXT: [[FADDSrr1:%[0-9]+]]:fpr32 = ninf nofpexcept FADDSrr killed [[FCVTSHr2]], killed [[FCVTSHr3]], implicit $fpcr
; CHECK-CVT-NEXT: [[FCVTHSr1:%[0-9]+]]:fpr16 = ninf nofpexcept FCVTHSr killed [[FADDSrr1]], implicit $fpcr
; CHECK-CVT-NEXT: $h0 = COPY [[FCVTHSr1]]
; CHECK-CVT-NEXT: RET_ReallyLR implicit $h0
;
; CHECK-FP16-LABEL: name: ninf_fadd_sequence
; CHECK-FP16: bb.0.entry:
; CHECK-FP16-NEXT: liveins: $h0, $h1, $h2
; CHECK-FP16-NEXT: {{ $}}
; CHECK-FP16-NEXT: [[COPY:%[0-9]+]]:fpr16 = COPY $h2
; CHECK-FP16-NEXT: [[COPY1:%[0-9]+]]:fpr16 = COPY $h1
; CHECK-FP16-NEXT: [[COPY2:%[0-9]+]]:fpr16 = COPY $h0
; CHECK-FP16-NEXT: [[FADDHrr:%[0-9]+]]:fpr16 = ninf nofpexcept FADDHrr [[COPY2]], [[COPY1]], implicit $fpcr
; CHECK-FP16-NEXT: [[FADDHrr1:%[0-9]+]]:fpr16 = ninf nofpexcept FADDHrr killed [[FADDHrr]], [[COPY]], implicit $fpcr
; CHECK-FP16-NEXT: $h0 = COPY [[FADDHrr1]]
; CHECK-FP16-NEXT: RET_ReallyLR implicit $h0
entry:
%add1 = fadd ninf half %x, %y
%add2 = fadd ninf half %add1, %z
ret half %add2
}
Loading
Loading