Skip to content

[Codegen][LegalizeIntegerTypes] Improve shift through stack #96151

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Sep 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 40 additions & 24 deletions llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4608,14 +4608,23 @@ void DAGTypeLegalizer::ExpandIntRes_ShiftThroughStack(SDNode *N, SDValue &Lo,
SDValue ShAmt = N->getOperand(1);
EVT ShAmtVT = ShAmt.getValueType();

// This legalization is optimal when the shift is by a multiple of byte width,
// %x * 8 <-> %x << 3 so 3 low bits should be be known zero.
bool ShiftByByteMultiple =
DAG.computeKnownBits(ShAmt).countMinTrailingZeros() >= 3;
EVT LoadVT = VT;
do {
LoadVT = TLI.getTypeToTransformTo(*DAG.getContext(), LoadVT);
} while (!TLI.isTypeLegal(LoadVT));

const unsigned ShiftUnitInBits = LoadVT.getStoreSizeInBits();
assert(ShiftUnitInBits <= VT.getScalarSizeInBits());
assert(isPowerOf2_32(ShiftUnitInBits) &&
"Shifting unit is not a a power of two!");

const bool IsOneStepShift =
DAG.computeKnownBits(ShAmt).countMinTrailingZeros() >=
Log2_32(ShiftUnitInBits);

// If we can't do it as one step, we'll have two uses of shift amount,
// and thus must freeze it.
if (!ShiftByByteMultiple)
if (!IsOneStepShift)
ShAmt = DAG.getFreeze(ShAmt);

unsigned VTBitWidth = VT.getScalarSizeInBits();
Expand All @@ -4629,10 +4638,9 @@ void DAGTypeLegalizer::ExpandIntRes_ShiftThroughStack(SDNode *N, SDValue &Lo,

// Get a temporary stack slot 2x the width of our VT.
// FIXME: reuse stack slots?
// FIXME: should we be more picky about alignment?
Align StackSlotAlignment(1);
SDValue StackPtr = DAG.CreateStackTemporary(
TypeSize::getFixed(StackSlotByteWidth), StackSlotAlignment);
Align StackAlign = DAG.getReducedAlign(StackSlotVT, /*UseABI=*/false);
SDValue StackPtr =
DAG.CreateStackTemporary(StackSlotVT.getStoreSize(), StackAlign);
EVT PtrTy = StackPtr.getValueType();
SDValue Ch = DAG.getEntryNode();

Expand All @@ -4652,15 +4660,22 @@ void DAGTypeLegalizer::ExpandIntRes_ShiftThroughStack(SDNode *N, SDValue &Lo,
Init = DAG.getNode(ISD::BUILD_PAIR, dl, StackSlotVT, AllZeros, Shiftee);
}
// And spill it into the stack slot.
Ch = DAG.getStore(Ch, dl, Init, StackPtr, StackPtrInfo, StackSlotAlignment);
Ch = DAG.getStore(Ch, dl, Init, StackPtr, StackPtrInfo, StackAlign);

// Now, compute the full-byte offset into stack slot from where we can load.
// We have shift amount, which is in bits, but in multiples of byte.
// So just divide by CHAR_BIT.
// We have shift amount, which is in bits. Offset should point to an aligned
// address.
SDNodeFlags Flags;
if (ShiftByByteMultiple)
Flags.setExact(true);
SDValue ByteOffset = DAG.getNode(ISD::SRL, dl, ShAmtVT, ShAmt,
Flags.setExact(IsOneStepShift);
SDValue SrlTmp = DAG.getNode(
ISD::SRL, dl, ShAmtVT, ShAmt,
DAG.getConstant(Log2_32(ShiftUnitInBits), dl, ShAmtVT), Flags);
SDValue BitOffset =
DAG.getNode(ISD::SHL, dl, ShAmtVT, SrlTmp,
DAG.getConstant(Log2_32(ShiftUnitInBits), dl, ShAmtVT));

Flags.setExact(true);
SDValue ByteOffset = DAG.getNode(ISD::SRL, dl, ShAmtVT, BitOffset,
DAG.getConstant(3, dl, ShAmtVT), Flags);
// And clamp it, because OOB load is an immediate UB,
// while shift overflow would have *just* been poison.
Expand Down Expand Up @@ -4689,15 +4704,16 @@ void DAGTypeLegalizer::ExpandIntRes_ShiftThroughStack(SDNode *N, SDValue &Lo,
AdjStackPtr = DAG.getMemBasePlusOffset(AdjStackPtr, ByteOffset, dl);

// And load it! While the load is not legal, legalizing it is obvious.
SDValue Res = DAG.getLoad(
VT, dl, Ch, AdjStackPtr,
MachinePointerInfo::getUnknownStack(DAG.getMachineFunction()), Align(1));
// We've performed the shift by a CHAR_BIT * [_ShAmt / CHAR_BIT_]

// If we may still have a less-than-CHAR_BIT to shift by, do so now.
if (!ShiftByByteMultiple) {
SDValue ShAmtRem = DAG.getNode(ISD::AND, dl, ShAmtVT, ShAmt,
DAG.getConstant(7, dl, ShAmtVT));
SDValue Res =
DAG.getLoad(VT, dl, Ch, AdjStackPtr,
MachinePointerInfo::getUnknownStack(DAG.getMachineFunction()),
commonAlignment(StackAlign, LoadVT.getStoreSize()));

// If we may still have a remaining bits to shift by, do so now.
if (!IsOneStepShift) {
SDValue ShAmtRem =
DAG.getNode(ISD::AND, dl, ShAmtVT, ShAmt,
DAG.getConstant(ShiftUnitInBits - 1, dl, ShAmtVT));
Res = DAG.getNode(N->getOpcode(), dl, VT, Res, ShAmtRem);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -186,23 +186,68 @@ define void @lshr_32bytes(ptr %src.ptr, ptr %byteOff.ptr, ptr %dst) nounwind {
; ALL-NEXT: ldr q1, [x0]
; ALL-NEXT: stp x9, x8, [sp, #16]
; ALL-NEXT: mov x8, sp
; ALL-NEXT: and x9, x10, #0x1f
; ALL-NEXT: and x9, x10, #0x18
; ALL-NEXT: str q1, [sp]
; ALL-NEXT: add x8, x8, x9
; ALL-NEXT: lsl x9, x10, #3
; ALL-NEXT: stp q0, q0, [sp, #32]
; ALL-NEXT: ldp x11, x10, [x8, #16]
; ALL-NEXT: mvn w13, w9
; ALL-NEXT: ldp x8, x12, [x8]
; ALL-NEXT: and x9, x9, #0x38
; ALL-NEXT: lsl x14, x10, #1
; ALL-NEXT: lsl x15, x11, #1
; ALL-NEXT: lsr x11, x11, x9
; ALL-NEXT: lsl x16, x12, #1
; ALL-NEXT: lsr x10, x10, x9
; ALL-NEXT: lsr x12, x12, x9
; ALL-NEXT: lsl x14, x14, x13
; ALL-NEXT: lsr x8, x8, x9
; ALL-NEXT: lsl x9, x16, x13
; ALL-NEXT: lsl x13, x15, x13
; ALL-NEXT: orr x11, x14, x11
; ALL-NEXT: orr x8, x9, x8
; ALL-NEXT: orr x9, x12, x13
; ALL-NEXT: stp x11, x10, [x2, #16]
; ALL-NEXT: stp x8, x9, [x2]
; ALL-NEXT: add sp, sp, #64
; ALL-NEXT: ret
%src = load i256, ptr %src.ptr, align 1
%byteOff = load i256, ptr %byteOff.ptr, align 1
%bitOff = shl i256 %byteOff, 3
%res = lshr i256 %src, %bitOff
store i256 %res, ptr %dst, align 1
ret void
}

define void @lshr_32bytes_dwordOff(ptr %src.ptr, ptr %dwordOff.ptr, ptr %dst) nounwind {
; ALL-LABEL: lshr_32bytes_dwordOff:
; ALL: // %bb.0:
; ALL-NEXT: sub sp, sp, #64
; ALL-NEXT: ldp x9, x8, [x0, #16]
; ALL-NEXT: movi v0.2d, #0000000000000000
; ALL-NEXT: ldr x10, [x1]
; ALL-NEXT: ldr q1, [x0]
; ALL-NEXT: stp x9, x8, [sp, #16]
; ALL-NEXT: ubfiz x8, x10, #3, #2
; ALL-NEXT: mov x9, sp
; ALL-NEXT: str q1, [sp]
; ALL-NEXT: stp q0, q0, [sp, #32]
; ALL-NEXT: add x8, x9, x8
; ALL-NEXT: ldp x10, x9, [x8, #16]
; ALL-NEXT: ldr q0, [x8]
; ALL-NEXT: str q0, [x2]
; ALL-NEXT: stp x10, x9, [x2, #16]
; ALL-NEXT: add sp, sp, #64
; ALL-NEXT: ret
%src = load i256, ptr %src.ptr, align 1
%byteOff = load i256, ptr %byteOff.ptr, align 1
%bitOff = shl i256 %byteOff, 3
%dwordOff = load i256, ptr %dwordOff.ptr, align 1
%bitOff = shl i256 %dwordOff, 6
%res = lshr i256 %src, %bitOff
store i256 %res, ptr %dst, align 1
ret void
}

define void @shl_32bytes(ptr %src.ptr, ptr %byteOff.ptr, ptr %dst) nounwind {
; ALL-LABEL: shl_32bytes:
; ALL: // %bb.0:
Expand All @@ -213,48 +258,139 @@ define void @shl_32bytes(ptr %src.ptr, ptr %byteOff.ptr, ptr %dst) nounwind {
; ALL-NEXT: ldr q1, [x0]
; ALL-NEXT: stp x9, x8, [sp, #48]
; ALL-NEXT: mov x8, sp
; ALL-NEXT: and x9, x10, #0x1f
; ALL-NEXT: and x9, x10, #0x18
; ALL-NEXT: add x8, x8, #32
; ALL-NEXT: stp q0, q0, [sp]
; ALL-NEXT: str q1, [sp, #32]
; ALL-NEXT: sub x8, x8, x9
; ALL-NEXT: lsl x9, x10, #3
; ALL-NEXT: ldp x10, x11, [x8]
; ALL-NEXT: ldp x12, x8, [x8, #16]
; ALL-NEXT: mvn w13, w9
; ALL-NEXT: and x9, x9, #0x38
; ALL-NEXT: lsr x14, x10, #1
; ALL-NEXT: lsr x15, x11, #1
; ALL-NEXT: lsl x11, x11, x9
; ALL-NEXT: lsr x16, x12, #1
; ALL-NEXT: lsl x10, x10, x9
; ALL-NEXT: lsl x12, x12, x9
; ALL-NEXT: lsr x14, x14, x13
; ALL-NEXT: lsl x8, x8, x9
; ALL-NEXT: lsr x9, x16, x13
; ALL-NEXT: lsr x13, x15, x13
; ALL-NEXT: orr x11, x11, x14
; ALL-NEXT: orr x8, x8, x9
; ALL-NEXT: orr x9, x12, x13
; ALL-NEXT: stp x10, x11, [x2]
; ALL-NEXT: stp x9, x8, [x2, #16]
; ALL-NEXT: add sp, sp, #64
; ALL-NEXT: ret
%src = load i256, ptr %src.ptr, align 1
%byteOff = load i256, ptr %byteOff.ptr, align 1
%bitOff = shl i256 %byteOff, 3
%res = shl i256 %src, %bitOff
store i256 %res, ptr %dst, align 1
ret void
}

define void @shl_32bytes_dwordOff(ptr %src.ptr, ptr %dwordOff.ptr, ptr %dst) nounwind {
; ALL-LABEL: shl_32bytes_dwordOff:
; ALL: // %bb.0:
; ALL-NEXT: sub sp, sp, #64
; ALL-NEXT: ldp x9, x8, [x0, #16]
; ALL-NEXT: movi v0.2d, #0000000000000000
; ALL-NEXT: ldr x10, [x1]
; ALL-NEXT: ldr q1, [x0]
; ALL-NEXT: stp x9, x8, [sp, #48]
; ALL-NEXT: mov x8, sp
; ALL-NEXT: ubfiz x9, x10, #3, #2
; ALL-NEXT: add x8, x8, #32
; ALL-NEXT: stp q0, q1, [sp, #16]
; ALL-NEXT: str q0, [sp]
; ALL-NEXT: sub x8, x8, x9
; ALL-NEXT: ldp x9, x10, [x8, #16]
; ALL-NEXT: ldr q0, [x8]
; ALL-NEXT: str q0, [x2]
; ALL-NEXT: stp x9, x10, [x2, #16]
; ALL-NEXT: add sp, sp, #64
; ALL-NEXT: ret
%src = load i256, ptr %src.ptr, align 1
%byteOff = load i256, ptr %byteOff.ptr, align 1
%bitOff = shl i256 %byteOff, 3
%dwordOff = load i256, ptr %dwordOff.ptr, align 1
%bitOff = shl i256 %dwordOff, 6
%res = shl i256 %src, %bitOff
store i256 %res, ptr %dst, align 1
ret void
}

define void @ashr_32bytes(ptr %src.ptr, ptr %byteOff.ptr, ptr %dst) nounwind {
; ALL-LABEL: ashr_32bytes:
; ALL: // %bb.0:
; ALL-NEXT: sub sp, sp, #64
; ALL-NEXT: ldp x9, x8, [x0, #16]
; ALL-NEXT: ldr x10, [x1]
; ALL-NEXT: ldr q0, [x0]
; ALL-NEXT: and x10, x10, #0x1f
; ALL-NEXT: and x11, x10, #0x18
; ALL-NEXT: stp x9, x8, [sp, #16]
; ALL-NEXT: asr x8, x8, #63
; ALL-NEXT: mov x9, sp
; ALL-NEXT: str q0, [sp]
; ALL-NEXT: add x9, x9, x11
; ALL-NEXT: stp x8, x8, [sp, #48]
; ALL-NEXT: stp x8, x8, [sp, #32]
; ALL-NEXT: lsl x8, x10, #3
; ALL-NEXT: ldp x11, x10, [x9, #16]
; ALL-NEXT: ldp x9, x12, [x9]
; ALL-NEXT: mvn w13, w8
; ALL-NEXT: and x8, x8, #0x38
; ALL-NEXT: lsl x14, x10, #1
; ALL-NEXT: lsl x15, x11, #1
; ALL-NEXT: lsr x11, x11, x8
; ALL-NEXT: lsl x16, x12, #1
; ALL-NEXT: asr x10, x10, x8
; ALL-NEXT: lsr x12, x12, x8
; ALL-NEXT: lsl x14, x14, x13
; ALL-NEXT: lsr x8, x9, x8
; ALL-NEXT: lsl x9, x16, x13
; ALL-NEXT: lsl x13, x15, x13
; ALL-NEXT: orr x11, x14, x11
; ALL-NEXT: orr x8, x9, x8
; ALL-NEXT: orr x9, x12, x13
; ALL-NEXT: stp x11, x10, [x2, #16]
; ALL-NEXT: stp x8, x9, [x2]
; ALL-NEXT: add sp, sp, #64
; ALL-NEXT: ret
%src = load i256, ptr %src.ptr, align 1
%byteOff = load i256, ptr %byteOff.ptr, align 1
%bitOff = shl i256 %byteOff, 3
%res = ashr i256 %src, %bitOff
store i256 %res, ptr %dst, align 1
ret void
}

define void @ashr_32bytes_dwordOff(ptr %src.ptr, ptr %dwordOff.ptr, ptr %dst) nounwind {
; ALL-LABEL: ashr_32bytes_dwordOff:
; ALL: // %bb.0:
; ALL-NEXT: sub sp, sp, #64
; ALL-NEXT: ldp x9, x8, [x0, #16]
; ALL-NEXT: ldr x10, [x1]
; ALL-NEXT: ldr q0, [x0]
; ALL-NEXT: stp x9, x8, [sp, #16]
; ALL-NEXT: asr x8, x8, #63
; ALL-NEXT: ubfiz x9, x10, #3, #2
; ALL-NEXT: mov x10, sp
; ALL-NEXT: str q0, [sp]
; ALL-NEXT: stp x8, x8, [sp, #48]
; ALL-NEXT: stp x8, x8, [sp, #32]
; ALL-NEXT: add x8, x9, x10
; ALL-NEXT: add x8, x10, x9
; ALL-NEXT: ldp x10, x9, [x8, #16]
; ALL-NEXT: ldr q0, [x8]
; ALL-NEXT: str q0, [x2]
; ALL-NEXT: stp x10, x9, [x2, #16]
; ALL-NEXT: add sp, sp, #64
; ALL-NEXT: ret
%src = load i256, ptr %src.ptr, align 1
%byteOff = load i256, ptr %byteOff.ptr, align 1
%bitOff = shl i256 %byteOff, 3
%dwordOff = load i256, ptr %dwordOff.ptr, align 1
%bitOff = shl i256 %dwordOff, 6
%res = ashr i256 %src, %bitOff
store i256 %res, ptr %dst, align 1
ret void
Expand Down
Loading
Loading