-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[InstCombine] Eliminate fptrunc/fpext if fast math flags allow it #115027
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-llvm-transforms Author: John Brawn (john-brawn-arm) ChangesWhen expressions of a floating-point type are evaluated at a higher precision (e.g. _Float16 being evaluated as float) this results in a fptrunc then fpext between each operation. With the appropriate fast math flags (nnan ninf contract) we can eliminate these cast instructions. As cast instructions don't have fast math flags it's the source and destination of the casts whose flags are checked. Full diff: https://github.com/llvm/llvm-project/pull/115027.diff 2 Files Affected:
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
index 6c2554ea73b7f8..4c2bbcfec5cf82 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
@@ -1949,6 +1949,32 @@ Instruction *InstCombinerImpl::visitFPExt(CastInst &FPExt) {
return CastInst::Create(FPCast->getOpcode(), FPCast->getOperand(0), Ty);
}
+ // fpext (fptrunc(x)) -> x, if the fast math flags allow it
+ Instruction *SrcInstr;
+ if (match(Src, m_FPTrunc(m_Instruction(SrcInstr)))) {
+ // Whether this transformation is possible depends both on the flags of the
+ // value that is truncated, and the flags on the instructions that use the
+ // fpext.
+ FastMathFlags SrcFlags = SrcInstr->getFastMathFlags();
+ FastMathFlags DstFlags = FastMathFlags::getFast();
+ for (User *U : FPExt.users())
+ if (auto *UInstr = dyn_cast<Instruction>(U))
+ DstFlags &= UInstr->getFastMathFlags();
+ // Trunc can introduce inf and change the encoding of a nan, so the
+ // destination must have the nnan and ninf flags to indicate that we don't
+ // need to care about that. We are also removing a rounding step, and that
+ // requires both the source and destination to allow contraction.
+ if (DstFlags.noNaNs() && DstFlags.noInfs() && SrcFlags.allowContract() &&
+ DstFlags.allowContract()) {
+ // We do need a single cast if the source and destination types don't
+ // match.
+ if (SrcInstr->getType() != Ty)
+ return CastInst::CreateFPCast(SrcInstr, Ty);
+ else
+ return replaceInstUsesWith(FPExt, SrcInstr);
+ }
+ }
+
return commonCastTransforms(FPExt);
}
diff --git a/llvm/test/Transforms/InstCombine/fpextend.ll b/llvm/test/Transforms/InstCombine/fpextend.ll
index c9adbe10d8db44..d3ac511b10996d 100644
--- a/llvm/test/Transforms/InstCombine/fpextend.ll
+++ b/llvm/test/Transforms/InstCombine/fpextend.ll
@@ -448,3 +448,95 @@ define bfloat @bf16_frem(bfloat %x) {
%t3 = fptrunc float %t2 to bfloat
ret bfloat %t3
}
+
+define double @fptrunc_fpextend_nofast(double %x, double %y, double %z) {
+; CHECK-LABEL: @fptrunc_fpextend_nofast(
+; CHECK-NEXT: [[ADD1:%.*]] = fadd double [[X:%.*]], [[Y:%.*]]
+; CHECK-NEXT: [[TRUNC:%.*]] = fptrunc double [[ADD1]] to float
+; CHECK-NEXT: [[EXT:%.*]] = fpext float [[TRUNC]] to double
+; CHECK-NEXT: [[ADD2:%.*]] = fadd double [[Z:%.*]], [[EXT]]
+; CHECK-NEXT: ret double [[ADD2]]
+;
+ %add1 = fadd double %x, %y
+ %trunc = fptrunc double %add1 to float
+ %ext = fpext float %trunc to double
+ %add2 = fadd double %ext, %z
+ ret double %add2
+}
+
+define double @fptrunc_fpextend_fast(double %x, double %y, double %z) {
+; CHECK-LABEL: @fptrunc_fpextend_fast(
+; CHECK-NEXT: [[ADD1:%.*]] = fadd contract double [[X:%.*]], [[Y:%.*]]
+; CHECK-NEXT: [[ADD2:%.*]] = fadd nnan ninf contract double [[ADD1]], [[Z:%.*]]
+; CHECK-NEXT: ret double [[ADD2]]
+;
+ %add1 = fadd contract double %x, %y
+ %trunc = fptrunc double %add1 to float
+ %ext = fpext float %trunc to double
+ %add2 = fadd nnan ninf contract double %ext, %z
+ ret double %add2
+}
+
+define float @fptrunc_fpextend_result_smaller(double %x, double %y, float %z) {
+; CHECK-LABEL: @fptrunc_fpextend_result_smaller(
+; CHECK-NEXT: [[ADD1:%.*]] = fadd contract double [[X:%.*]], [[Y:%.*]]
+; CHECK-NEXT: [[EXT:%.*]] = fptrunc double [[ADD1]] to float
+; CHECK-NEXT: [[ADD2:%.*]] = fadd nnan ninf contract float [[Z:%.*]], [[EXT]]
+; CHECK-NEXT: ret float [[ADD2]]
+;
+ %add1 = fadd contract double %x, %y
+ %trunc = fptrunc double %add1 to half
+ %ext = fpext half %trunc to float
+ %add2 = fadd nnan ninf contract float %ext, %z
+ ret float %add2
+}
+
+define double @fptrunc_fpextend_result_larger(float %x, float %y, double %z) {
+; CHECK-LABEL: @fptrunc_fpextend_result_larger(
+; CHECK-NEXT: [[ADD1:%.*]] = fadd contract float [[X:%.*]], [[Y:%.*]]
+; CHECK-NEXT: [[EXT:%.*]] = fpext float [[ADD1]] to double
+; CHECK-NEXT: [[ADD2:%.*]] = fadd nnan ninf contract double [[Z:%.*]], [[EXT]]
+; CHECK-NEXT: ret double [[ADD2]]
+;
+ %add1 = fadd contract float %x, %y
+ %trunc = fptrunc float %add1 to half
+ %ext = fpext half %trunc to double
+ %add2 = fadd nnan ninf contract double %ext, %z
+ ret double %add2
+}
+
+define double @fptrunc_fpextend_multiple_use(double %x, double %y, double %a, double %b) {
+; CHECK-LABEL: @fptrunc_fpextend_multiple_use(
+; CHECK-NEXT: [[ADD1:%.*]] = fadd contract double [[X:%.*]], [[Y:%.*]]
+; CHECK-NEXT: [[ADD2:%.*]] = fadd nnan ninf contract double [[ADD1]], [[A:%.*]]
+; CHECK-NEXT: [[ADD3:%.*]] = fadd nnan ninf contract double [[ADD1]], [[B:%.*]]
+; CHECK-NEXT: [[MUL:%.*]] = fmul double [[ADD2]], [[ADD3]]
+; CHECK-NEXT: ret double [[MUL]]
+;
+ %add1 = fadd contract double %x, %y
+ %trunc = fptrunc double %add1 to float
+ %ext = fpext float %trunc to double
+ %add2 = fadd nnan ninf contract double %ext, %a
+ %add3 = fadd nnan ninf contract double %ext, %b
+ %mul = fmul double %add2, %add3
+ ret double %mul
+}
+
+define double @fptrunc_fpextend_multiple_use_flag_mismatch(double %x, double %y, double %a, double %b) {
+; CHECK-LABEL: @fptrunc_fpextend_multiple_use_flag_mismatch(
+; CHECK-NEXT: [[ADD1:%.*]] = fadd contract double [[X:%.*]], [[Y:%.*]]
+; CHECK-NEXT: [[TRUNC:%.*]] = fptrunc double [[ADD1]] to float
+; CHECK-NEXT: [[EXT:%.*]] = fpext float [[TRUNC]] to double
+; CHECK-NEXT: [[ADD2:%.*]] = fadd nnan ninf contract double [[A:%.*]], [[EXT]]
+; CHECK-NEXT: [[ADD3:%.*]] = fadd nnan ninf double [[B:%.*]], [[EXT]]
+; CHECK-NEXT: [[MUL:%.*]] = fmul double [[ADD2]], [[ADD3]]
+; CHECK-NEXT: ret double [[MUL]]
+;
+ %add1 = fadd contract double %x, %y
+ %trunc = fptrunc double %add1 to float
+ %ext = fpext float %trunc to double
+ %add2 = fadd nnan ninf contract double %ext, %a
+ %add3 = fadd nnan ninf double %ext, %b
+ %mul = fmul double %add2, %add3
+ ret double %mul
+}
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd much rather we bite the bullet and add the fast-math flags onto the FP conversion operations, rather than trying to play game checking the users to figure out what their fast-math flags imply the conversion operation's flags should be.
Sounds sensible. I'll get to work on that then adjust this patch to check the trunc/ext. |
For reference, there's an RFC for that here: https://discourse.llvm.org/t/rfc-fmf-on-more-instructions/82978 |
Allow fast math flags in fptrunc and fpext: #115894 |
When expressions of a floating-point type are evaluated at a higher precision (e.g. _Float16 being evaluated as float) this results in a fptrunc then fpext between each operation. With the appropriate fast math flags (nnan ninf contract) we can eliminate these cast instructions.
297b144
to
78e1634
Compare
This has now been updated to use the fast math flags on the fpext and fptrunc. |
Ping. |
2 similar comments
Ping. |
Ping. |
// destination must have the nnan and ninf flags to indicate that we don't | ||
// need to care about that. We are also removing a rounding step, and that | ||
// requires both the source and destination to allow contraction. | ||
if (DstFlags.noNaNs() && DstFlags.noInfs() && SrcFlags.allowContract() && |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think contract
is suitable here. Previously, contract
means the optimizer is allowed to fold fmul+fadd into a fma. But the transformation here is too aggressive since clang may be the first compiler to fold this pattern: https://godbolt.org/z/bYfz334Pr
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The LLVM LangRef is a bit unhelpful as to what the contract flag allows, as it just says "Allow floating-point contraction" without defining what it means by contraction. The definition of contraction in the C23 standard (in section 6.5.1) is clearer:
A floating expression may be contracted, that is, evaluated as though it were a single operation, thereby omitting rounding errors implied by the source code and the expression evaluation method. The FP_CONTRACT pragma in <math.h> provides a way to disallow contracted expressions. Otherwise, whether and how expressions are contracted is implementation-defined.
with a footnote saying that in a contracted expression the intermediate operations are as if evaluated to infinite range and precision. If the expression (double)(float)x
(where x is a double) is contracted then the double-to-float and float-to-double operations are evaluated as if to infinite range and precision, meaning the result is just x.
So long as what clang is doing is correct I don't think it matters that it's the first compiler to do it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm also not comfortable with using contract
in this way. I can see your point that this meets the C23 definition for FP_CONTRACT, but I don't think it's what users would expect. My feeling is that users probably associate the contract flag with FMA. The C23 standard extends the definition, presumably in anticipation of similar hardware instructions that are able to fuse other combinations of operations, but the transformation proposed in this PR is most likely circumventing an explicit user instruction intended to truncate a value.
In order to get the 'contract' flag in isolation from clang, for example, you'd need to use the -ffp-contract=fast
option. The documentation for this option says, "Specify when the compiler is permitted to form fused floating-point operations, such as fused multiply-add (FMA)." There's nothing there that indicates it will allow the compiler to disregard changes in precision implied (or explicitly requested) by the source code. If I were using this option, my intention would be to enable FMA formation across expressions. I would not be happy if the compiler changed my results in other ways.
This gets us into a problematic situation. I think we'd all agree that the full set of fast-math flags should enable this transformation. However, there isn't a flag for general value-changing optimizations, so if we don't allow this with contract
we're going to have to either look at the "unsafe-fp-math" function attribute or do something more or less arbitrary in requiring some other combination of flags.
@jcranmer-intel What's your opinion on this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've got an RFC on the contract
semantics almost ready to go (let's see if I'm successful in getting it published this week). The working definition I have, that generalizes it from just FMA formation, is essentially "new expression would evaluate the same result if both old and newer were evaluated at infinite range/precision, and the new expression only has at most one instruction that causes rounding" (a few other conditions, but that's the main one).
The main goal is to generalize to cover cases like fms
instructions (a * b - c
), but when I was reviewing a lot of the documentation in C for #pragma STDC FP_CONTRACT
, I've also found that fpext; libm; fptrunc
is another case that seems to be contemplated for FP_CONTRACT
.
By this definition of contract
, then this optimization would be correct. But I'll also admit to having a vague level of unease about this being in contract
--round-tripping via a smaller value tends to feel intentional, so removing it should have some stronger level of intent. Of all of the FMFs, contract
is also the flag that is probably the most likely for users to default to enable, so it helps to be a little less aggressive in the optimizations here.
Stepping back a touch: I've increasingly come to the opinion that FMA formation via "combine add/mul into fma if some flag is present" isn't the best way to tackle optimization. It's better to instead have a "fast_fma" primitive, which does fma if there's a hardware instruction for it, and add/mul if there isn't. This is ultimately something that requires source changes, even language changes, so we may be screwed out of a path for this for C/C++, though.
If people want it, I can bring up this topic to the CFP study group.
After looking into this some more I don't think instcombine is the right place to fix this. Different targets handle fp16 differently in clang, so e.g. with aarch64 we have "fadd half" generated by clang directly so there's no fptrunc/fpext to eliminate. And instcombine also has a transform to convert "fpext, fadd float, fptrunc" info "fadd half", which means this instcombine transformation doesn't even work on the original test case I was looking at. Instead I've created #131345 to do this in dagcombine. |
When expressions of a floating-point type are evaluated at a higher precision (e.g. _Float16 being evaluated as float) this results in a fptrunc then fpext between each operation. With the appropriate fast math flags (nnan ninf contract) we can eliminate these cast instructions.