Skip to content

Commit 2f01442

Browse files
authored
Fix constant instruction binop in forward mode (rust-lang#680)
* Fix constant instruction binop in forward mode * Fix llvm7
1 parent 55e050e commit 2f01442

File tree

2 files changed

+29
-2
lines changed

2 files changed

+29
-2
lines changed

enzyme/Enzyme/AdjointGenerator.h

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1760,8 +1760,6 @@ class AdjointGenerator
17601760

17611761
void visitBinaryOperator(llvm::BinaryOperator &BO) {
17621762
eraseIfUnused(BO);
1763-
if (gutils->isConstantInstruction(&BO))
1764-
return;
17651763

17661764
size_t size = 1;
17671765
if (BO.getType()->isSized())
@@ -1778,6 +1776,8 @@ class AdjointGenerator
17781776
switch (Mode) {
17791777
case DerivativeMode::ReverseModeGradient:
17801778
case DerivativeMode::ReverseModeCombined:
1779+
if (gutils->isConstantInstruction(&BO))
1780+
return;
17811781
createBinaryOperatorAdjoint(BO);
17821782
break;
17831783
case DerivativeMode::ForwardMode:
@@ -2255,6 +2255,11 @@ class AdjointGenerator
22552255
}
22562256

22572257
void createBinaryOperatorDual(llvm::BinaryOperator &BO) {
2258+
if (gutils->isConstantInstruction(&BO)) {
2259+
forwardModeInvertedPointerFallback(BO);
2260+
return;
2261+
}
2262+
22582263
IRBuilder<> Builder2(&BO);
22592264
getForwardBuilder(Builder2);
22602265

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
; RUN: %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -mem2reg -instsimplify -simplifycfg -S | FileCheck %s
2+
3+
declare void @_Z16__enzyme_fwddiff(...)
4+
5+
define void @_Z34testFwdDerivativesRosenbrockEnzymev(i8* %a, i8* %b) {
6+
call void (...) @_Z16__enzyme_fwddiff(double (i64*)* @f, metadata !"enzyme_dup", i8* %a, i8* %b)
7+
ret void
8+
}
9+
10+
define double @f(i64* %i10) {
11+
bb:
12+
%i13 = load i64, i64* %i10, align 8
13+
%i14 = sub i64 2, %i13
14+
%i15 = sdiv exact i64 %i14, 8
15+
%a5 = uitofp i64 %i15 to double
16+
ret double %a5
17+
}
18+
19+
; CHECK: define internal double @fwddiffef(i64* %i10, i64* %"i10'")
20+
; CHECK-NEXT: bb:
21+
; CHECK-NEXT: ret double 0.000000e+00
22+
; CHECK-NEXT: }

0 commit comments

Comments
 (0)