Skip to content

[SelectionDAG] Let ComputeKnownSignBits handle (shl (ext X), C) #97695

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jul 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 25 additions & 4 deletions llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4615,12 +4615,33 @@ unsigned SelectionDAG::ComputeNumSignBits(SDValue Op, const APInt &DemandedElts,
Tmp = std::min<uint64_t>(Tmp + *ShAmt, VTBits);
return Tmp;
case ISD::SHL:
if (std::optional<uint64_t> ShAmt =
getValidMaximumShiftAmount(Op, DemandedElts, Depth + 1)) {
if (std::optional<ConstantRange> ShAmtRange =
getValidShiftAmountRange(Op, DemandedElts, Depth + 1)) {
uint64_t MaxShAmt = ShAmtRange->getUnsignedMax().getZExtValue();
uint64_t MinShAmt = ShAmtRange->getUnsignedMin().getZExtValue();
// Try to look through ZERO/SIGN/ANY_EXTEND. If all extended bits are
// shifted out, then we can compute the number of sign bits for the
// operand being extended. A future improvement could be to pass along the
// "shifted left by" information in the recursive calls to
// ComputeKnownSignBits. Allowing us to handle this more generically.
if (ISD::isExtOpcode(Op.getOperand(0).getOpcode())) {
SDValue Ext = Op.getOperand(0);
EVT ExtVT = Ext.getValueType();
SDValue Extendee = Ext.getOperand(0);
EVT ExtendeeVT = Extendee.getValueType();
uint64_t SizeDifference =
ExtVT.getScalarSizeInBits() - ExtendeeVT.getScalarSizeInBits();
if (SizeDifference <= MinShAmt) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a explanation comment (similar to the patch summary)

Tmp = SizeDifference +
ComputeNumSignBits(Extendee, DemandedElts, Depth + 1);
if (MaxShAmt < Tmp)
return Tmp - MaxShAmt;
}
}
// shl destroys sign bits, ensure it doesn't shift out all sign bits.
Tmp = ComputeNumSignBits(Op.getOperand(0), DemandedElts, Depth + 1);
if (*ShAmt < Tmp)
return Tmp - *ShAmt;
if (MaxShAmt < Tmp)
return Tmp - MaxShAmt;
}
break;
case ISD::AND:
Expand Down
158 changes: 158 additions & 0 deletions llvm/test/CodeGen/X86/known-signbits-shl.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
; RUN: llc < %s -mtriple=x86_64-linux | FileCheck %s --check-prefix=X64

; Verify that we can look through a ZERO_EXTEND/ANY_EXTEND when doing
; ComputeNumSignBits for SHL.
; We use the (sshlsat x, c) -> (shl x, c) fold as verification.
; That fold should happen if c is less than the number of sign bits in x

define void @computeNumSignBits_shl_zext_1(i8 %x, ptr %p) nounwind {
; X64-LABEL: computeNumSignBits_shl_zext_1:
; X64: # %bb.0:
; X64-NEXT: sarb $5, %dil
; X64-NEXT: movzbl %dil, %eax
; X64-NEXT: movl %eax, %ecx
; X64-NEXT: shll $11, %ecx
; X64-NEXT: movw %cx, (%rsi)
; X64-NEXT: movl %eax, %ecx
; X64-NEXT: shll $12, %ecx
; X64-NEXT: movw %cx, (%rsi)
; X64-NEXT: shll $13, %eax
; X64-NEXT: movw %ax, (%rsi)
; X64-NEXT: retq
%ashr = ashr i8 %x, 5
%zext = zext i8 %ashr to i16
%nsb4 = shl i16 %zext, 10
; Expecting (sshlsat x, c) -> (shl x, c) fold.
%tmp1 = call i16 @llvm.sshl.sat.i16(i16 %nsb4, i16 1)
store volatile i16 %tmp1, ptr %p
; Expecting (sshlsat x, c) -> (shl x, c) fold.
%tmp2 = call i16 @llvm.sshl.sat.i16(i16 %nsb4, i16 2)
store volatile i16 %tmp2, ptr %p
; Expecting (sshlsat x, c) -> (shl x, c) fold.
%tmp3 = call i16 @llvm.sshl.sat.i16(i16 %nsb4, i16 3)
store volatile i16 %tmp3, ptr %p
ret void
}

define void @computeNumSignBits_shl_zext_2(i8 %x, ptr %p) nounwind {
; X64-LABEL: computeNumSignBits_shl_zext_2:
; X64: # %bb.0:
; X64-NEXT: sarb $5, %dil
; X64-NEXT: movzbl %dil, %eax
; X64-NEXT: movl %eax, %ecx
; X64-NEXT: shll $10, %ecx
; X64-NEXT: xorl %edx, %edx
; X64-NEXT: testw %cx, %cx
; X64-NEXT: sets %dl
; X64-NEXT: addl $32767, %edx # imm = 0x7FFF
; X64-NEXT: shll $14, %eax
; X64-NEXT: movswl %ax, %edi
; X64-NEXT: shrl $4, %edi
; X64-NEXT: cmpw %di, %cx
; X64-NEXT: cmovnel %edx, %eax
; X64-NEXT: movw %ax, (%rsi)
; X64-NEXT: retq
%ashr = ashr i8 %x, 5
%zext = zext i8 %ashr to i16
%nsb4 = shl i16 %zext, 10
; 4 sign bits. Not expecting (sshlsat x, c) -> (shl x, c) fold.
%tmp4 = call i16 @llvm.sshl.sat.i16(i16 %nsb4, i16 4)
store volatile i16 %tmp4, ptr %p
ret void
}

define void @computeNumSignBits_shl_zext_vec_1(<2 x i8> %x, ptr %p) nounwind {
; X64-LABEL: computeNumSignBits_shl_zext_vec_1:
; X64: # %bb.0:
; X64-NEXT: psrlw $5, %xmm0
; X64-NEXT: pand {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0
; X64-NEXT: movdqa {{.*#+}} xmm1 = [4,4,4,4,4,4,4,4,4,4,4,4,4,4,4,4]
; X64-NEXT: pxor %xmm1, %xmm0
; X64-NEXT: psubb %xmm1, %xmm0
; X64-NEXT: punpcklbw {{.*#+}} xmm0 = xmm0[0,0,1,1,2,2,3,3,4,4,5,5,6,6,7,7]
; X64-NEXT: pmullw {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0 # [2048,8192,u,u,u,u,u,u]
; X64-NEXT: movd %xmm0, (%rdi)
; X64-NEXT: retq
%ashr = ashr <2 x i8> %x, <i8 5, i8 5>
%zext = zext <2 x i8> %ashr to <2 x i16>
%nsb4_2 = shl <2 x i16> %zext, <i16 10, i16 12>
; Expecting (sshlsat x, c) -> (shl x, c) fold.
%tmp1 = call <2 x i16> @llvm.sshl.sat.v2i16(<2 x i16> %nsb4_2, <2 x i16> <i16 1, i16 1>)
store volatile <2 x i16> %tmp1, ptr %p
ret void
}

define void @computeNumSignBits_shl_zext_vec_2(<2 x i8> %x, ptr %p) nounwind {
; X64-LABEL: computeNumSignBits_shl_zext_vec_2:
; X64: # %bb.0:
; X64-NEXT: psrlw $5, %xmm0
; X64-NEXT: pand {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0
; X64-NEXT: movdqa {{.*#+}} xmm1 = [4,4,4,4,4,4,4,4,4,4,4,4,4,4,4,4]
; X64-NEXT: pxor %xmm1, %xmm0
; X64-NEXT: psubb %xmm1, %xmm0
; X64-NEXT: pxor %xmm1, %xmm1
; X64-NEXT: punpcklbw {{.*#+}} xmm0 = xmm0[0],xmm1[0],xmm0[1],xmm1[1],xmm0[2],xmm1[2],xmm0[3],xmm1[3],xmm0[4],xmm1[4],xmm0[5],xmm1[5],xmm0[6],xmm1[6],xmm0[7],xmm1[7]
; X64-NEXT: pmullw {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0 # [1024,4096,u,u,u,u,u,u]
; X64-NEXT: movdqa {{.*#+}} xmm2 = [32768,32768,32768,32768,32768,32768,32768,32768]
; X64-NEXT: pand %xmm0, %xmm2
; X64-NEXT: pcmpgtw %xmm0, %xmm1
; X64-NEXT: pandn {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm1
; X64-NEXT: por %xmm2, %xmm1
; X64-NEXT: movdqa %xmm0, %xmm2
; X64-NEXT: psllw $2, %xmm2
; X64-NEXT: movdqa %xmm2, %xmm3
; X64-NEXT: psraw $2, %xmm3
; X64-NEXT: pcmpeqw %xmm0, %xmm3
; X64-NEXT: movdqa %xmm3, %xmm0
; X64-NEXT: pandn %xmm1, %xmm0
; X64-NEXT: pand %xmm2, %xmm3
; X64-NEXT: por %xmm0, %xmm3
; X64-NEXT: movd %xmm3, (%rdi)
; X64-NEXT: retq
%ashr = ashr <2 x i8> %x, <i8 5, i8 5>
%zext = zext <2 x i8> %ashr to <2 x i16>
%nsb4_2 = shl <2 x i16> %zext, <i16 10, i16 12>
; Not expecting (sshlsat x, c) -> (shl x, c) fold.
; Because only 2 sign bits in element 1.
%tmp1 = call <2 x i16> @llvm.sshl.sat.v2i16(<2 x i16> %nsb4_2, <2 x i16> <i16 2, i16 2>)
store volatile <2 x i16> %tmp1, ptr %p
ret void
}

define void @computeNumSignBits_shl_zext_vec_3(<2 x i8> %x, ptr %p) nounwind {
; X64-LABEL: computeNumSignBits_shl_zext_vec_3:
; X64: # %bb.0:
; X64-NEXT: psrlw $5, %xmm0
; X64-NEXT: pand {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0
; X64-NEXT: movdqa {{.*#+}} xmm1 = [4,4,4,4,4,4,4,4,4,4,4,4,4,4,4,4]
; X64-NEXT: pxor %xmm1, %xmm0
; X64-NEXT: psubb %xmm1, %xmm0
; X64-NEXT: pxor %xmm1, %xmm1
; X64-NEXT: punpcklbw {{.*#+}} xmm0 = xmm0[0],xmm1[0],xmm0[1],xmm1[1],xmm0[2],xmm1[2],xmm0[3],xmm1[3],xmm0[4],xmm1[4],xmm0[5],xmm1[5],xmm0[6],xmm1[6],xmm0[7],xmm1[7]
; X64-NEXT: pmullw {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0 # [16384,4096,u,u,u,u,u,u]
; X64-NEXT: movdqa {{.*#+}} xmm2 = [32768,32768,32768,32768,32768,32768,32768,32768]
; X64-NEXT: pand %xmm0, %xmm2
; X64-NEXT: pcmpgtw %xmm0, %xmm1
; X64-NEXT: pandn {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm1
; X64-NEXT: por %xmm2, %xmm1
; X64-NEXT: movdqa %xmm0, %xmm2
; X64-NEXT: paddw %xmm0, %xmm2
; X64-NEXT: movdqa %xmm2, %xmm3
; X64-NEXT: psraw $1, %xmm3
; X64-NEXT: pcmpeqw %xmm0, %xmm3
; X64-NEXT: movdqa %xmm3, %xmm0
; X64-NEXT: pandn %xmm1, %xmm0
; X64-NEXT: pand %xmm2, %xmm3
; X64-NEXT: por %xmm0, %xmm3
; X64-NEXT: movd %xmm3, (%rdi)
; X64-NEXT: retq
%ashr = ashr <2 x i8> %x, <i8 5, i8 5>
%zext = zext <2 x i8> %ashr to <2 x i16>
%nsb1_2 = shl <2 x i16> %zext, <i16 14, i16 12>
; Not expecting (sshlsat x, c) -> (shl x, c) fold.
; Because all sign bits shifted out for element 0
%tmp1 = call <2 x i16> @llvm.sshl.sat.v2i16(<2 x i16> %nsb1_2, <2 x i16> <i16 1, i16 1>)
store volatile <2 x i16> %tmp1, ptr %p
ret void
}
Loading