Skip to content

Commit bc04d6f

Browse files
authored
Forward mode select inst (rust-lang#218)
* implemented select inst for forward mode * fixed return activity check
1 parent 0bd5ab1 commit bc04d6f

File tree

6 files changed

+275
-2
lines changed

6 files changed

+275
-2
lines changed

enzyme/Enzyme/AdjointGenerator.h

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -677,14 +677,28 @@ class AdjointGenerator
677677

678678
void visitSelectInst(llvm::SelectInst &SI) {
679679
eraseIfUnused(SI);
680+
680681
if (gutils->isConstantInstruction(&SI))
681682
return;
682683
if (SI.getType()->isPointerTy())
683684
return;
684685

685-
if (Mode == DerivativeMode::ReverseModePrimal)
686+
switch (Mode) {
687+
case DerivativeMode::ReverseModePrimal:
686688
return;
689+
case DerivativeMode::ReverseModeCombined:
690+
case DerivativeMode::ReverseModeGradient: {
691+
createSelectInstAdjoint(SI);
692+
return;
693+
}
694+
case DerivativeMode::ForwardMode: {
695+
createSelectInstDual(SI);
696+
return;
697+
}
698+
}
699+
}
687700

701+
void createSelectInstAdjoint(llvm::SelectInst &SI) {
688702
Value *op0 = gutils->getNewFromOriginal(SI.getOperand(0));
689703
Value *orig_op1 = SI.getOperand(1);
690704
Value *op1 = gutils->getNewFromOriginal(orig_op1);
@@ -778,6 +792,38 @@ class AdjointGenerator
778792
addToDiffe(orig_op2, dif2, Builder2, TR.addingType(size, orig_op2));
779793
}
780794

795+
void createSelectInstDual(llvm::SelectInst &SI) {
796+
Value *orig_cond = SI.getOperand(0);
797+
Value *cond = gutils->getNewFromOriginal(orig_cond);
798+
799+
Value *op1 = SI.getOperand(1);
800+
Value *op2 = SI.getOperand(2);
801+
802+
bool constantval0 = gutils->isConstantValue(op1);
803+
bool constantval1 = gutils->isConstantValue(op2);
804+
805+
IRBuilder<> Builder2(&SI);
806+
getForwardBuilder(Builder2);
807+
808+
Value *dif1;
809+
Value *dif2;
810+
811+
if (!constantval0) {
812+
dif1 = diffe(op1, Builder2);
813+
} else {
814+
dif1 = Constant::getNullValue(SI.getType());
815+
}
816+
817+
if (!constantval1) {
818+
dif2 = diffe(op2, Builder2);
819+
} else {
820+
dif2 = Constant::getNullValue(SI.getType());
821+
}
822+
823+
Value *diffe = Builder2.CreateSelect(cond, dif1, dif2);
824+
setDiffe(&SI, diffe, Builder2);
825+
}
826+
781827
void visitExtractElementInst(llvm::ExtractElementInst &EEI) {
782828
eraseIfUnused(EEI);
783829
if (gutils->isConstantInstruction(&EEI))

enzyme/Enzyme/EnzymeLogic.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2662,7 +2662,8 @@ Function *EnzymeLogic::CreatePrimalAndGradient(
26622662
ReturnType retVal;
26632663
if (mode == DerivativeMode::ForwardMode) {
26642664
auto TR = TA.analyzeFunction(oldTypeInfo);
2665-
bool retActive = TR.getReturnAnalysis().Inner0().isFloat();
2665+
bool retActive = TR.getReturnAnalysis().Inner0().isPossibleFloat() &&
2666+
!todiff->getReturnType()->isVoidTy();
26662667

26672668
retVal = returnValue
26682669
? (retActive ? ReturnType::TwoReturns : ReturnType::Return)
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
; RUN: %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -mem2reg -sroa -instsimplify -simplifycfg -S | FileCheck %s
2+
3+
; ModuleID = 'inp.c'
4+
source_filename = "inp.c"
5+
target datalayout = "e-m:e-i64:64-f80:128-n8:16:32:64-S128"
6+
target triple = "x86_64-unknown-linux-gnu"
7+
8+
@.str = private unnamed_addr constant [20 x i8] c"dfun/dx = %f, x=%d\0A\00", align 1
9+
10+
; Function Attrs: norecurse nounwind readnone uwtable
11+
define double @fun2(double %x) {
12+
entry:
13+
%cmp.inv = fcmp oge double %x, 0.000000e+00
14+
%.x = select i1 %cmp.inv, double %x, double 0.000000e+00
15+
ret double %.x
16+
}
17+
18+
; Function Attrs: nounwind uwtable
19+
define i32 @main() {
20+
entry:
21+
%call3.4 = tail call double (i8*, ...) @__enzyme_fwddiff(i8* bitcast (double (double)* @fun2 to i8*), double 2.000000e+00, double 1.0)
22+
%call4.4 = tail call i32 (i8*, ...) @printf(i8* getelementptr inbounds ([20 x i8], [20 x i8]* @.str, i64 0, i64 0), double %call3.4, i32 2)
23+
ret i32 0
24+
}
25+
26+
; Function Attrs: nounwind
27+
declare dso_local i32 @printf(i8* nocapture readonly, ...)
28+
29+
declare double @__enzyme_fwddiff(i8*, ...)
30+
31+
attributes #0 = { norecurse nounwind readnone uwtable "correctly-rounded-divide-sqrt-fp-math"="false" "disable-tail-calls"="false" "less-precise-fpmad"="false" "no-frame-pointer-elim"="false" "no-infs-fp-math"="false" "no-jump-tables"="false" "no-nans-fp-math"="false" "no-signed-zeros-fp-math"="false" "no-trapping-math"="false" "stack-protector-buffer-size"="8" "target-cpu"="x86-64" "target-features"="+fxsr,+mmx,+sse,+sse2,+x87" "unsafe-fp-math"="false" "use-soft-float"="false" }
32+
attributes #1 = { nounwind uwtable "correctly-rounded-divide-sqrt-fp-math"="false" "disable-tail-calls"="false" "less-precise-fpmad"="false" "no-frame-pointer-elim"="false" "no-infs-fp-math"="false" "no-jump-tables"="false" "no-nans-fp-math"="false" "no-signed-zeros-fp-math"="false" "no-trapping-math"="false" "stack-protector-buffer-size"="8" "target-cpu"="x86-64" "target-features"="+fxsr,+mmx,+sse,+sse2,+x87" "unsafe-fp-math"="false" "use-soft-float"="false" }
33+
attributes #2 = { nounwind "correctly-rounded-divide-sqrt-fp-math"="false" "disable-tail-calls"="false" "less-precise-fpmad"="false" "no-frame-pointer-elim"="false" "no-infs-fp-math"="false" "no-nans-fp-math"="false" "no-signed-zeros-fp-math"="false" "no-trapping-math"="false" "stack-protector-buffer-size"="8" "target-cpu"="x86-64" "target-features"="+fxsr,+mmx,+sse,+sse2,+x87" "unsafe-fp-math"="false" "use-soft-float"="false" }
34+
attributes #3 = { "correctly-rounded-divide-sqrt-fp-math"="false" "disable-tail-calls"="false" "less-precise-fpmad"="false" "no-frame-pointer-elim"="false" "no-infs-fp-math"="false" "no-nans-fp-math"="false" "no-signed-zeros-fp-math"="false" "no-trapping-math"="false" "stack-protector-buffer-size"="8" "target-cpu"="x86-64" "target-features"="+fxsr,+mmx,+sse,+sse2,+x87" "unsafe-fp-math"="false" "use-soft-float"="false" }
35+
attributes #4 = { nounwind }
36+
37+
!llvm.module.flags = !{!0}
38+
!llvm.ident = !{!1}
39+
40+
!0 = !{i32 1, !"wchar_size", i32 4}
41+
!1 = !{!"clang version 7.1.0 "}
42+
43+
; CHECK: define internal { double } @diffefun2(double %x, double %"x'")
44+
; CHECK-NEXT: entry:
45+
; CHECK-NEXT: %cmp.inv = fcmp oge double %x, 0.000000e+00
46+
; CHECK-NEXT: %0 = select{{( fast)?}} i1 %cmp.inv, double %"x'", double 0.000000e+00
47+
; CHECK-NEXT: %1 = insertvalue { double } undef, double %0, 0
48+
; CHECK-NEXT: ret { double } %1
49+
; CHECK-NEXT: }
Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
; RUN: %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -mem2reg -simplifycfg -early-cse-memssa -instsimplify -correlated-propagation -adce -S | FileCheck %s
2+
3+
; Function Attrs: norecurse nounwind readonly uwtable
4+
define double @alldiv(double* nocapture readonly %A, i64 %N, double %start) {
5+
entry:
6+
br label %loop
7+
8+
loop: ; preds = %9, %5
9+
%i = phi i64 [ 0, %entry ], [ %next, %loop ]
10+
%reduce = phi double [ %start, %entry ], [ %div, %loop ]
11+
%gep = getelementptr inbounds double, double* %A, i64 %i
12+
%ld = load double, double* %gep, align 8, !tbaa !2
13+
%div = fdiv double %reduce, %ld
14+
%next = add nuw nsw i64 %i, 1
15+
%cmp = icmp eq i64 %next, %N
16+
br i1 %cmp, label %end, label %loop
17+
18+
end: ; preds = %9, %3
19+
ret double %div
20+
}
21+
22+
define double @alldiv2(double* nocapture readonly %A, i64 %N) {
23+
entry:
24+
br label %loop
25+
26+
loop: ; preds = %9, %5
27+
%i = phi i64 [ 0, %entry ], [ %next, %loop ]
28+
%reduce = phi double [ 2.000000e+00, %entry ], [ %div, %loop ]
29+
%gep = getelementptr inbounds double, double* %A, i64 %i
30+
%ld = load double, double* %gep, align 8, !tbaa !2
31+
%div = fdiv double %reduce, %ld
32+
%next = add nuw nsw i64 %i, 1
33+
%cmp = icmp eq i64 %next, %N
34+
br i1 %cmp, label %end, label %loop
35+
36+
end: ; preds = %9, %3
37+
ret double %div
38+
}
39+
40+
; Function Attrs: nounwind uwtable
41+
define double @main(double* %A, double* %dA, i64 %N, double %start) {
42+
%r = call double @__enzyme_fwddiff(i8* bitcast (double (double*, i64, double)* @alldiv to i8*), double* %A, double* %dA, i64 %N, double %start, double 1.0)
43+
%r2 = call double @__enzyme_fwddiff2(i8* bitcast (double (double*, i64)* @alldiv2 to i8*), double* %A, double* %dA, i64 %N)
44+
ret double %r
45+
}
46+
47+
declare double @__enzyme_fwddiff(i8*, double*, double*, i64, double, double)
48+
declare double @__enzyme_fwddiff2(i8*, double*, double*, i64)
49+
50+
!llvm.module.flags = !{!0}
51+
!llvm.ident = !{!1}
52+
53+
!0 = !{i32 1, !"wchar_size", i32 4}
54+
!1 = !{!"Ubuntu clang version 10.0.1-++20200809072545+ef32c611aa2-1~exp1~20200809173142.193"}
55+
!2 = !{!3, !3, i64 0}
56+
!3 = !{!"double", !4, i64 0}
57+
!4 = !{!"omnipotent char", !5, i64 0}
58+
!5 = !{!"Simple C/C++ TBAA"}
59+
!6 = !{!7, !7, i64 0}
60+
!7 = !{!"any pointer", !4, i64 0}
61+
62+
63+
64+
; CHECK: define internal { double } @diffealldiv(double* nocapture readonly %A, double* nocapture %"A'", i64 %N, double %start, double %"start'")
65+
; CHECK-NEXT: entry:
66+
; CHECK-NEXT: br label %loop
67+
68+
; CHECK: loop: ; preds = %loop, %entry
69+
; CHECK-NEXT: %iv = phi i64 [ %iv.next, %loop ], [ 0, %entry ]
70+
; CHECK-NEXT: %reduce = phi double [ %start, %entry ], [ %div, %loop ]
71+
; CHECK-NEXT: %iv.next = add nuw nsw i64 %iv, 1
72+
; CHECK-NEXT: %"gep'ipg" = getelementptr inbounds double, double* %"A'", i64 %iv
73+
; CHECK-NEXT: %gep = getelementptr inbounds double, double* %A, i64 %iv
74+
; CHECK-NEXT: %ld = load double, double* %gep, align 8, !tbaa !2
75+
; CHECK-NEXT: %0 = load double, double* %"gep'ipg"
76+
; CHECK-NEXT: %div = fdiv double %reduce, %ld
77+
; CHECK-NEXT: %1 = fmul fast double %reduce, %0
78+
; CHECK-NEXT: %2 = fsub fast double 0.000000e+00, %1
79+
; CHECK-NEXT: %3 = fmul fast double %ld, %ld
80+
; CHECK-NEXT: %4 = fdiv fast double %2, %3
81+
; CHECK-NEXT: %cmp = icmp eq i64 %iv.next, %N
82+
; CHECK-NEXT: br i1 %cmp, label %end, label %loop
83+
84+
; CHECK: end: ; preds = %loop
85+
; CHECK-NEXT: %5 = insertvalue { double } undef, double %4, 0
86+
; CHECK-NEXT: ret { double } %5
87+
; CHECK-NEXT: }
88+
89+
90+
91+
92+
; CHECK: define internal { double } @diffealldiv2(double* nocapture readonly %A, double* nocapture %"A'", i64 %N)
93+
; CHECK-NEXT: entry:
94+
; CHECK-NEXT: br label %loop
95+
96+
; CHECK: loop: ; preds = %loop, %entry
97+
; CHECK-NEXT: %iv = phi i64 [ %iv.next, %loop ], [ 0, %entry ]
98+
; CHECK-NEXT: %reduce = phi double [ 2.000000e+00, %entry ], [ %div, %loop ]
99+
; CHECK-NEXT: %iv.next = add nuw nsw i64 %iv, 1
100+
; CHECK-NEXT: %"gep'ipg" = getelementptr inbounds double, double* %"A'", i64 %iv
101+
; CHECK-NEXT: %gep = getelementptr inbounds double, double* %A, i64 %iv
102+
; CHECK-NEXT: %ld = load double, double* %gep, align 8, !tbaa !2
103+
; CHECK-NEXT: %0 = load double, double* %"gep'ipg"
104+
; CHECK-NEXT: %div = fdiv double %reduce, %ld
105+
; CHECK-NEXT: %1 = fmul fast double %reduce, %0
106+
; CHECK-NEXT: %2 = fsub fast double 0.000000e+00, %1
107+
; CHECK-NEXT: %3 = fmul fast double %ld, %ld
108+
; CHECK-NEXT: %4 = fdiv fast double %2, %3
109+
; CHECK-NEXT: %cmp = icmp eq i64 %iv.next, %N
110+
; CHECK-NEXT: br i1 %cmp, label %end, label %loop
111+
112+
; CHECK: end: ; preds = %loop
113+
; CHECK-NEXT: %5 = insertvalue { double } undef, double %4, 0
114+
; CHECK-NEXT: ret { double } %5
115+
; CHECK-NEXT: }
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
; RUN: %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -mem2reg -early-cse -simplifycfg -S | FileCheck %s
2+
3+
; Function Attrs: noinline nounwind uwtable
4+
define dso_local float @man_max(float* %a, float* %b) #0 {
5+
entry:
6+
%0 = load float, float* %a, align 4
7+
%1 = load float, float* %b, align 4
8+
%cmp = fcmp ogt float %0, %1
9+
%a.b = select i1 %cmp, float* %a, float* %b
10+
%retval.0 = load float, float* %a.b, align 4
11+
ret float %retval.0
12+
}
13+
14+
define void @dman_max(float* %a, float* %da, float* %b, float* %db) {
15+
entry:
16+
call float (...) @__enzyme_fwddiff.f64(float (float*, float*)* @man_max, float* %a, float* %da, float* %b, float* %db)
17+
ret void
18+
}
19+
20+
declare float @__enzyme_fwddiff.f64(...)
21+
22+
attributes #0 = { noinline }
23+
24+
25+
; CHECK: define internal { float } @diffeman_max(float* %a, float* %"a'", float* %b, float* %"b'")
26+
; CHECK-NEXT: entry:
27+
; CHECK-NEXT: %0 = load float, float* %a, align 4
28+
; CHECK-NEXT: %1 = load float, float* %b, align 4
29+
; CHECK-NEXT: %cmp = fcmp ogt float %0, %1
30+
; CHECK-NEXT: %"a.b'ipse" = select i1 %cmp, float* %"a'", float* %"b'"
31+
; CHECK-NEXT: %2 = load float, float* %"a.b'ipse"
32+
; CHECK-NEXT: %3 = insertvalue { float } undef, float %2, 0
33+
; CHECK-NEXT: ret { float } %3
34+
; CHECK-NEXT: }

