Skip to content

Commit 5a9a02f

Browse files
[SCEV] Compute SCEV for ashr(add(shl(x, n), c), m) instr triplet
%x = shl i64 %w, n %y = add i64 %x, c %z = ashr i64 %y, m The above given instruction triplet is seen many times in the generated LLVM IR, but SCEV model is not able to compute the SCEV value of AShr instruction in this case. This patch models the two cases of the above instruction pattern using the following expression: => sext(add(mul(trunc(w), 2^(n-m)), c >> m)) 1) when n = m the expression reduces to sext(add(trunc(w), c >> n)) as n-m=0, and multiplying with 2^0 gives the same result. 2) when n > m the expression works as given above. It also adds several unittest to verify that SCEV is able to compute the value. $ opt sext-add-inreg.ll -passes="print<scalar-evolution>" Comparing the snippets of the result of SCEV analysis: * SCEV of ashr before change ---------------------------- %idxprom = ashr exact i64 %sext, 32 --> %idxprom U: [-2147483648,2147483648) S: [-2147483648,2147483648) Exits: 8 LoopDispositions: { %for.body: Variant } * SCEV of ashr after change --------------------------- %idxprom = ashr exact i64 %sext, 32 --> {0,+,1}<nuw><nsw><%for.body> U: [0,9) S: [0,9) Exits: 8 LoopDispositions: { %for.body: Computable } LoopDisposition of the given SCEV was LoopVariant before, after adding the new way to model the instruction, the LoopDisposition becomes LoopComputable as it is able to compute the SCEV of the instruction. Differential Revision: https://reviews.llvm.org/D152278
1 parent 7bd6328 commit 5a9a02f

File tree

5 files changed

+186
-31
lines changed

5 files changed

+186
-31
lines changed

llvm/lib/Analysis/ScalarEvolution.cpp

Lines changed: 55 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -7854,7 +7854,7 @@ const SCEV *ScalarEvolution::createSCEV(Value *V) {
78547854
}
78557855
break;
78567856

