Skip to content

[ValueTracking] Improve Bitcast handling to match SDAG #125935

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Jun 19, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 25 additions & 2 deletions llvm/lib/Analysis/ValueTracking.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1346,6 +1346,8 @@ static void computeKnownBitsFromOperator(const Operator *I,
isa<ScalableVectorType>(I->getType()))
break;

unsigned NumElts = DemandedElts.getBitWidth();
bool IsLE = Q.DL.isLittleEndian();
// Look through a cast from narrow vector elements to wider type.
// Examples: v4i32 -> v2i64, v3i8 -> v24
unsigned SubBitWidth = SrcVecTy->getScalarSizeInBits();
Expand All @@ -1364,7 +1366,6 @@ static void computeKnownBitsFromOperator(const Operator *I,
//
// The known bits of each sub-element are then inserted into place
// (dependent on endian) to form the full result of known bits.
unsigned NumElts = DemandedElts.getBitWidth();
unsigned SubScale = BitWidth / SubBitWidth;
APInt SubDemandedElts = APInt::getZero(NumElts * SubScale);
for (unsigned i = 0; i != NumElts; ++i) {
Expand All @@ -1376,10 +1377,32 @@ static void computeKnownBitsFromOperator(const Operator *I,
for (unsigned i = 0; i != SubScale; ++i) {
computeKnownBits(I->getOperand(0), SubDemandedElts.shl(i), KnownSrc, Q,
Depth + 1);
unsigned ShiftElt = Q.DL.isLittleEndian() ? i : SubScale - 1 - i;
unsigned ShiftElt = IsLE ? i : SubScale - 1 - i;
Known.insertBits(KnownSrc, ShiftElt * SubBitWidth);
}
}
// Look through a cast from wider vector elements to narrow type.
// Examples: v2i64 -> v4i32
if (SubBitWidth % BitWidth == 0) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess the missing case v2i32 -> i64 can be handled here (as a follow-up):

if (SrcTy->isIntOrPtrTy() &&
// TODO: For now, not handling conversions like:
// (bitcast i64 %x to <2 x i32>)
!I->getType()->isVectorTy()) {
computeKnownBits(I->getOperand(0), Known, Q, Depth + 1);
break;
}

unsigned SubScale = SubBitWidth / BitWidth;
KnownBits KnownSrc(SubBitWidth);
APInt SubDemandedElts =
APIntOps::ScaleBitMask(DemandedElts, NumElts / SubScale);
computeKnownBits(I->getOperand(0), SubDemandedElts, KnownSrc, Q,
Depth + 1);

Known.Zero.setAllBits();
Known.One.setAllBits();
for (unsigned i = 0; i != SubScale; ++i) {
Copy link
Contributor

@nikic nikic Jun 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this was supposed to go up to NumElts, not SubScale. If one of the first SubScale elements are non-demanded but later ones are demanded this will miscompile.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! I've opened #145223 to reland this.

if (DemandedElts[i]) {
unsigned Shifts = IsLE ? i : NumElts - 1 - i;
unsigned Offset = (Shifts % SubScale) * BitWidth;
Known = Known.intersectWith(KnownSrc.extractBits(BitWidth, Offset));
if (Known.isUnknown())
break;
}
}
}
break;
}
case Instruction::SExt: {
Expand Down
4 changes: 2 additions & 2 deletions llvm/test/Transforms/InstCombine/X86/x86-vector-shifts.ll
Original file line number Diff line number Diff line change
Expand Up @@ -3732,7 +3732,6 @@ define <4 x i64> @test_avx2_psrl_0() {
ret <4 x i64> %16
}

; FIXME: Failure to peek through bitcasts to ensure psllq shift amount is within bounds.
define <2 x i64> @PR125228(<2 x i64> %v, <2 x i64> %s) {
; CHECK-LABEL: @PR125228(
; CHECK-NEXT: [[MASK:%.*]] = and <2 x i64> [[S:%.*]], splat (i64 63)
Expand All @@ -3741,7 +3740,8 @@ define <2 x i64> @PR125228(<2 x i64> %v, <2 x i64> %s) {
; CHECK-NEXT: [[CAST:%.*]] = bitcast <2 x i64> [[MASK]] to <16 x i8>
; CHECK-NEXT: [[PSRLDQ:%.*]] = shufflevector <16 x i8> [[CAST]], <16 x i8> poison, <16 x i32> <i32 8, i32 9, i32 10, i32 11, i32 12, i32 13, i32 14, i32 15, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison, i32 poison>
; CHECK-NEXT: [[CAST3:%.*]] = bitcast <16 x i8> [[PSRLDQ]] to <2 x i64>
; CHECK-NEXT: [[SLL1:%.*]] = call <2 x i64> @llvm.x86.sse2.psll.q(<2 x i64> [[V]], <2 x i64> [[CAST3]])
; CHECK-NEXT: [[TMP2:%.*]] = shufflevector <2 x i64> [[CAST3]], <2 x i64> poison, <2 x i32> zeroinitializer
; CHECK-NEXT: [[SLL1:%.*]] = shl <2 x i64> [[V]], [[TMP2]]
; CHECK-NEXT: [[SHUFP_UNCASTED:%.*]] = shufflevector <2 x i64> [[SLL0]], <2 x i64> [[SLL1]], <2 x i32> <i32 0, i32 3>
; CHECK-NEXT: ret <2 x i64> [[SHUFP_UNCASTED]]
;
Expand Down
21 changes: 7 additions & 14 deletions llvm/test/Transforms/InstCombine/bitcast-known-bits.ll
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,7 @@ define <16 x i8> @knownbits_bitcast_masked_shift(<16 x i8> %arg1, <16 x i8> %arg
; CHECK-NEXT: [[BITCAST4:%.*]] = bitcast <16 x i8> [[OR]] to <8 x i16>
; CHECK-NEXT: [[SHL5:%.*]] = shl nuw <8 x i16> [[BITCAST4]], splat (i16 2)
; CHECK-NEXT: [[BITCAST6:%.*]] = bitcast <8 x i16> [[SHL5]] to <16 x i8>
; CHECK-NEXT: [[AND7:%.*]] = and <16 x i8> [[BITCAST6]], splat (i8 -52)
; CHECK-NEXT: ret <16 x i8> [[AND7]]
; CHECK-NEXT: ret <16 x i8> [[BITCAST6]]
;
%and = and <16 x i8> %arg1, splat (i8 3)
%and3 = and <16 x i8> %arg2, splat (i8 48)
Expand All @@ -33,8 +32,7 @@ define <16 x i8> @knownbits_shuffle_masked_nibble_shift(<16 x i8> %arg) {
; CHECK-NEXT: [[BITCAST1:%.*]] = bitcast <16 x i8> [[SHUFFLEVECTOR]] to <8 x i16>
; CHECK-NEXT: [[SHL:%.*]] = shl nuw <8 x i16> [[BITCAST1]], splat (i16 4)
; CHECK-NEXT: [[BITCAST2:%.*]] = bitcast <8 x i16> [[SHL]] to <16 x i8>
; CHECK-NEXT: [[AND3:%.*]] = and <16 x i8> [[BITCAST2]], splat (i8 -16)
; CHECK-NEXT: ret <16 x i8> [[AND3]]
; CHECK-NEXT: ret <16 x i8> [[BITCAST2]]
;
%and = and <16 x i8> %arg, splat (i8 15)
%shufflevector = shufflevector <16 x i8> %and, <16 x i8> poison, <16 x i32> <i32 1, i32 0, i32 3, i32 2, i32 5, i32 4, i32 7, i32 6, i32 9, i32 8, i32 11, i32 10, i32 13, i32 12, i32 15, i32 14>
Expand All @@ -53,8 +51,7 @@ define <16 x i8> @knownbits_reverse_shuffle_masked_shift(<16 x i8> %arg) {
; CHECK-NEXT: [[BITCAST1:%.*]] = bitcast <16 x i8> [[SHUFFLEVECTOR]] to <8 x i16>
; CHECK-NEXT: [[SHL:%.*]] = shl nuw <8 x i16> [[BITCAST1]], splat (i16 4)
; CHECK-NEXT: [[BITCAST2:%.*]] = bitcast <8 x i16> [[SHL]] to <16 x i8>
; CHECK-NEXT: [[AND3:%.*]] = and <16 x i8> [[BITCAST2]], splat (i8 -16)
; CHECK-NEXT: ret <16 x i8> [[AND3]]
; CHECK-NEXT: ret <16 x i8> [[BITCAST2]]
;
%and = and <16 x i8> %arg, splat (i8 15)
%shufflevector = shufflevector <16 x i8> %and, <16 x i8> poison, <16 x i32> <i32 3, i32 2, i32 1, i32 0, i32 7, i32 6, i32 5, i32 4, i32 11, i32 10, i32 9, i32 8, i32 15, i32 14, i32 13, i32 12>
Expand All @@ -70,8 +67,7 @@ define <16 x i8> @knownbits_extract_bit(<8 x i16> %arg) {
; CHECK-SAME: <8 x i16> [[ARG:%.*]]) {
; CHECK-NEXT: [[LSHR:%.*]] = lshr <8 x i16> [[ARG]], splat (i16 15)
; CHECK-NEXT: [[BITCAST1:%.*]] = bitcast <8 x i16> [[LSHR]] to <16 x i8>
; CHECK-NEXT: [[AND:%.*]] = and <16 x i8> [[BITCAST1]], splat (i8 1)
; CHECK-NEXT: ret <16 x i8> [[AND]]
; CHECK-NEXT: ret <16 x i8> [[BITCAST1]]
;
%lshr = lshr <8 x i16> %arg, splat (i16 15)
%bitcast1 = bitcast <8 x i16> %lshr to <16 x i8>
Expand All @@ -88,7 +84,8 @@ define { i32, i1 } @knownbits_popcount_add_with_overflow(<2 x i64> %arg1, <2 x i
; CHECK-NEXT: [[CALL9:%.*]] = tail call range(i64 0, 65) <2 x i64> @llvm.ctpop.v2i64(<2 x i64> [[ARG2]])
; CHECK-NEXT: [[BITCAST10:%.*]] = bitcast <2 x i64> [[CALL9]] to <4 x i32>
; CHECK-NEXT: [[EXTRACTELEMENT11:%.*]] = extractelement <4 x i32> [[BITCAST10]], i64 0
; CHECK-NEXT: [[TMP1:%.*]] = tail call { i32, i1 } @llvm.uadd.with.overflow.i32(i32 [[EXTRACTELEMENT]], i32 [[EXTRACTELEMENT11]])
; CHECK-NEXT: [[CALL12:%.*]] = add nuw nsw i32 [[EXTRACTELEMENT]], [[EXTRACTELEMENT11]]
; CHECK-NEXT: [[TMP1:%.*]] = insertvalue { i32, i1 } { i32 poison, i1 false }, i32 [[CALL12]], 0
; CHECK-NEXT: ret { i32, i1 } [[TMP1]]
;
%call = tail call <2 x i64> @llvm.ctpop.v2i64(<2 x i64> %arg1)
Expand All @@ -110,11 +107,7 @@ define <16 x i8> @knownbits_shuffle_add_shift_v32i8(<16 x i8> %arg1, <8 x i16> %
; CHECK-NEXT: [[BITCAST11:%.*]] = bitcast <8 x i16> [[SHL10]] to <16 x i8>
; CHECK-NEXT: [[ADD12:%.*]] = add <16 x i8> [[BITCAST11]], [[BITCAST7]]
; CHECK-NEXT: [[ADD14:%.*]] = add <16 x i8> [[ADD12]], [[ARG1]]
; CHECK-NEXT: [[BITCAST14:%.*]] = bitcast <16 x i8> [[ADD12]] to <8 x i16>
; CHECK-NEXT: [[SHL15:%.*]] = shl <8 x i16> [[BITCAST14]], splat (i16 8)
; CHECK-NEXT: [[BITCAST16:%.*]] = bitcast <8 x i16> [[SHL15]] to <16 x i8>
; CHECK-NEXT: [[ADD13:%.*]] = add <16 x i8> [[ADD14]], [[BITCAST16]]
; CHECK-NEXT: ret <16 x i8> [[ADD13]]
; CHECK-NEXT: ret <16 x i8> [[ADD14]]
;
%shl6 = shl <8 x i16> %arg2, splat (i16 8)
%bitcast7 = bitcast <8 x i16> %shl6 to <16 x i8>
Expand Down