Skip to content

Commit 9c10f97

Browse files
authored
Forward Mode CallInst (rust-lang#297)
* add test * implement call inst * remove calls to lookup
1 parent 9c979f6 commit 9c10f97

File tree

13 files changed

+1015
-8
lines changed

13 files changed

+1015
-8
lines changed

enzyme/Enzyme/AdjointGenerator.h

Lines changed: 107 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7093,10 +7093,115 @@ class AdjointGenerator
70937093
return;
70947094
}
70957095

7096-
bool modifyPrimal = shouldAugmentCall(orig, gutils, TR);
7097-
70987096
bool foreignFunction = called == nullptr || called->empty();
70997097

7098+
FnTypeInfo nextTypeInfo(called);
7099+
7100+
if (called) {
7101+
nextTypeInfo = TR.getCallInfo(*orig, *called);
7102+
}
7103+
7104+
if (Mode == DerivativeMode::ForwardMode) {
7105+
IRBuilder<> Builder2(&call);
7106+
getForwardBuilder(Builder2);
7107+
7108+
bool retUsed = subretused;
7109+
7110+
SmallVector<Value *, 8> args;
7111+
std::vector<DIFFE_TYPE> argsInverted;
7112+
std::map<int, Type *> gradByVal;
7113+
7114+
for (unsigned i = 0; i < orig->getNumArgOperands(); ++i) {
7115+
7116+
auto argi = gutils->getNewFromOriginal(orig->getArgOperand(i));
7117+
7118+
#if LLVM_VERSION_MAJOR >= 9
7119+
if (orig->isByValArgument(i)) {
7120+
gradByVal[args.size()] = orig->getParamByValType(i);
7121+
}
7122+
#endif
7123+
args.push_back(argi);
7124+
7125+
if (gutils->isConstantValue(orig->getArgOperand(i)) &&
7126+
!foreignFunction) {
7127+
argsInverted.push_back(DIFFE_TYPE::CONSTANT);
7128+
continue;
7129+
}
7130+
7131+
auto argType = argi->getType();
7132+
7133+
if (!argType->isFPOrFPVectorTy() &&
7134+
(TR.query(orig->getArgOperand(i)).Inner0().isPossiblePointer() ||
7135+
foreignFunction)) {
7136+
DIFFE_TYPE ty = DIFFE_TYPE::DUP_ARG;
7137+
if (argType->isPointerTy()) {
7138+
#if LLVM_VERSION_MAJOR >= 12
7139+
auto at = getUnderlyingObject(orig->getArgOperand(i), 100);
7140+
#else
7141+
auto at = GetUnderlyingObject(
7142+
orig->getArgOperand(i),
7143+
gutils->oldFunc->getParent()->getDataLayout(), 100);
7144+
#endif
7145+
if (auto arg = dyn_cast<Argument>(at)) {
7146+
if (constant_args[arg->getArgNo()] == DIFFE_TYPE::DUP_NONEED) {
7147+
ty = DIFFE_TYPE::DUP_NONEED;
7148+
}
7149+
}
7150+
}
7151+
args.push_back(
7152+
gutils->invertPointerM(orig->getArgOperand(i), Builder2));
7153+
argsInverted.push_back(ty);
7154+
7155+
// Note sometimes whattype mistakenly says something should be
7156+
// constant [because composed of integer pointers alone]
7157+
assert(whatType(argType, Mode) == DIFFE_TYPE::DUP_ARG ||
7158+
whatType(argType, Mode) == DIFFE_TYPE::CONSTANT);
7159+
} else {
7160+
if (foreignFunction)
7161+
assert(!argType->isIntOrIntVectorTy());
7162+
7163+
args.push_back(diffe(orig->getArgOperand(i), Builder2));
7164+
argsInverted.push_back(DIFFE_TYPE::DUP_ARG);
7165+
}
7166+
}
7167+
7168+
auto newcalled = gutils->Logic.CreatePrimalAndGradient(
7169+
cast<Function>(called), subretType, argsInverted, gutils->TLI,
7170+
TR.analyzer.interprocedural, /*returnValue*/ retUsed,
7171+
/*subdretptr*/ false, DerivativeMode::ForwardMode, nullptr,
7172+
nextTypeInfo, uncacheable_args, nullptr,
7173+
/*AtomicAdd*/ gutils->AtomicAdd);
7174+
7175+
assert(newcalled);
7176+
FunctionType *FT = cast<FunctionType>(
7177+
cast<PointerType>(newcalled->getType())->getElementType());
7178+
7179+
CallInst *diffes = Builder2.CreateCall(FT, newcalled, args);
7180+
diffes->setCallingConv(orig->getCallingConv());
7181+
diffes->setDebugLoc(gutils->getNewFromOriginal(orig->getDebugLoc()));
7182+
#if LLVM_VERSION_MAJOR >= 9
7183+
for (auto pair : gradByVal) {
7184+
diffes->addParamAttr(
7185+
pair.first,
7186+
Attribute::getWithByValType(diffes->getContext(), pair.second));
7187+
}
7188+
#endif
7189+
7190+
if (!gutils->isConstantValue(&call)) {
7191+
unsigned structidx = retUsed ? 1 : 0;
7192+
Value *diffe = Builder2.CreateExtractValue(diffes, {structidx});
7193+
setDiffe(&call, diffe, Builder2);
7194+
}
7195+
7196+
if (!subretused) {
7197+
eraseIfUnused(*orig, /*erase*/ true, /*check*/ false);
7198+
}
7199+
7200+
return;
7201+
}
7202+
7203+
bool modifyPrimal = shouldAugmentCall(orig, gutils, TR);
7204+
71007205
SmallVector<Value *, 8> args;
71017206
SmallVector<Value *, 8> pre_args;
71027207
std::vector<DIFFE_TYPE> argsInverted;
@@ -7202,12 +7307,6 @@ class AdjointGenerator
72027307
CallInst *augmentcall = nullptr;
72037308
Value *cachereplace = nullptr;
72047309

