@@ -11803,13 +11803,13 @@ class BoUpSLP::ShuffleInstructionBuilder final : public BaseShuffleAnalysis {
11803
11803
void add(const TreeEntry &E1, const TreeEntry &E2, ArrayRef<int> Mask) {
11804
11804
Value *V1 = E1.VectorizedValue;
11805
11805
if (V1->getType()->isIntOrIntVectorTy())
11806
- V1 = castToScalarTyElem(V1, all_of (E1.Scalars, [&](Value *V) {
11806
+ V1 = castToScalarTyElem(V1, any_of (E1.Scalars, [&](Value *V) {
11807
11807
return !isKnownNonNegative(
11808
11808
V, SimplifyQuery(*R.DL));
11809
11809
}));
11810
11810
Value *V2 = E2.VectorizedValue;
11811
11811
if (V2->getType()->isIntOrIntVectorTy())
11812
- V2 = castToScalarTyElem(V2, all_of (E2.Scalars, [&](Value *V) {
11812
+ V2 = castToScalarTyElem(V2, any_of (E2.Scalars, [&](Value *V) {
11813
11813
return !isKnownNonNegative(
11814
11814
V, SimplifyQuery(*R.DL));
11815
11815
}));
@@ -11820,7 +11820,7 @@ class BoUpSLP::ShuffleInstructionBuilder final : public BaseShuffleAnalysis {
11820
11820
void add(const TreeEntry &E1, ArrayRef<int> Mask) {
11821
11821
Value *V1 = E1.VectorizedValue;
11822
11822
if (V1->getType()->isIntOrIntVectorTy())
11823
- V1 = castToScalarTyElem(V1, all_of (E1.Scalars, [&](Value *V) {
11823
+ V1 = castToScalarTyElem(V1, any_of (E1.Scalars, [&](Value *V) {
11824
11824
return !isKnownNonNegative(
11825
11825
V, SimplifyQuery(*R.DL));
11826
11826
}));
@@ -14900,24 +14900,30 @@ bool BoUpSLP::collectValuesToDemote(
14900
14900
// If the value is not a vectorized instruction in the expression and not used
14901
14901
// by the insertelement instruction and not used in multiple vector nodes, it
14902
14902
// cannot be demoted.
14903
+ bool IsSignedNode = any_of(E.Scalars, [&](Value *R) {
14904
+ return !isKnownNonNegative(R, SimplifyQuery(*DL));
14905
+ });
14903
14906
auto IsPotentiallyTruncated = [&](Value *V, unsigned &BitWidth) -> bool {
14904
14907
if (MultiNodeScalars.contains(V))
14905
14908
return false;
14906
- if (OrigBitWidth > BitWidth) {
14909
+ // For lat shuffle of sext/zext with many uses need to check the extra bit
14910
+ // for unsigned values, otherwise may have incorrect casting for reused
14911
+ // scalars.
14912
+ bool IsSignedVal = !isKnownNonNegative(V, SimplifyQuery(*DL));
14913
+ if ((!IsSignedNode || IsSignedVal) && OrigBitWidth > BitWidth) {
14907
14914
APInt Mask = APInt::getBitsSetFrom(OrigBitWidth, BitWidth);
14908
14915
if (MaskedValueIsZero(V, Mask, SimplifyQuery(*DL)))
14909
14916
return true;
14910
14917
}
14911
- auto NumSignBits = ComputeNumSignBits(V, *DL, 0, AC, nullptr, DT);
14918
+ unsigned NumSignBits = ComputeNumSignBits(V, *DL, 0, AC, nullptr, DT);
14912
14919
unsigned BitWidth1 = OrigBitWidth - NumSignBits;
14913
- bool IsSigned = !isKnownNonNegative(V, SimplifyQuery(*DL));
14914
- if (IsSigned)
14920
+ if (IsSignedNode)
14915
14921
++BitWidth1;
14916
14922
if (auto *I = dyn_cast<Instruction>(V)) {
14917
14923
APInt Mask = DB->getDemandedBits(I);
14918
14924
unsigned BitWidth2 =
14919
14925
std::max<unsigned>(1, Mask.getBitWidth() - Mask.countl_zero());
14920
- while (!IsSigned && BitWidth2 < OrigBitWidth) {
14926
+ while (!IsSignedNode && BitWidth2 < OrigBitWidth) {
14921
14927
APInt Mask = APInt::getBitsSetFrom(OrigBitWidth, BitWidth2 - 1);
14922
14928
if (MaskedValueIsZero(V, Mask, SimplifyQuery(*DL)))
14923
14929
break;
0 commit comments