Skip to content

[X86] Fix miscompile in combineShiftRightArithmetic #86597

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 16 additions & 13 deletions llvm/lib/Target/X86/X86ISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47406,10 +47406,13 @@ static SDValue combineShiftRightArithmetic(SDNode *N, SelectionDAG &DAG,
return DAG.getNode(X86ISD::VSRAV, DL, N->getVTList(), N0, ShrAmtVal);
}

// fold (ashr (shl, a, [56,48,32,24,16]), SarConst)
// into (shl, (sext (a), [56,48,32,24,16] - SarConst)) or
// into (lshr, (sext (a), SarConst - [56,48,32,24,16]))
// depending on sign of (SarConst - [56,48,32,24,16])
// fold (SRA (SHL X, ShlConst), SraConst)
// into (SHL (sext_in_reg X), ShlConst - SraConst)
// or (sext_in_reg X)
// or (SRA (sext_in_reg X), SraConst - ShlConst)
// depending on relation between SraConst and ShlConst.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

between*

// We only do this if (Size - ShlConst) is equal to 8, 16 or 32. That allows
// us to do the sext_in_reg from corresponding bit.

// sexts in X86 are MOVs. The MOVs have the same code size
// as above SHIFTs (only SHIFT on 1 has lower code size).
Expand All @@ -47425,29 +47428,29 @@ static SDValue combineShiftRightArithmetic(SDNode *N, SelectionDAG &DAG,
SDValue N00 = N0.getOperand(0);
SDValue N01 = N0.getOperand(1);
APInt ShlConst = N01->getAsAPIntVal();
APInt SarConst = N1->getAsAPIntVal();
APInt SraConst = N1->getAsAPIntVal();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you fix this comment on line 47412 depending on sign of (SarConst - [56,48,32,24,16])

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you also fix 47411 to say ashr instead of lshr to match what the code does

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've cleaned up a bit now:

  • fixed the lshr->ashr thing (actually using SRA/SHL now, instead of the IR names when describing the folds)
  • renamed SarConst -> SraConst
  • got rid of the [56,48,32,24,16] comments (I did not really understand those comments and they did not fully match what the code was doing afaict)

EVT CVT = N1.getValueType();

if (SarConst.isNegative())
if (CVT != N01.getValueType())
return SDValue();
if (SraConst.isNegative())
return SDValue();

for (MVT SVT : { MVT::i8, MVT::i16, MVT::i32 }) {
unsigned ShiftSize = SVT.getSizeInBits();
// skipping types without corresponding sext/zext and
// ShlConst that is not one of [56,48,32,24,16]
// Only deal with (Size - ShlConst) being equal to 8, 16 or 32.
if (ShiftSize >= Size || ShlConst != Size - ShiftSize)
continue;
SDLoc DL(N);
SDValue NN =
DAG.getNode(ISD::SIGN_EXTEND_INREG, DL, VT, N00, DAG.getValueType(SVT));
SarConst = SarConst - (Size - ShiftSize);
if (SarConst == 0)
if (SraConst.eq(ShlConst))
return NN;
if (SarConst.isNegative())
if (SraConst.ult(ShlConst))
return DAG.getNode(ISD::SHL, DL, VT, NN,
DAG.getConstant(-SarConst, DL, CVT));
DAG.getConstant(ShlConst - SraConst, DL, CVT));
return DAG.getNode(ISD::SRA, DL, VT, NN,
DAG.getConstant(SarConst, DL, CVT));
DAG.getConstant(SraConst - ShlConst, DL, CVT));
}
return SDValue();
}
Expand Down
13 changes: 5 additions & 8 deletions llvm/test/CodeGen/X86/sar_fold.ll
Original file line number Diff line number Diff line change
Expand Up @@ -67,20 +67,17 @@ define void @shl144sar48(ptr %p) #0 {
ret void
}

; This is incorrect. The 142 least significant bits in the stored value should
; be zero, and but 142-157 should be taken from %a with a sign-extend into the
; two most significant bits.
define void @shl144sar2(ptr %p) #0 {
; CHECK-LABEL: shl144sar2:
; CHECK: # %bb.0:
; CHECK-NEXT: movl {{[0-9]+}}(%esp), %eax
; CHECK-NEXT: movswl (%eax), %ecx
; CHECK-NEXT: sarl $31, %ecx
; CHECK-NEXT: shll $14, %ecx
; CHECK-NEXT: movl %ecx, 16(%eax)
; CHECK-NEXT: movl %ecx, 8(%eax)
; CHECK-NEXT: movl %ecx, 12(%eax)
; CHECK-NEXT: movl %ecx, 4(%eax)
; CHECK-NEXT: movl %ecx, (%eax)
; CHECK-NEXT: movl $0, 8(%eax)
; CHECK-NEXT: movl $0, 12(%eax)
; CHECK-NEXT: movl $0, 4(%eax)
; CHECK-NEXT: movl $0, (%eax)
; CHECK-NEXT: retl
%a = load i160, ptr %p
%1 = shl i160 %a, 144
Expand Down