7205-
FnTypeInfo nextTypeInfo(called);
7206-
7207-
if (called) {
7208-
nextTypeInfo = TR.getCallInfo(*orig, *called);
7209-
}
7210-
72117310
// llvm::Optional<std::map<std::pair<Instruction*, std::string>,
72127311
// unsigned>> sub_index_map;
72137312
Optional<int> tapeIdx;

enzyme/Enzyme/EnzymeLogic.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2201,6 +2201,8 @@ void createTerminator(DiffeGradientUtils *gutils,
22012201

22022202
if (gutils->newFunc->getReturnType()->isVoidTy()) {
22032203
assert(retargs.size() == 0);
2204+
gutils->erase(gutils->getNewFromOriginal(inst));
2205+
nBuilder.CreateRetVoid();
22042206
return;
22052207
}
22062208

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
; RUN: %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -mem2reg -instsimplify -adce -correlated-propagation -simplifycfg -S | FileCheck %s
2+
3+
; Function Attrs: noinline norecurse nounwind uwtable
4+
define dso_local zeroext i1 @metasubf(double* nocapture %x) local_unnamed_addr #0 {
5+
entry:
6+
%arrayidx = getelementptr inbounds double, double* %x, i64 1
7+
store double 3.000000e+00, double* %arrayidx, align 8
8+
%0 = load double, double* %x, align 8
9+
%cmp = fcmp fast oeq double %0, 2.000000e+00
10+
ret i1 %cmp
11+
}
12+
13+
; Function Attrs: noinline norecurse nounwind uwtable
14+
define dso_local zeroext i1 @subf(double* nocapture %x) local_unnamed_addr #0 {
15+
entry:
16+
%0 = load double, double* %x, align 8
17+
%mul = fmul fast double %0, 2.000000e+00
18+
store double %mul, double* %x, align 8
19+
%call = tail call zeroext i1 @metasubf(double* %x)
20+
ret i1 %call
21+
}
22+
23+
; Function Attrs: noinline norecurse nounwind uwtable
24+
define dso_local void @f(double* nocapture %x) #0 {
25+
entry:
26+
%call = tail call zeroext i1 @subf(double* %x)
27+
store double 2.000000e+00, double* %x, align 8
28+
ret void
29+
}
30+
31+
; Function Attrs: noinline nounwind uwtable
32+
define dso_local double @dsumsquare(double* %x, double* %xp) local_unnamed_addr #1 {
33+
entry:
34+
%call = tail call fast double @__enzyme_fwddiff(i8* bitcast (void (double*)* @f to i8*), double* %x, double* %xp)
35+
ret double %call
36+
}
37+
38+
declare dso_local double @__enzyme_fwddiff(i8*, double*, double*) local_unnamed_addr
39+
40+
attributes #0 = { noinline norecurse nounwind uwtable }
41+
attributes #1 = { noinline nounwind uwtable }
42+
43+
; CHECK: define internal {{(dso_local )?}}void @diffef(double* nocapture %x, double* nocapture %"x'")
44+
; CHECK-NEXT: entry:
45+
; CHECK-NEXT: call void @diffesubf(double* %x, double* %"x'")
46+
; CHECK-NEXT: store double 2.000000e+00, double* %x
47+
; CHECK-NEXT: store double 0.000000e+00, double* %"x'"
48+
; CHECK-NEXT: ret void
49+
; CHECK-NEXT: }
50+
51+
; CHECK: define internal {{(dso_local )?}}void @diffesubf(double* nocapture %x, double* nocapture %"x'")
52+
; CHECK-NEXT: entry:
53+
; CHECK-NEXT: %0 = load double, double* %x
54+
; CHECK-NEXT: %1 = load double, double* %"x'"
55+
; CHECK-NEXT: %mul = fmul fast double %0, 2.000000e+00
56+
; CHECK-NEXT: %2 = fmul fast double %1, 2.000000e+00
57+
; CHECK-NEXT: store double %mul, double* %x
58+
; CHECK-NEXT: store double %2, double* %"x'"
59+
; CHECK-NEXT: call void @diffemetasubf(double* %x, double* %"x'")
60+
; CHECK-NEXT: ret void
61+
; CHECK-NEXT: }
62+
63+
; CHECK: define internal {{(dso_local )?}}void @diffemetasubf(double* nocapture %x, double* nocapture %"x'")
64+
; CHECK-NEXT: entry:
65+
; CHECK-NEXT: %"arrayidx'ipg" = getelementptr inbounds double, double* %"x'", i64 1
66+
; CHECK-NEXT: %arrayidx = getelementptr inbounds double, double* %x, i64 1
67+
; CHECK-NEXT: store double 3.000000e+00, double* %arrayidx
68+
; CHECK-NEXT: store double 0.000000e+00, double* %"arrayidx'ipg"
69+
; CHECK-NEXT: ret void
70+
; CHECK-NEXT: }
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
; RUN: %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -mem2reg -instsimplify -adce -correlated-propagation -simplifycfg -S | FileCheck %s
2+
3+
; Function Attrs: noinline norecurse nounwind uwtable
4+
define dso_local zeroext i1 @metasubf(double* nocapture %x) local_unnamed_addr #0 {
5+
entry:
6+
%arrayidx = getelementptr inbounds double, double* %x, i64 1
7+
store double 3.000000e+00, double* %arrayidx, align 8
8+
%0 = load double, double* %x, align 8
9+
%cmp = fcmp fast oeq double %0, 2.000000e+00
10+
ret i1 %cmp
11+
}
12+
13+
; Function Attrs: noinline norecurse nounwind uwtable
14+
define dso_local zeroext i1 @othermetasubf(double* nocapture %x) local_unnamed_addr #0 {
15+
entry:
16+
%arrayidx = getelementptr inbounds double, double* %x, i64 1
17+
store double 4.000000e+00, double* %arrayidx, align 8
18+
%0 = load double, double* %x, align 8
19+
%cmp = fcmp fast oeq double %0, 3.000000e+00
20+
ret i1 %cmp
21+
}
22+
23+
; Function Attrs: noinline norecurse nounwind uwtable
24+
define dso_local zeroext i1 @subf(double* nocapture %x) local_unnamed_addr #0 {
25+
entry:
26+
%0 = load double, double* %x, align 8
27+
%mul = fmul fast double %0, 2.000000e+00
28+
store double %mul, double* %x, align 8
29+
%call = tail call zeroext i1 @metasubf(double* %x)
30+
%call1 = tail call zeroext i1 @othermetasubf(double* %x)
31+
ret i1 %call1
32+
}
33+
34+
; Function Attrs: noinline norecurse nounwind uwtable
35+
define dso_local void @f(double* nocapture %x) #0 {
36+
entry:
37+
%call = tail call zeroext i1 @subf(double* %x)
38+
store double 2.000000e+00, double* %x, align 8
39+
ret void
40+
}
41+
42+
; Function Attrs: noinline nounwind uwtable
43+
define dso_local double @dsumsquare(double* %x, double* %xp) local_unnamed_addr #1 {
44+
entry:
45+
%call = tail call fast double @__enzyme_fwddiff(i8* bitcast (void (double*)* @f to i8*), double* %x, double* %xp)
46+
ret double %call
47+
}
48+
49+
declare dso_local double @__enzyme_fwddiff(i8*, double*, double*) local_unnamed_addr
50+
51+
52+
; CHECK: define internal {{(dso_local )?}}void @diffef(double* nocapture %x, double* nocapture %"x'")
53+
; CHECK-NEXT: entry:
54+
; CHECK-NEXT: call void @diffesubf(double* %x, double* %"x'")
55+
; CHECK-NEXT: store double 2.000000e+00, double* %x, align 8
56+
; CHECK-NEXT: store double 0.000000e+00, double* %"x'", align 8
57+
; CHECK-NEXT: ret void
58+
; CHECK-NEXT: }
59+
60+
; CHECK: define internal {{(dso_local )?}}void @diffesubf(double* nocapture %x, double* nocapture %"x'")
61+
; CHECK-NEXT: entry:
62+
; CHECK-NEXT: %0 = load double, double* %x, align 8
63+
; CHECK-NEXT: %1 = load double, double* %"x'"
64+
; CHECK-NEXT: %mul = fmul fast double %0, 2.000000e+00
65+
; CHECK-NEXT: %2 = fmul fast double %1, 2.000000e+00
66+
; CHECK-NEXT: store double %mul, double* %x, align 8
67+
; CHECK-NEXT: store double %2, double* %"x'", align 8
68+
; CHECK-NEXT: call void @diffemetasubf(double* %x, double* %"x'")
69+
; CHECK-NEXT: call void @diffeothermetasubf(double* %x, double* %"x'")
70+
; CHECK-NEXT: ret void
71+
; CHECK-NEXT: }
72+
73+
; CHECK: define internal {{(dso_local )?}}void @diffemetasubf(double* nocapture %x, double* nocapture %"x'")
74+
; CHECK-NEXT: entry:
75+
; CHECK-NEXT: %"arrayidx'ipg" = getelementptr inbounds double, double* %"x'", i64 1
76+
; CHECK-NEXT: %arrayidx = getelementptr inbounds double, double* %x, i64 1
77+
; CHECK-NEXT: store double 3.000000e+00, double* %arrayidx, align 8
78+
; CHECK-NEXT: store double 0.000000e+00, double* %"arrayidx'ipg", align 8
79+
; CHECK-NEXT: ret void
80+
; CHECK-NEXT: }
81+
82+
; CHECK: define internal {{(dso_local )?}}void @diffeothermetasubf(double* nocapture %x, double* nocapture %"x'")
83+
; CHECK-NEXT: entry:
84+
; CHECK-NEXT: %"arrayidx'ipg" = getelementptr inbounds double, double* %"x'", i64 1
85+
; CHECK-NEXT: %arrayidx = getelementptr inbounds double, double* %x, i64 1
86+
; CHECK-NEXT: store double 4.000000e+00, double* %arrayidx, align 8
87+
; CHECK-NEXT: store double 0.000000e+00, double* %"arrayidx'ipg", align 8
88+
; CHECK-NEXT: ret void
89+
; CHECK-NEXT: }
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
; RUN: %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -mem2reg -instsimplify -adce -correlated-propagation -simplifycfg -S | FileCheck %s
2+
3+
; Function Attrs: noinline norecurse nounwind uwtable
4+
define dso_local zeroext i1 @metasubf(double* nocapture %x) local_unnamed_addr #0 {
5+
entry:
6+
%arrayidx = getelementptr inbounds double, double* %x, i64 1
7+
store double 3.000000e+00, double* %arrayidx, align 8
8+
%0 = load double, double* %x, align 8
9+
%cmp = fcmp fast oeq double %0, 2.000000e+00
10+
ret i1 %cmp
11+
}
12+
13+
; Function Attrs: noinline norecurse nounwind uwtable
14+
define dso_local zeroext i1 @othermetasubf(double* nocapture %x) local_unnamed_addr #0 {
15+
entry:
16+
%arrayidx = getelementptr inbounds double, double* %x, i64 1
17+
store double 4.000000e+00, double* %arrayidx, align 8
18+
%0 = load double, double* %x, align 8
19+
%cmp = fcmp fast oeq double %0, 3.000000e+00
20+
ret i1 %cmp
21+
}
22+
23+
; Function Attrs: noinline norecurse nounwind uwtable
24+
define dso_local void @subf(double* nocapture %x) local_unnamed_addr #0 {
25+
entry:
26+
%0 = load double, double* %x, align 8
27+
%mul = fmul fast double %0, 2.000000e+00
28+
store double %mul, double* %x, align 8
29+
%call = tail call zeroext i1 @metasubf(double* %x)
30+
%call1 = tail call zeroext i1 @othermetasubf(double* %x)
31+
ret void
32+
}
33+
34+
; Function Attrs: noinline norecurse nounwind uwtable
35+
define dso_local void @f(double* nocapture %x) #0 {
36+
entry:
37+
tail call void @subf(double* %x)
38+
store double 2.000000e+00, double* %x, align 8
39+
ret void
40+
}
41+
42+
; Function Attrs: noinline nounwind uwtable
43+
define dso_local double @dsumsquare(double* %x, double* %xp) local_unnamed_addr #1 {
44+
entry:
45+
%call = tail call fast double @__enzyme_fwddiff(i8* bitcast (void (double*)* @f to i8*), double* %x, double* %xp)
46+
ret double %call
47+
}
48+
49+
declare dso_local double @__enzyme_fwddiff(i8*, double*, double*) local_unnamed_addr
50+
51+
; CHECK: define internal {{(dso_local )?}}void @diffef(double* nocapture %x, double* nocapture %"x'")
52+
; CHECK-NEXT: entry:
53+
; CHECK-NEXT: call void @diffesubf(double* %x, double* %"x'")
54+
; CHECK-NEXT: store double 2.000000e+00, double* %x, align 8
55+
; CHECK-NEXT: store double 0.000000e+00, double* %"x'", align 8
56+
; CHECK-NEXT: ret void
57+
; CHECK-NEXT: }
58+
59+
60+
; CHECK: define internal {{(dso_local )?}}void @diffesubf(double* nocapture %x, double* nocapture %"x'")
61+
; CHECK-NEXT: entry:
62+
; CHECK-NEXT: %0 = load double, double* %x, align 8
63+
; CHECK-NEXT: %1 = load double, double* %"x'"
64+
; CHECK-NEXT: %mul = fmul fast double %0, 2.000000e+00
65+
; CHECK-NEXT: %2 = fmul fast double %1, 2.000000e+00
66+
; CHECK-NEXT: store double %mul, double* %x, align 8
67+
; CHECK-NEXT: store double %2, double* %"x'", align 8
68+
; CHECK-NEXT: call void @diffemetasubf(double* %x, double* %"x'")
69+
; CHECK-NEXT: call void @diffeothermetasubf(double* %x, double* %"x'")
70+
; CHECK-NEXT: ret void
71+
; CHECK-NEXT: }
72+
73+
; CHECK: define internal {{(dso_local )?}}void @diffemetasubf(double* nocapture %x, double* nocapture %"x'")
74+
; CHECK-NEXT: entry:
75+
; CHECK-NEXT: %"arrayidx'ipg" = getelementptr inbounds double, double* %"x'", i64 1
76+
; CHECK-NEXT: %arrayidx = getelementptr inbounds double, double* %x, i64 1
77+
; CHECK-NEXT: store double 3.000000e+00, double* %arrayidx, align 8
78+
; CHECK-NEXT: store double 0.000000e+00, double* %"arrayidx'ipg", align 8
79+
; CHECK-NEXT: ret void
80+
; CHECK-NEXT: }
81+
82+
; CHECK: define internal {{(dso_local )?}}void @diffeothermetasubf(double* nocapture %x, double* nocapture %"x'")
83+
; CHECK-NEXT: entry:
84+
; CHECK-NEXT: %"arrayidx'ipg" = getelementptr inbounds double, double* %"x'", i64 1
85+
; CHECK-NEXT: %arrayidx = getelementptr inbounds double, double* %x, i64 1
86+
; CHECK-NEXT: store double 4.000000e+00, double* %arrayidx, align 8
87+
; CHECK-NEXT: store double 0.000000e+00, double* %"arrayidx'ipg", align 8
88+
; CHECK-NEXT: ret void
89+
; CHECK-NEXT: }
90+

0 commit comments

Comments
 (0)