Skip to content

Commit 702a89f

Browse files
authored
Fix increment usage (rust-lang#468)
1 parent 13af0d9 commit 702a89f

File tree

5 files changed

+102
-36
lines changed

5 files changed

+102
-36
lines changed

enzyme/Enzyme/CacheUtility.cpp

Lines changed: 32 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,7 @@ std::pair<PHINode *, Instruction *> FindCanonicalIV(Loop *L, Type *Ty) {
192192
// Attempt to rewrite all phinode's in the loop in terms of the
193193
// induction variable
194194
void RemoveRedundantIVs(BasicBlock *Header, PHINode *CanonicalIV,
195-
MustExitScalarEvolution &SE,
195+
Instruction *Increment, MustExitScalarEvolution &SE,
196196
std::function<void(Instruction *, Value *)> replacer,
197197
std::function<void(Instruction *)> eraser) {
198198
assert(Header);
@@ -246,6 +246,37 @@ void RemoveRedundantIVs(BasicBlock *Header, PHINode *CanonicalIV,
246246
replacer(Tmp, NewIV);
247247
eraser(Tmp);
248248
}
249+
250+
// Replace existing increments with canonical Increment
251+
Increment->moveAfter(CanonicalIV->getParent()->getFirstNonPHI());
252+
SmallVector<Instruction *, 1> toErase;
253+
for (auto use : CanonicalIV->users()) {
254+
auto BO = dyn_cast<BinaryOperator>(use);
255+
if (BO == nullptr)
256+
continue;
257+
if (BO->getOpcode() != BinaryOperator::Add)
258+
continue;
259+
if (use == Increment)
260+
continue;
261+
262+
Value *toadd = nullptr;
263+
if (BO->getOperand(0) == CanonicalIV) {
264+
toadd = BO->getOperand(1);
265+
} else {
266+
assert(BO->getOperand(1) == CanonicalIV);
267+
toadd = BO->getOperand(0);
268+
}
269+
if (auto CI = dyn_cast<ConstantInt>(toadd)) {
270+
if (!CI->isOne())
271+
continue;
272+
BO->replaceAllUsesWith(Increment);
273+
toErase.push_back(BO);
274+
} else {
275+
continue;
276+
}
277+
}
278+
for (auto BO : toErase)
279+
eraser(BO);
249280
}
250281

251282
void CanonicalizeLatches(const Loop *L, BasicBlock *Header,
@@ -332,37 +363,6 @@ void CanonicalizeLatches(const Loop *L, BasicBlock *Header,
332363
// Replace previous increment usage with new increment value
333364
if (Increment) {
334365
Increment->moveAfter(CanonicalIV->getParent()->getFirstNonPHI());
335-
std::vector<Instruction *> toerase;
336-
// Replace existing increments with canonical Increment
337-
for (auto use : CanonicalIV->users()) {
338-
auto BO = dyn_cast<BinaryOperator>(use);
339-
if (BO == nullptr)
340-
continue;
341-
if (BO->getOpcode() != BinaryOperator::Add)
342-
continue;
343-
if (use == Increment)
344-
continue;
345-
346-
Value *toadd = nullptr;
347-
if (BO->getOperand(0) == CanonicalIV) {
348-
toadd = BO->getOperand(1);
349-
} else {
350-
assert(BO->getOperand(1) == CanonicalIV);
351-
toadd = BO->getOperand(0);
352-
}
353-
if (auto CI = dyn_cast<ConstantInt>(toadd)) {
354-
if (!CI->isOne())
355-
continue;
356-
BO->replaceAllUsesWith(Increment);
357-
toerase.push_back(BO);
358-
} else {
359-
continue;
360-
}
361-
}
362-
for (auto inst : toerase) {
363-
gutils.erase(inst);
364-
}
365-
toerase.clear();
366366

367367
if (latches.size() == 1 && isa<BranchInst>(latches[0]->getTerminator()) &&
368368
cast<BranchInst>(latches[0]->getTerminator())->isConditional())

enzyme/Enzyme/CacheUtility.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -390,7 +390,7 @@ InsertNewCanonicalIV(llvm::Loop *L, llvm::Type *Ty, std::string name = "iv");
390390
// induction variable
391391
void RemoveRedundantIVs(
392392
llvm::BasicBlock *Header, llvm::PHINode *CanonicalIV,
393-
MustExitScalarEvolution &SE,
393+
llvm::Instruction *Increment, MustExitScalarEvolution &SE,
394394
std::function<void(llvm::Instruction *, llvm::Value *)> replacer,
395395
std::function<void(llvm::Instruction *)> eraser);
396396
#endif

enzyme/Enzyme/EnzymeLogic.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3476,11 +3476,11 @@ Function *EnzymeLogic::CreatePrimalAndGradient(
34763476
sucBB->removePredecessor(newBB);
34773477
}
34783478

3479-
std::vector<Instruction *> toerase;
3479+
SmallVector<Instruction *, 2> toerase;
34803480
for (auto &I : oBB) {
34813481
toerase.push_back(&I);
34823482
}
3483-
for (auto I : toerase) {
3483+
for (auto I : llvm::reverse(toerase)) {
34843484
maker.eraseIfUnused(*I, /*erase*/ true,
34853485
/*check*/ key.mode ==
34863486
DerivativeMode::ReverseModeCombined);

enzyme/Enzyme/FunctionUtils.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -808,7 +808,7 @@ void CanonicalizeLoops(Function *F, FunctionAnalysisManager &FAM) {
808808
PHINode *CanonicalIV = pair.first;
809809
assert(CanonicalIV);
810810
RemoveRedundantIVs(
811-
L->getHeader(), CanonicalIV, SE,
811+
L->getHeader(), CanonicalIV, pair.second, SE,
812812
[&](Instruction *I, Value *V) { I->replaceAllUsesWith(V); },
813813
[&](Instruction *I) { I->eraseFromParent(); });
814814
}
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
; RUN: %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -mem2reg -early-cse -instsimplify -simplifycfg -loop-deletion -simplifycfg -S | FileCheck %s
2+
3+
source_filename = "text"
4+
target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128-ni:10:11:12:13"
5+
target triple = "x86_64-linux-gnu"
6+
7+
declare void @julia_throw_boundserror_1767(i64 %d)
8+
9+
define double @sub(double* %i58) {
10+
bb:
11+
br label %bb22
12+
13+
bb22: ; preds = %bb32, %bb
14+
%i23 = phi i64 [ %i78, %bb32 ], [ 0, %bb ]
15+
%i78 = add nuw nsw i64 %i23, 1
16+
%i26 = icmp sle i64 %i23, 10
17+
br i1 %i26, label %bb32, label %bb29
18+
19+
bb29: ; preds = %bb22
20+
%i30 = add nuw nsw i64 %i23, 1
21+
call void @julia_throw_boundserror_1767(i64 %i30)
22+
unreachable
23+
24+
bb32: ; preds = %bb22
25+
%c = icmp sle i64 %i23, 5
26+
br i1 %c, label %bb22, label %exit
27+
28+
exit:
29+
%v = load double, double* %i58, align 8
30+
ret double %v
31+
}
32+
33+
34+
declare dso_local double @__enzyme_autodiff(i8*, i64*, double*, double*)
35+
36+
define void @main(double* %arg, double* %arg1) {
37+
bb:
38+
%enzyme_dup = alloca i64, align 8
39+
%i = tail call double @__enzyme_autodiff(i8* bitcast (double (double*)* @julia_arsum2_1761 to i8*), i64* %enzyme_dup, double* %arg, double* %arg1)
40+
ret void
41+
}
42+
43+
define double @julia_arsum2_1761(double* %arg) {
44+
bb:
45+
%i23 = call double @sub(double* %arg)
46+
store double 0.000000e+00, double* %arg
47+
ret double %i23
48+
}
49+
50+
!llvm.module.flags = !{!0, !1}
51+
52+
!0 = !{i32 2, !"Dwarf Version", i32 4}
53+
!1 = !{i32 2, !"Debug Info Version", i32 3}
54+
!2 = !{!3, !3, i64 0}
55+
!3 = !{!"jtbaa_data", !4, i64 0}
56+
!4 = !{!"jtbaa", !5, i64 0}
57+
!5 = !{!"jtbaa"}
58+
59+
; Creating this should not segfault due to the use of the replaced iv in the unreachable block
60+
; CHECK: define internal void @diffesub(double* %i58, double* %"i58'", double %differeturn)
61+
; CHECK-NEXT: bb:
62+
; CHECK-NEXT: %0 = load double, double* %"i58'", align 8
63+
; CHECK-NEXT: %1 = fadd fast double %0, %differeturn
64+
; CHECK-NEXT: store double %1, double* %"i58'", align 8
65+
; CHECK-NEXT: ret void
66+
; CHECK-NEXT: }

0 commit comments

Comments
 (0)