enzyme/test/Enzyme/ForwardMode/max.ll

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
; RUN: %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -mem2reg -instsimplify -simplifycfg -S | FileCheck %s
2+
3+
; Function Attrs: norecurse nounwind readnone uwtable
4+
define dso_local double @max(double %x, double %y) #0 {
5+
entry:
6+
%cmp = fcmp fast ogt double %x, %y
7+
%cond = select i1 %cmp, double %x, double %y
8+
ret double %cond
9+
}
10+
11+
; Function Attrs: nounwind uwtable
12+
define dso_local double @test_derivative(double %x, double %y) local_unnamed_addr #1 {
13+
entry:
14+
%0 = tail call double (double (double, double)*, ...) @__enzyme_fwddiff(double (double, double)* nonnull @max, double %x, double 1.0, double %y, double 1.0)
15+
ret double %0
16+
}
17+
18+
; Function Attrs: nounwind
19+
declare double @__enzyme_fwddiff(double (double, double)*, ...)
20+
21+
22+
; CHECK: define internal { double } @diffemax(double %x, double %"x'", double %y, double %"y'")
23+
; CHECK-NEXT: entry:
24+
; CHECK-NEXT: %cmp = fcmp fast ogt double %x, %y
25+
; CHECK-NEXT: %0 = select {{(fast )?}}i1 %cmp, double %"x'", double %"y'"
26+
; CHECK-NEXT: %1 = insertvalue { double } undef, double %0, 0
27+
; CHECK-NEXT: ret { double } %1
28+
; CHECK-NEXT: }

0 commit comments

Comments
 (0)