Skip to content

Commit ed443d8

Browse files
committed
[AggressiveInstCombine] Only fold consecutive shifts of loads with constant shift amounts
This is what the code assumed but never actually checked. Fixes #62509. Reviewed By: nikic Differential Revision: https://reviews.llvm.org/D149896
1 parent dc39d98 commit ed443d8

File tree

2 files changed

+60
-11
lines changed

2 files changed

+60
-11
lines changed

llvm/lib/Transforms/AggressiveInstCombine/AggressiveInstCombine.cpp

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -611,7 +611,7 @@ struct LoadOps {
611611
LoadInst *RootInsert = nullptr;
612612
bool FoundRoot = false;
613613
uint64_t LoadSize = 0;
614-
Value *Shift = nullptr;
614+
const APInt *Shift = nullptr;
615615
Type *ZextType;
616616
AAMDNodes AATags;
617617
};
@@ -621,15 +621,15 @@ struct LoadOps {
621621
// (ZExt(L1) << shift1) | ZExt(L2) -> ZExt(L3)
622622
static bool foldLoadsRecursive(Value *V, LoadOps &LOps, const DataLayout &DL,
623623
AliasAnalysis &AA) {
624-
Value *ShAmt2 = nullptr;
624+
const APInt *ShAmt2 = nullptr;
625625
Value *X;
626626
Instruction *L1, *L2;
627627

628628
// Go to the last node with loads.
629629
if (match(V, m_OneUse(m_c_Or(
630630
m_Value(X),
631631
m_OneUse(m_Shl(m_OneUse(m_ZExt(m_OneUse(m_Instruction(L2)))),
632-
m_Value(ShAmt2)))))) ||
632+
m_APInt(ShAmt2)))))) ||
633633
match(V, m_OneUse(m_Or(m_Value(X),
634634
m_OneUse(m_ZExt(m_OneUse(m_Instruction(L2)))))))) {
635635
if (!foldLoadsRecursive(X, LOps, DL, AA) && LOps.FoundRoot)
@@ -640,11 +640,11 @@ static bool foldLoadsRecursive(Value *V, LoadOps &LOps, const DataLayout &DL,
640640

641641
// Check if the pattern has loads
642642
LoadInst *LI1 = LOps.Root;
643-
Value *ShAmt1 = LOps.Shift;
643+
const APInt *ShAmt1 = LOps.Shift;
644644
if (LOps.FoundRoot == false &&
645645
(match(X, m_OneUse(m_ZExt(m_Instruction(L1)))) ||
646646
match(X, m_OneUse(m_Shl(m_OneUse(m_ZExt(m_OneUse(m_Instruction(L1)))),
647-
m_Value(ShAmt1)))))) {
647+
m_APInt(ShAmt1)))))) {
648648
LI1 = dyn_cast<LoadInst>(L1);
649649
}
650650
LoadInst *LI2 = dyn_cast<LoadInst>(L2);
@@ -719,12 +719,11 @@ static bool foldLoadsRecursive(Value *V, LoadOps &LOps, const DataLayout &DL,
719719
std::swap(ShAmt1, ShAmt2);
720720

721721
// Find Shifts values.
722-
const APInt *Temp;
723722
uint64_t Shift1 = 0, Shift2 = 0;
724-
if (ShAmt1 && match(ShAmt1, m_APInt(Temp)))
725-
Shift1 = Temp->getZExtValue();
726-
if (ShAmt2 && match(ShAmt2, m_APInt(Temp)))
727-
Shift2 = Temp->getZExtValue();
723+
if (ShAmt1)
724+
Shift1 = ShAmt1->getZExtValue();
725+
if (ShAmt2)
726+
Shift2 = ShAmt2->getZExtValue();
728727

729728
// First load is always LI1. This is where we put the new load.
730729
// Use the merged load size available from LI1 for forward loads.
@@ -816,7 +815,7 @@ static bool foldConsecutiveLoads(Instruction &I, const DataLayout &DL,
816815
// Check if shift needed. We need to shift with the amount of load1
817816
// shift if not zero.
818817
if (LOps.Shift)
819-
NewOp = Builder.CreateShl(NewOp, LOps.Shift);
818+
NewOp = Builder.CreateShl(NewOp, ConstantInt::get(I.getContext(), *LOps.Shift));
820819
I.replaceAllUsesWith(NewOp);
821820

822821
return true;

llvm/test/Transforms/AggressiveInstCombine/X86/or-load.ll

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2253,3 +2253,53 @@ define i32 @loadCombine_4consecutive_badinsert6(ptr %p) {
22532253
%o3 = or i32 %o2, %e1
22542254
ret i32 %o3
22552255
}
2256+
2257+
define i64 @loadCombine_nonConstShift1(ptr %arg, i8 %b) {
2258+
; ALL-LABEL: @loadCombine_nonConstShift1(
2259+
; ALL-NEXT: [[G1:%.*]] = getelementptr i8, ptr [[ARG:%.*]], i64 1
2260+
; ALL-NEXT: [[LD0:%.*]] = load i8, ptr [[ARG]], align 1
2261+
; ALL-NEXT: [[LD1:%.*]] = load i8, ptr [[G1]], align 1
2262+
; ALL-NEXT: [[Z0:%.*]] = zext i8 [[LD0]] to i64
2263+
; ALL-NEXT: [[Z1:%.*]] = zext i8 [[LD1]] to i64
2264+
; ALL-NEXT: [[Z6:%.*]] = zext i8 [[B:%.*]] to i64
2265+
; ALL-NEXT: [[S0:%.*]] = shl i64 [[Z0]], [[Z6]]
2266+
; ALL-NEXT: [[S1:%.*]] = shl i64 [[Z1]], 8
2267+
; ALL-NEXT: [[O7:%.*]] = or i64 [[S0]], [[S1]]
2268+
; ALL-NEXT: ret i64 [[O7]]
2269+
;
2270+
%g1 = getelementptr i8, ptr %arg, i64 1
2271+
%ld0 = load i8, ptr %arg, align 1
2272+
%ld1 = load i8, ptr %g1, align 1
2273+
%z0 = zext i8 %ld0 to i64
2274+
%z1 = zext i8 %ld1 to i64
2275+
%z6 = zext i8 %b to i64
2276+
%s0 = shl i64 %z0, %z6
2277+
%s1 = shl i64 %z1, 8
2278+
%o7 = or i64 %s0, %s1
2279+
ret i64 %o7
2280+
}
2281+
2282+
define i64 @loadCombine_nonConstShift2(ptr %arg, i8 %b) {
2283+
; ALL-LABEL: @loadCombine_nonConstShift2(
2284+
; ALL-NEXT: [[G1:%.*]] = getelementptr i8, ptr [[ARG:%.*]], i64 1
2285+
; ALL-NEXT: [[LD0:%.*]] = load i8, ptr [[ARG]], align 1
2286+
; ALL-NEXT: [[LD1:%.*]] = load i8, ptr [[G1]], align 1
2287+
; ALL-NEXT: [[Z0:%.*]] = zext i8 [[LD0]] to i64
2288+
; ALL-NEXT: [[Z1:%.*]] = zext i8 [[LD1]] to i64
2289+
; ALL-NEXT: [[Z6:%.*]] = zext i8 [[B:%.*]] to i64
2290+
; ALL-NEXT: [[S0:%.*]] = shl i64 [[Z0]], [[Z6]]
2291+
; ALL-NEXT: [[S1:%.*]] = shl i64 [[Z1]], 8
2292+
; ALL-NEXT: [[O7:%.*]] = or i64 [[S1]], [[S0]]
2293+
; ALL-NEXT: ret i64 [[O7]]
2294+
;
2295+
%g1 = getelementptr i8, ptr %arg, i64 1
2296+
%ld0 = load i8, ptr %arg, align 1
2297+
%ld1 = load i8, ptr %g1, align 1
2298+
%z0 = zext i8 %ld0 to i64
2299+
%z1 = zext i8 %ld1 to i64
2300+
%z6 = zext i8 %b to i64
2301+
%s0 = shl i64 %z0, %z6
2302+
%s1 = shl i64 %z1, 8
2303+
%o7 = or i64 %s1, %s0
2304+
ret i64 %o7
2305+
}

0 commit comments

Comments
 (0)