7857-
case Instruction::AShr: {
7857+
case Instruction::AShr:
78587858
// AShr X, C, where C is a constant.
78597859
ConstantInt *CI = dyn_cast<ConstantInt>(BO->RHS);
78607860
if (!CI)
@@ -7876,37 +7876,69 @@ const SCEV *ScalarEvolution::createSCEV(Value *V) {
78767876
Type *TruncTy = IntegerType::get(getContext(), BitWidth - AShrAmt);
78777877

78787878
Operator *L = dyn_cast<Operator>(BO->LHS);
7879-
if (L && L->getOpcode() == Instruction::Shl) {
7879+
const SCEV *AddTruncateExpr = nullptr;
7880+
ConstantInt *ShlAmtCI = nullptr;
7881+
const SCEV *AddConstant = nullptr;
7882+
7883+
if (L && L->getOpcode() == Instruction::Add) {
7884+
// X = Shl A, n
7885+
// Y = Add X, c
7886+
// Z = AShr Y, m
7887+
// n, c and m are constants.
7888+
7889+
Operator *LShift = dyn_cast<Operator>(L->getOperand(0));
7890+
ConstantInt *AddOperandCI = dyn_cast<ConstantInt>(L->getOperand(1));
7891+
if (LShift && LShift->getOpcode() == Instruction::Shl) {
7892+
if (AddOperandCI) {
7893+
const SCEV *ShlOp0SCEV = getSCEV(LShift->getOperand(0));
7894+
ShlAmtCI = dyn_cast<ConstantInt>(LShift->getOperand(1));
7895+
// since we truncate to TruncTy, the AddConstant should be of the
7896+
// same type, so create a new Constant with type same as TruncTy.
7897+
// Also, the Add constant should be shifted right by AShr amount.
7898+
APInt AddOperand = AddOperandCI->getValue().ashr(AShrAmt);
7899+
AddConstant = getConstant(TruncTy, AddOperand.getZExtValue(),
7900+
AddOperand.isSignBitSet());
7901+
// we model the expression as sext(add(trunc(A), c << n)), since the
7902+
// sext(trunc) part is already handled below, we create a
7903+
// AddExpr(TruncExp) which will be used later.
7904+
AddTruncateExpr = getTruncateExpr(ShlOp0SCEV, TruncTy);
7905+
}
7906+
}
7907+
} else if (L && L->getOpcode() == Instruction::Shl) {
78807908
// X = Shl A, n
78817909
// Y = AShr X, m
78827910
// Both n and m are constant.
78837911

78847912
const SCEV *ShlOp0SCEV = getSCEV(L->getOperand(0));
7885-
if (L->getOperand(1) == BO->RHS)
7886-
// For a two-shift sext-inreg, i.e. n = m,
7887-
// use sext(trunc(x)) as the SCEV expression.
7888-
return getSignExtendExpr(
7889-
getTruncateExpr(ShlOp0SCEV, TruncTy), OuterTy);
7890-
7891-
ConstantInt *ShlAmtCI = dyn_cast<ConstantInt>(L->getOperand(1));
7892-
if (ShlAmtCI && ShlAmtCI->getValue().ult(BitWidth)) {
7893-
uint64_t ShlAmt = ShlAmtCI->getZExtValue();
7894-
if (ShlAmt > AShrAmt) {
7895-
// When n > m, use sext(mul(trunc(x), 2^(n-m)))) as the SCEV
7896-
// expression. We already checked that ShlAmt < BitWidth, so
7897-
// the multiplier, 1 << (ShlAmt - AShrAmt), fits into TruncTy as
7898-
// ShlAmt - AShrAmt < Amt.
7899-
APInt Mul = APInt::getOneBitSet(BitWidth - AShrAmt,
7900-
ShlAmt - AShrAmt);
7901-
return getSignExtendExpr(
7902-
getMulExpr(getTruncateExpr(ShlOp0SCEV, TruncTy),
7903-
getConstant(Mul)), OuterTy);
7904-
}
7913+
ShlAmtCI = dyn_cast<ConstantInt>(L->getOperand(1));
7914+
AddTruncateExpr = getTruncateExpr(ShlOp0SCEV, TruncTy);
7915+
}
7916+
7917+
if (AddTruncateExpr && ShlAmtCI) {
7918+
// We can merge the two given cases into a single SCEV statement,
7919+
// incase n = m, the mul expression will be 2^0, so it gets resolved to
7920+
// a simpler case. The following code handles the two cases:
7921+
//
7922+
// 1) For a two-shift sext-inreg, i.e. n = m,
7923+
// use sext(trunc(x)) as the SCEV expression.
7924+
//
7925+
// 2) When n > m, use sext(mul(trunc(x), 2^(n-m)))) as the SCEV
7926+
// expression. We already checked that ShlAmt < BitWidth, so
7927+
// the multiplier, 1 << (ShlAmt - AShrAmt), fits into TruncTy as
7928+
// ShlAmt - AShrAmt < Amt.
7929+
uint64_t ShlAmt = ShlAmtCI->getZExtValue();
7930+
if (ShlAmtCI->getValue().ult(BitWidth) && ShlAmt >= AShrAmt) {
7931+
APInt Mul = APInt::getOneBitSet(BitWidth - AShrAmt, ShlAmt - AShrAmt);
7932+
const SCEV *CompositeExpr =
7933+
getMulExpr(AddTruncateExpr, getConstant(Mul));
7934+
if (L->getOpcode() != Instruction::Shl)
7935+
CompositeExpr = getAddExpr(CompositeExpr, AddConstant);
7936+
7937+
return getSignExtendExpr(CompositeExpr, OuterTy);
79057938
}
79067939
}
79077940
break;
79087941
}
7909-
}
79107942
}
79117943

79127944
switch (U->getOpcode()) {
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
; NOTE: Assertions have been autogenerated by utils/update_analyze_test_checks.py UTC_ARGS: --version 2
2+
; RUN: opt < %s -disable-output "-passes=print<scalar-evolution>" 2>&1 | FileCheck %s
3+
4+
@.str = private unnamed_addr constant [3 x i8] c"%x\00", align 1
5+
6+
define dso_local i32 @test_loop(ptr nocapture noundef readonly %x) {
7+
; CHECK-LABEL: 'test_loop'
8+
; CHECK-NEXT: Classifying expressions for: @test_loop
9+
; CHECK-NEXT: %i.03 = phi i64 [ 1, %entry ], [ %inc, %for.body ]
10+
; CHECK-NEXT: --> {1,+,1}<nuw><nsw><%for.body> U: [1,10) S: [1,10) Exits: 9 LoopDispositions: { %for.body: Computable }
11+
; CHECK-NEXT: %conv = shl nuw nsw i64 %i.03, 32
12+
; CHECK-NEXT: --> {4294967296,+,4294967296}<nuw><nsw><%for.body> U: [4294967296,38654705665) S: [4294967296,38654705665) Exits: 38654705664 LoopDispositions: { %for.body: Computable }
13+
; CHECK-NEXT: %sext = add nsw i64 %conv, -4294967296
14+
; CHECK-NEXT: --> {0,+,4294967296}<nuw><nsw><%for.body> U: [0,34359738369) S: [0,34359738369) Exits: 34359738368 LoopDispositions: { %for.body: Computable }
15+
; CHECK-NEXT: %idxprom = ashr exact i64 %sext, 32
16+
; CHECK-NEXT: --> {0,+,1}<nuw><nsw><%for.body> U: [0,9) S: [0,9) Exits: 8 LoopDispositions: { %for.body: Computable }
17+
; CHECK-NEXT: %arrayidx = getelementptr inbounds i32, ptr %x, i64 %idxprom
18+
; CHECK-NEXT: --> {%x,+,4}<nuw><%for.body> U: full-set S: full-set Exits: (32 + %x) LoopDispositions: { %for.body: Computable }
19+
; CHECK-NEXT: %0 = load i32, ptr %arrayidx, align 4
20+
; CHECK-NEXT: --> %0 U: full-set S: full-set Exits: <<Unknown>> LoopDispositions: { %for.body: Variant }
21+
; CHECK-NEXT: %call = tail call i32 (ptr, ...) @printf(ptr noundef nonnull dereferenceable(1) @.str, i32 noundef %0)
22+
; CHECK-NEXT: --> %call U: full-set S: full-set Exits: <<Unknown>> LoopDispositions: { %for.body: Variant }
23+
; CHECK-NEXT: %inc = add nuw nsw i64 %i.03, 1
24+
; CHECK-NEXT: --> {2,+,1}<nuw><nsw><%for.body> U: [2,11) S: [2,11) Exits: 10 LoopDispositions: { %for.body: Computable }
25+
; CHECK-NEXT: Determining loop execution counts for: @test_loop
26+
; CHECK-NEXT: Loop %for.body: backedge-taken count is 8
27+
; CHECK-NEXT: Loop %for.body: constant max backedge-taken count is 8
28+
; CHECK-NEXT: Loop %for.body: symbolic max backedge-taken count is 8
29+
; CHECK-NEXT: Loop %for.body: Predicated backedge-taken count is 8
30+
; CHECK-NEXT: Predicates:
31+
; CHECK: Loop %for.body: Trip multiple is 9
32+
;
33+
entry:
34+
br label %for.body
35+
36+
for.cond.cleanup: ; preds = %for.body
37+
ret i32 0
38+
39+
for.body: ; preds = %entry, %for.body
40+
%i.03 = phi i64 [ 1, %entry ], [ %inc, %for.body ]
41+
%conv = shl nuw nsw i64 %i.03, 32
42+
%sext = add nsw i64 %conv, -4294967296
43+
%idxprom = ashr exact i64 %sext, 32
44+
%arrayidx = getelementptr inbounds i32, ptr %x, i64 %idxprom
45+
%0 = load i32, ptr %arrayidx, align 4
46+
%call = tail call i32 (ptr, ...) @printf(ptr noundef nonnull dereferenceable(1) @.str, i32 noundef %0)
47+
%inc = add nuw nsw i64 %i.03, 1
48+
%exitcond.not = icmp eq i64 %inc, 10
49+
br i1 %exitcond.not, label %for.cond.cleanup, label %for.body
50+
}
51+
52+
declare noundef i32 @printf(ptr nocapture noundef readonly, ...)
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
; NOTE: Assertions have been autogenerated by utils/update_analyze_test_checks.py UTC_ARGS: --version 2
2+
; RUN: opt < %s -disable-output "-passes=print<scalar-evolution>" 2>&1 | FileCheck %s
3+
4+
define i64 @test00(i64 %a) {
5+
; CHECK-LABEL: 'test00'
6+
; CHECK-NEXT: Classifying expressions for: @test00
7+
; CHECK-NEXT: %add = shl i64 %a, 10
8+
; CHECK-NEXT: --> (1024 * %a) U: [0,-1023) S: [-9223372036854775808,9223372036854774785)
9+
; CHECK-NEXT: %shl = add i64 %add, 256
10+
; CHECK-NEXT: --> (256 + (1024 * %a))<nuw><nsw> U: [256,-767) S: [-9223372036854775552,9223372036854775041)
11+
; CHECK-NEXT: %ashr = ashr exact i64 %shl, 8
12+
; CHECK-NEXT: --> (1 + (sext i56 (4 * (trunc i64 %a to i56)) to i64))<nuw><nsw> U: [1,-2) S: [-36028797018963967,36028797018963966)
13+
; CHECK-NEXT: Determining loop execution counts for: @test00
14+
;
15+
%add = shl i64 %a, 10
16+
%shl = add i64 %add, 256
17+
%ashr = ashr exact i64 %shl, 8
18+
ret i64 %ashr
19+
}
20+
21+
define i64 @test01(i64 %a) {
22+
; CHECK-LABEL: 'test01'
23+
; CHECK-NEXT: Classifying expressions for: @test01
24+
; CHECK-NEXT: %add = shl i64 %a, 6
25+
; CHECK-NEXT: --> (64 * %a) U: [0,-63) S: [-9223372036854775808,9223372036854775745)
26+
; CHECK-NEXT: %shl = add i64 %add, 256
27+
; CHECK-NEXT: --> (256 + (64 * %a)) U: [0,-63) S: [-9223372036854775808,9223372036854775745)
28+
; CHECK-NEXT: %ashr = ashr exact i64 %shl, 8
29+
; CHECK-NEXT: --> %ashr U: [-36028797018963968,36028797018963968) S: [-36028797018963968,36028797018963968)
30+
; CHECK-NEXT: Determining loop execution counts for: @test01
31+
;
32+
%add = shl i64 %a, 6
33+
%shl = add i64 %add, 256
34+
%ashr = ashr exact i64 %shl, 8
35+
ret i64 %ashr
36+
}
37+
38+
define i64 @test02(i64 %a) {
39+
; CHECK-LABEL: 'test02'
40+
; CHECK-NEXT: Classifying expressions for: @test02
41+
; CHECK-NEXT: %add = shl i64 %a, 12
42+
; CHECK-NEXT: --> (4096 * %a) U: [0,-4095) S: [-9223372036854775808,9223372036854771713)
43+
; CHECK-NEXT: %shl = add i64 %add, 4096
44+
; CHECK-NEXT: --> (4096 + (4096 * %a)) U: [0,-4095) S: [-9223372036854775808,9223372036854771713)
45+
; CHECK-NEXT: %ashr = ashr exact i64 %shl, 8
46+
; CHECK-NEXT: --> (sext i56 (16 + (16 * (trunc i64 %a to i56))) to i64) U: [0,-15) S: [-36028797018963968,36028797018963953)
47+
; CHECK-NEXT: Determining loop execution counts for: @test02
48+
;
49+
%add = shl i64 %a, 12
50+
%shl = add i64 %add, 4096
51+
%ashr = ashr exact i64 %shl, 8
52+
ret i64 %ashr
53+
}
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
; NOTE: Assertions have been autogenerated by utils/update_analyze_test_checks.py UTC_ARGS: --version 2
2+
; RUN: opt < %s -disable-output "-passes=print<scalar-evolution>" 2>&1 | FileCheck %s
3+
4+
define i64 @test(i64 %a) {
5+
; CHECK-LABEL: 'test'
6+
; CHECK-NEXT: Classifying expressions for: @test
7+
; CHECK-NEXT: %add = shl i64 %a, 8
8+
; CHECK-NEXT: --> (256 * %a) U: [0,-255) S: [-9223372036854775808,9223372036854775553)
9+
; CHECK-NEXT: %shl = add i64 %add, 256
10+
; CHECK-NEXT: --> (256 + (256 * %a)) U: [0,-255) S: [-9223372036854775808,9223372036854775553)
11+
; CHECK-NEXT: %ashr = ashr exact i64 %shl, 8
12+
; CHECK-NEXT: --> (sext i56 (1 + (trunc i64 %a to i56)) to i64) U: [-36028797018963968,36028797018963968) S: [-36028797018963968,36028797018963968)
13+
; CHECK-NEXT: Determining loop execution counts for: @test
14+
;
15+
%add = shl i64 %a, 8
16+
%shl = add i64 %add, 256
17+
%ashr = ashr exact i64 %shl, 8
18+
ret i64 %ashr
19+
}

llvm/test/Transforms/LoopStrengthReduce/scaling-factor-incompat-type.ll

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,31 +4,30 @@
44
; see pr42770
55
; REQUIRES: asserts
66
; RUN: opt < %s -loop-reduce -S | FileCheck %s
7-
87
target datalayout = "e-m:e-i64:64-f80:128-n8:16:32:64-S128-ni:1"
98

109
define void @foo() {
1110
; CHECK-LABEL: @foo(
1211
; CHECK-NEXT: bb:
1312
; CHECK-NEXT: br label [[BB4:%.*]]
1413
; CHECK: bb1:
15-
; CHECK-NEXT: [[T3:%.*]] = ashr i64 [[LSR_IV_NEXT:%.*]], 32
14+
; CHECK-NEXT: [[T:%.*]] = shl i64 [[T14:%.*]], 32
15+
; CHECK-NEXT: [[T2:%.*]] = add i64 [[T]], 1
16+
; CHECK-NEXT: [[T3:%.*]] = ashr i64 [[T2]], 32
1617
; CHECK-NEXT: ret void
1718
; CHECK: bb4:
18-
; CHECK-NEXT: [[LSR_IV1:%.*]] = phi i16 [ [[LSR_IV_NEXT2:%.*]], [[BB13:%.*]] ], [ 6, [[BB:%.*]] ]
19-
; CHECK-NEXT: [[LSR_IV:%.*]] = phi i64 [ [[LSR_IV_NEXT]], [[BB13]] ], [ 8589934593, [[BB]] ]
20-
; CHECK-NEXT: [[T5:%.*]] = phi i64 [ 2, [[BB]] ], [ [[T14:%.*]], [[BB13]] ]
19+
; CHECK-NEXT: [[LSR_IV:%.*]] = phi i16 [ [[LSR_IV_NEXT:%.*]], [[BB13:%.*]] ], [ 6, [[BB:%.*]] ]
20+
; CHECK-NEXT: [[T5:%.*]] = phi i64 [ 2, [[BB]] ], [ [[T14]], [[BB13]] ]
2121
; CHECK-NEXT: [[T6:%.*]] = add i64 [[T5]], 4
2222
; CHECK-NEXT: [[T7:%.*]] = trunc i64 [[T6]] to i16
2323
; CHECK-NEXT: [[T8:%.*]] = urem i16 [[T7]], 3
2424
; CHECK-NEXT: [[T9:%.*]] = mul i16 [[T8]], 2
25-
; CHECK-NEXT: [[LSR_IV_NEXT]] = add nuw nsw i64 [[LSR_IV]], 25769803776
26-
; CHECK-NEXT: [[LSR_IV_NEXT2]] = add nuw nsw i16 [[LSR_IV1]], 6
25+
; CHECK-NEXT: [[LSR_IV_NEXT]] = add nuw nsw i16 [[LSR_IV]], 6
2726
; CHECK-NEXT: [[T14]] = add nuw nsw i64 [[T5]], 6
2827
; CHECK-NEXT: [[T10:%.*]] = icmp eq i16 [[T9]], 1
2928
; CHECK-NEXT: br i1 [[T10]], label [[BB11:%.*]], label [[BB13]]
3029
; CHECK: bb11:
31-
; CHECK-NEXT: [[T12:%.*]] = udiv i16 1, [[LSR_IV1]]
30+
; CHECK-NEXT: [[T12:%.*]] = udiv i16 1, [[LSR_IV]]
3231
; CHECK-NEXT: unreachable
3332
; CHECK: bb13:
3433
; CHECK-NEXT: br i1 true, label [[BB1:%.*]], label [[BB4]]

0 commit comments

Comments
 (0)