-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[InstCombine] Handle scalable splats of constants in getMinimumFPType #132960
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[InstCombine] Handle scalable splats of constants in getMinimumFPType #132960
Conversation
We previously handled ConstantExpr scalable splats in 5d92979, but only fpexts. ConstantExpr fpexts have since been removed, and simultaneously we didn't handle splats of constants that weren't extended. This updates it to remove the fpext check and instead see if we can shrink the result of getSplatValue. Note that the test case doesn't get completely folded away due to llvm#132922
@llvm/pr-subscribers-llvm-transforms Author: Luke Lau (lukel97) ChangesWe previously handled ConstantExpr scalable splats in 5d92979, but only fpexts. ConstantExpr fpexts have since been removed, and simultaneously we didn't handle splats of constants that weren't extended. This updates it to remove the fpext check and instead see if we can shrink the result of getSplatValue. Note that the test case doesn't get completely folded away due to #132922 Full diff: https://github.com/llvm/llvm-project/pull/132960.diff 2 Files Affected:
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
index 4ec1af394464b..3faaf1e52db26 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
@@ -1685,11 +1685,12 @@ static Type *getMinimumFPType(Value *V, bool PreferBFloat) {
return T;
// We can only correctly find a minimum type for a scalable vector when it is
- // a splat. For splats of constant values the fpext is wrapped up as a
- // ConstantExpr.
- if (auto *FPCExt = dyn_cast<ConstantExpr>(V))
- if (FPCExt->getOpcode() == Instruction::FPExt)
- return FPCExt->getOperand(0)->getType();
+ // a splat.
+ if (auto *FPCE = dyn_cast<ConstantExpr>(V))
+ if (isa<ScalableVectorType>(V->getType()))
+ if (auto *Splat = dyn_cast<ConstantFP>(FPCE->getSplatValue()))
+ if (Type *T = shrinkFPConstant(Splat, PreferBFloat))
+ return T;
// Try to shrink a vector of FP constants. This returns nullptr on scalable
// vectors
diff --git a/llvm/test/Transforms/InstCombine/scalable-const-fp-splat.ll b/llvm/test/Transforms/InstCombine/scalable-const-fp-splat.ll
index 731b079881f08..595486361d16e 100644
--- a/llvm/test/Transforms/InstCombine/scalable-const-fp-splat.ll
+++ b/llvm/test/Transforms/InstCombine/scalable-const-fp-splat.ll
@@ -13,3 +13,16 @@ define <vscale x 2 x float> @shrink_splat_scalable_extend(<vscale x 2 x float> %
%5 = fptrunc <vscale x 2 x double> %4 to <vscale x 2 x float>
ret <vscale x 2 x float> %5
}
+
+define <vscale x 2 x float> @shrink_splat_scalable_extend_rhs_constexpr(<vscale x 2 x float> %a) {
+; CHECK-LABEL: define <vscale x 2 x float> @shrink_splat_scalable_extend_rhs_constexpr(
+; CHECK-SAME: <vscale x 2 x float> [[A:%.*]]) {
+; CHECK-NEXT: [[TMP1:%.*]] = fptrunc <vscale x 2 x double> splat (double -1.000000e+00) to <vscale x 2 x float>
+; CHECK-NEXT: [[TMP3:%.*]] = fadd <vscale x 2 x float> [[A]], [[TMP1]]
+; CHECK-NEXT: ret <vscale x 2 x float> [[TMP3]]
+;
+ %2 = fpext <vscale x 2 x float> %a to <vscale x 2 x double>
+ %4 = fadd <vscale x 2 x double> %2, splat (double -1.000000e+00)
+ %5 = fptrunc <vscale x 2 x double> %4 to <vscale x 2 x float>
+ ret <vscale x 2 x float> %5
+}
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM.
// a splat. | ||
if (auto *FPCE = dyn_cast<ConstantExpr>(V)) | ||
if (isa<ScalableVectorType>(V->getType())) | ||
if (auto *Splat = dyn_cast<ConstantFP>(FPCE->getSplatValue())) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In InstCombinerImpl::visitCallInst
there is a similar example of using splats:
// Handle mul by one:
if (Constant *CV1 = dyn_cast<Constant>(Arg1))
if (ConstantInt *Splat =
dyn_cast_or_null<ConstantInt>(CV1->getSplatValue()))
and I wondered if it was worth doing the same here. Perhaps I'm wrong, but I thought splats aren't restricted to scalable vectors I think and would apply to fixed-width too? For example,
if (auto *CFP = dyn_cast<Constant>(V))
if (auto *Splat = dyn_cast_or_null<ConstantFP>(CFP->getSplatValue()))
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed-width vectors are handled by shrinkFPConstantVector
below.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, shrinkFPConstantVector will handle splatted and non-splatted fixed vectors too.
But I think it should be NFC if we also handle fixed splats here, I guess it would act as a sort of fast path to prevent having to check shrinkFPConstant on each element? I've added it to this PR but let me know if you want to leave it out to keep it closer to the original
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for this. I personally think it looks neater to handle all splats in one place.
@@ -13,3 +13,16 @@ define <vscale x 2 x float> @shrink_splat_scalable_extend(<vscale x 2 x float> % | |||
%5 = fptrunc <vscale x 2 x double> %4 to <vscale x 2 x float> | |||
ret <vscale x 2 x float> %5 | |||
} | |||
|
|||
define <vscale x 2 x float> @shrink_splat_scalable_extend_rhs_constexpr(<vscale x 2 x float> %a) { | |||
; CHECK-LABEL: define <vscale x 2 x float> @shrink_splat_scalable_extend_rhs_constexpr( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would be good to have a fixed-width variant of this too, if there isn't one already?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added to fpextend.ll in d26e672
…struction Stacked on llvm#132960 to prevent a regression Previously only fixed vector splats were handled. This adds supports for scalable vectors too by allowing ConstantExpr splats. We need to add the extra V->getType()->isVectorTy() check because a ConstantExpr might be a scalar to vector bitcast. I believe this will also allow casts of fixed vector ConstantExprs to be folded but I couldn't come up with a test case for this, the ConstantExprs seem to be folded away before reaching InstCombine. Fixes llvm#132922
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
We previously handled ConstantExpr scalable splats in 5d92979, but only fpexts.
ConstantExpr fpexts have since been removed, and simultaneously we didn't handle splats of constants that weren't extended.
This updates it to remove the fpext check and instead see if we can shrink the result of getSplatValue.
Note that the test case doesn't get completely folded away due to #132922