Skip to content

Commit 692cc07

Browse files
committed
Automatic derivative registration
1 parent 4a035cd commit 692cc07

File tree

2 files changed

+133
-131
lines changed

2 files changed

+133
-131
lines changed

enzyme/Enzyme/AdjointGenerator.h

Lines changed: 66 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -2814,11 +2814,19 @@ class AdjointGenerator
28142814

28152815
Function *called = orig->getCalledFunction();
28162816

2817+
StringRef funcName = "";
2818+
if (called) {
2819+
if (called->hasFnAttribute("enzyme_math"))
2820+
funcName = called->getFnAttribute("enzyme_math").getValueAsString();
2821+
else
2822+
funcName = called->getName();
2823+
}
2824+
28172825
if (Mode != DerivativeMode::ReverseModePrimal && called) {
2818-
if (called->getName() == "__kmpc_for_static_init_4" ||
2819-
called->getName() == "__kmpc_for_static_init_4u" ||
2820-
called->getName() == "__kmpc_for_static_init_8" ||
2821-
called->getName() == "__kmpc_for_static_init_8u") {
2826+
if (funcName == "__kmpc_for_static_init_4" ||
2827+
funcName == "__kmpc_for_static_init_4u" ||
2828+
funcName == "__kmpc_for_static_init_8" ||
2829+
funcName == "__kmpc_for_static_init_8u") {
28222830
IRBuilder<> Builder2(call.getParent());
28232831
getReverseBuilder(Builder2);
28242832
auto fini = called->getParent()->getFunction("__kmpc_for_static_fini");
@@ -2834,8 +2842,7 @@ class AdjointGenerator
28342842
}
28352843

28362844
// MPI send / recv can only send float/integers
2837-
if (called && (called->getName() == "MPI_Isend" ||
2838-
called->getName() == "MPI_Irecv")) {
2845+
if (funcName == "MPI_Isend" || funcName == "MPI_Irecv") {
28392846
Value *firstallocation = nullptr;
28402847
if (Mode == DerivativeMode::ReverseModePrimal ||
28412848
Mode == DerivativeMode::ReverseModeCombined) {
@@ -2853,7 +2860,7 @@ class AdjointGenerator
28532860
/*5 */ Type::getInt8PtrTy(call.getContext()),
28542861
/*6 */ Type::getInt8Ty(call.getContext()),
28552862
};
2856-
auto impi = StructType::get(called->getContext(), types, false);
2863+
auto impi = StructType::get(call.getContext(), types, false);
28572864

28582865
Value *impialloc = CallInst::CreateMalloc(
28592866
gutils->getNewFromOriginal(&call), i64, impi,
@@ -2869,7 +2876,7 @@ class AdjointGenerator
28692876
d_req, PointerType::getUnqual(impialloc->getType()));
28702877
BuilderZ.CreateStore(impialloc, d_req);
28712878

2872-
if (called->getName() == "MPI_Isend") {
2879+
if (funcName == "MPI_Isend") {
28732880
Value *tysize = MPI_TYPE_SIZE(
28742881
gutils->getNewFromOriginal(call.getOperand(2)), BuilderZ);
28752882

@@ -2935,7 +2942,7 @@ class AdjointGenerator
29352942

29362943
BuilderZ.CreateStore(
29372944
ConstantInt::get(Type::getInt8Ty(impialloc->getContext()),
2938-
(called->getName() == "MPI_Isend")
2945+
(funcName == "MPI_Isend")
29392946
? (int)MPI_CallType::ISEND
29402947
: (int)MPI_CallType::IRECV),
29412948
BuilderZ.CreateInBoundsGEP(impialloc,
@@ -2969,7 +2976,7 @@ class AdjointGenerator
29692976
Type::getInt64Ty(Builder2.getContext())),
29702977
"", true, true);
29712978

2972-
if (called->getName() == "MPI_Irecv") {
2979+
if (funcName == "MPI_Irecv") {
29732980
auto val_arg =
29742981
ConstantInt::get(Type::getInt8Ty(Builder2.getContext()), 0);
29752982
auto volatile_arg = ConstantInt::getFalse(Builder2.getContext());
@@ -2989,7 +2996,7 @@ class AdjointGenerator
29892996
tys),
29902997
nargs));
29912998
memset->addParamAttr(0, Attribute::NonNull);
2992-
} else if (called->getName() == "MPI_Isend") {
2999+
} else if (funcName == "MPI_Isend") {
29933000
Value *shadow = gutils->invertPointerM(call.getOperand(0), Builder2);
29943001
if (Mode == DerivativeMode::ReverseModeCombined)
29953002
firstallocation = lookup(firstallocation, Builder2);
@@ -3103,7 +3110,7 @@ class AdjointGenerator
31033110
return;
31043111
}
31053112

3106-
if (called && called->getName() == "MPI_Wait") {
3113+
if (funcName == "MPI_Wait") {
31073114
if (Mode == DerivativeMode::ReverseModeGradient ||
31083115
Mode == DerivativeMode::ReverseModeCombined) {
31093116
IRBuilder<> Builder2(call.getParent());
@@ -3121,7 +3128,7 @@ class AdjointGenerator
31213128
/*5 */ Type::getInt8PtrTy(call.getContext()),
31223129
/*6 */ Type::getInt8Ty(call.getContext()),
31233130
};
3124-
auto impi = StructType::get(called->getContext(), types, false);
3131+
auto impi = StructType::get(call.getContext(), types, false);
31253132

31263133
Value *d_reqp = Builder2.CreateLoad(Builder2.CreatePointerCast(
31273134
d_req, PointerType::getUnqual(PointerType::getUnqual(impi))));
@@ -3150,8 +3157,7 @@ class AdjointGenerator
31503157
return;
31513158
}
31523159

3153-
if (called &&
3154-
(called->getName() == "MPI_Send" || called->getName() == "MPI_Ssend")) {
3160+
if (funcName == "MPI_Send" || funcName == "MPI_Ssend") {
31553161
if (Mode == DerivativeMode::ReverseModeGradient ||
31563162
Mode == DerivativeMode::ReverseModeCombined) {
31573163
IRBuilder<> Builder2(call.getParent());
@@ -3305,7 +3311,7 @@ class AdjointGenerator
33053311
return;
33063312
}
33073313

3308-
if (called && called->getName() == "MPI_Recv") {
3314+
if (funcName == "MPI_Recv") {
33093315
if (Mode == DerivativeMode::ReverseModeGradient ||
33103316
Mode == DerivativeMode::ReverseModeCombined) {
33113317
IRBuilder<> Builder2(call.getParent());
@@ -3368,10 +3374,9 @@ class AdjointGenerator
33683374
}
33693375
}
33703376

3371-
if (called &&
3372-
(called->getName() == "printf" || called->getName() == "puts" ||
3373-
called->getName().startswith("_ZN3std2io5stdio6_print") ||
3374-
called->getName().startswith("_ZN4core3fmt"))) {
3377+
if (funcName == "printf" || funcName == "puts" ||
3378+
funcName.startswith("_ZN3std2io5stdio6_print") ||
3379+
funcName.startswith("_ZN4core3fmt")) {
33753380
if (Mode == DerivativeMode::ReverseModeGradient) {
33763381
eraseIfUnused(*orig, /*erase*/ true, /*check*/ false);
33773382
}
@@ -3387,14 +3392,11 @@ class AdjointGenerator
33873392

33883393
// Handle lgamma, safe to recompute so no store/change to forward
33893394
if (called) {
3390-
auto n = called->getName();
3391-
if (called->getName() == "__kmpc_fork_call") {
3395+
if (funcName == "__kmpc_fork_call") {
33923396
visitOMPCall(call);
33933397
return;
33943398
}
3395-
if (called &&
3396-
(called->getName() == "asin" || called->getName() == "asinf" ||
3397-
called->getName() == "asinl")) {
3399+
if (funcName == "asin" || funcName == "asinf" || funcName == "asinl") {
33983400
if (gutils->knownRecomputeHeuristic.find(orig) !=
33993401
gutils->knownRecomputeHeuristic.end()) {
34003402
if (!gutils->knownRecomputeHeuristic[orig]) {
@@ -3426,10 +3428,8 @@ class AdjointGenerator
34263428
return;
34273429
}
34283430

3429-
if (called &&
3430-
(called->getName() == "atan" || called->getName() == "atanf" ||
3431-
called->getName() == "atanl" ||
3432-
called->getName() == "__fd_atan_1")) {
3431+
if (funcName == "atan" || funcName == "atanf" || funcName == "atanl" ||
3432+
funcName == "__fd_atan_1") {
34333433
if (gutils->knownRecomputeHeuristic.find(orig) !=
34343434
gutils->knownRecomputeHeuristic.end()) {
34353435
if (!gutils->knownRecomputeHeuristic[orig]) {
@@ -3453,8 +3453,7 @@ class AdjointGenerator
34533453
return;
34543454
}
34553455

3456-
if (called &&
3457-
(called->getName() == "tanhf" || called->getName() == "tanh")) {
3456+
if (funcName == "tanhf" || funcName == "tanh") {
34583457
if (gutils->knownRecomputeHeuristic.find(orig) !=
34593458
gutils->knownRecomputeHeuristic.end()) {
34603459
if (!gutils->knownRecomputeHeuristic[orig]) {
@@ -3474,8 +3473,8 @@ class AdjointGenerator
34743473

34753474
SmallVector<Value *, 1> args = {x};
34763475
auto coshf = gutils->oldFunc->getParent()->getOrInsertFunction(
3477-
(called->getName() == "tanh") ? "cosh" : "coshf",
3478-
called->getFunctionType(), called->getAttributes());
3476+
(funcName == "tanh") ? "cosh" : "coshf", called->getFunctionType(),
3477+
called->getAttributes());
34793478
auto cal = cast<CallInst>(Builder2.CreateCall(coshf, args));
34803479
Value *dif0 = Builder2.CreateFDiv(diffe(orig, Builder2),
34813480
Builder2.CreateFMul(cal, cal));
@@ -3484,7 +3483,7 @@ class AdjointGenerator
34843483
return;
34853484
}
34863485

3487-
if (called->getName() == "coshf" || called->getName() == "cosh") {
3486+
if (funcName == "coshf" || funcName == "cosh") {
34883487
if (gutils->knownRecomputeHeuristic.find(orig) !=
34893488
gutils->knownRecomputeHeuristic.end()) {
34903489
if (!gutils->knownRecomputeHeuristic[orig]) {
@@ -3504,15 +3503,15 @@ class AdjointGenerator
35043503

35053504
SmallVector<Value *, 1> args = {x};
35063505
auto sinhf = gutils->oldFunc->getParent()->getOrInsertFunction(
3507-
(called->getName() == "cosh") ? "sinh" : "sinhf",
3508-
called->getFunctionType(), called->getAttributes());
3506+
(funcName == "cosh") ? "sinh" : "sinhf", called->getFunctionType(),
3507+
called->getAttributes());
35093508
auto cal = cast<CallInst>(Builder2.CreateCall(sinhf, args));
35103509
Value *dif0 = Builder2.CreateFMul(diffe(orig, Builder2), cal);
35113510
setDiffe(orig, Constant::getNullValue(orig->getType()), Builder2);
35123511
addToDiffe(orig->getArgOperand(0), dif0, Builder2, x->getType());
35133512
return;
35143513
}
3515-
if (called->getName() == "sinhf" || called->getName() == "sinh") {
3514+
if (funcName == "sinhf" || funcName == "sinh") {
35163515
if (gutils->knownRecomputeHeuristic.find(orig) !=
35173516
gutils->knownRecomputeHeuristic.end()) {
35183517
if (!gutils->knownRecomputeHeuristic[orig]) {
@@ -3532,8 +3531,8 @@ class AdjointGenerator
35323531

35333532
SmallVector<Value *, 1> args = {x};
35343533
auto sinhf = gutils->oldFunc->getParent()->getOrInsertFunction(
3535-
(called->getName() == "sinh") ? "cosh" : "coshf",
3536-
called->getFunctionType(), called->getAttributes());
3534+
(funcName == "sinh") ? "cosh" : "coshf", called->getFunctionType(),
3535+
called->getAttributes());
35373536
auto cal = cast<CallInst>(Builder2.CreateCall(sinhf, args));
35383537
Value *dif0 = Builder2.CreateFMul(diffe(orig, Builder2), cal);
35393538
setDiffe(orig, Constant::getNullValue(orig->getType()), Builder2);
@@ -3542,7 +3541,7 @@ class AdjointGenerator
35423541
}
35433542

35443543
if (called) {
3545-
if (called->getName() == "erf") {
3544+
if (funcName == "erf") {
35463545
if (gutils->knownRecomputeHeuristic.find(orig) !=
35473546
gutils->knownRecomputeHeuristic.end()) {
35483547
if (!gutils->knownRecomputeHeuristic[orig]) {
@@ -3576,7 +3575,7 @@ class AdjointGenerator
35763575
addToDiffe(orig->getArgOperand(0), cal, Builder2, x->getType());
35773576
return;
35783577
}
3579-
if (called->getName() == "erfi") {
3578+
if (funcName == "erfi") {
35803579
if (gutils->knownRecomputeHeuristic.find(orig) !=
35813580
gutils->knownRecomputeHeuristic.end()) {
35823581
if (!gutils->knownRecomputeHeuristic[orig]) {
@@ -3610,7 +3609,7 @@ class AdjointGenerator
36103609
addToDiffe(orig->getArgOperand(0), cal, Builder2, x->getType());
36113610
return;
36123611
}
3613-
if (called->getName() == "erfc") {
3612+
if (funcName == "erfc") {
36143613
if (gutils->knownRecomputeHeuristic.find(orig) !=
36153614
gutils->knownRecomputeHeuristic.end()) {
36163615
if (!gutils->knownRecomputeHeuristic[orig]) {
@@ -3645,8 +3644,8 @@ class AdjointGenerator
36453644
return;
36463645
}
36473646

3648-
if (called->getName() == "j0" || called->getName() == "y0" ||
3649-
called->getName() == "j0f" || called->getName() == "y0f") {
3647+
if (funcName == "j0" || funcName == "y0" || funcName == "j0f" ||
3648+
funcName == "y0f") {
36503649
if (gutils->knownRecomputeHeuristic.find(orig) !=
36513650
gutils->knownRecomputeHeuristic.end()) {
36523651
if (!gutils->knownRecomputeHeuristic[orig]) {
@@ -3667,9 +3666,8 @@ class AdjointGenerator
36673666

36683667
Value *dx = Builder2.CreateCall(
36693668
gutils->oldFunc->getParent()->getOrInsertFunction(
3670-
(called->getName()[0] == 'j')
3671-
? ((called->getName() == "j0") ? "j1" : "j1f")
3672-
: ((called->getName() == "y0") ? "y1" : "y1f"),
3669+
(funcName[0] == 'j') ? ((funcName == "j0") ? "j1" : "j1f")
3670+
: ((funcName == "y0") ? "y1" : "y1f"),
36733671
called->getFunctionType()),
36743672
std::vector<Value *>({x}));
36753673
dx = Builder2.CreateFNeg(dx);
@@ -3679,8 +3677,8 @@ class AdjointGenerator
36793677
return;
36803678
}
36813679

3682-
if (called->getName() == "j1" || called->getName() == "y1" ||
3683-
called->getName() == "j1f" || called->getName() == "y1f") {
3680+
if (funcName == "j1" || funcName == "y1" || funcName == "j1f" ||
3681+
funcName == "y1f") {
36843682
if (gutils->knownRecomputeHeuristic.find(orig) !=
36853683
gutils->knownRecomputeHeuristic.end()) {
36863684
if (!gutils->knownRecomputeHeuristic[orig]) {
@@ -3701,9 +3699,8 @@ class AdjointGenerator
37013699

37023700
Value *d0 = Builder2.CreateCall(
37033701
gutils->oldFunc->getParent()->getOrInsertFunction(
3704-
(called->getName()[0] == 'j')
3705-
? ((called->getName() == "j1") ? "j0" : "j0f")
3706-
: ((called->getName() == "y1") ? "y0" : "y0f"),
3702+
(funcName[0] == 'j') ? ((funcName == "j1") ? "j0" : "j0f")
3703+
: ((funcName == "y1") ? "y0" : "y0f"),
37073704
called->getFunctionType()),
37083705
std::vector<Value *>({x}));
37093706

@@ -3713,9 +3710,8 @@ class AdjointGenerator
37133710
auto FT2 = FunctionType::get(x->getType(), pargs, false);
37143711
Value *d2 = Builder2.CreateCall(
37153712
gutils->oldFunc->getParent()->getOrInsertFunction(
3716-
(called->getName()[0] == 'j')
3717-
? ((called->getName() == "j1") ? "jn" : "jnf")
3718-
: ((called->getName() == "y1") ? "yn" : "ynf"),
3713+
(funcName[0] == 'j') ? ((funcName == "j1") ? "jn" : "jnf")
3714+
: ((funcName == "y1") ? "yn" : "ynf"),
37193715
FT2),
37203716
std::vector<Value *>({ConstantInt::get(intType, 2), x}));
37213717
Value *dx = Builder2.CreateFSub(d0, d2);
@@ -3726,8 +3722,8 @@ class AdjointGenerator
37263722
return;
37273723
}
37283724

3729-
if (called->getName() == "jn" || called->getName() == "yn" ||
3730-
called->getName() == "jnf" || called->getName() == "ynf") {
3725+
if (funcName == "jn" || funcName == "yn" || funcName == "jnf" ||
3726+
funcName == "ynf") {
37313727
if (gutils->knownRecomputeHeuristic.find(orig) !=
37323728
gutils->knownRecomputeHeuristic.end()) {
37333729
if (!gutils->knownRecomputeHeuristic[orig]) {
@@ -3768,7 +3764,7 @@ class AdjointGenerator
37683764
return;
37693765
}
37703766

3771-
if (called->getName() == "julia.write_barrier") {
3767+
if (funcName == "julia.write_barrier") {
37723768
if (Mode == DerivativeMode::ReverseModeGradient) {
37733769
eraseIfUnused(*orig, /*erase*/ true, /*check*/ false);
37743770
return;
@@ -3788,7 +3784,7 @@ class AdjointGenerator
37883784
return;
37893785
}
37903786
Intrinsic::ID ID = Intrinsic::not_intrinsic;
3791-
if (isMemFreeLibMFunction(called->getName(), &ID)) {
3787+
if (isMemFreeLibMFunction(funcName, &ID)) {
37923788
if (Mode == DerivativeMode::ReverseModePrimal ||
37933789
gutils->isConstantInstruction(orig)) {
37943790
eraseIfUnused(*orig);
@@ -3804,7 +3800,7 @@ class AdjointGenerator
38043800
return;
38053801
}
38063802
}
3807-
if (called->getName() == "__fd_sincos_1") {
3803+
if (funcName == "__fd_sincos_1") {
38083804
if (gutils->knownRecomputeHeuristic.find(orig) !=
38093805
gutils->knownRecomputeHeuristic.end()) {
38103806
if (!gutils->knownRecomputeHeuristic[orig]) {
@@ -3847,8 +3843,7 @@ class AdjointGenerator
38473843
addToDiffe(orig->getArgOperand(0), dif0, Builder2, x->getType());
38483844
return;
38493845
}
3850-
if (called->getName() == "cabs" || called->getName() == "cabsf" ||
3851-
called->getName() == "cabsl") {
3846+
if (funcName == "cabs" || funcName == "cabsf" || funcName == "cabsl") {
38523847
if (gutils->knownRecomputeHeuristic.find(orig) !=
38533848
gutils->knownRecomputeHeuristic.end()) {
38543849
if (!gutils->knownRecomputeHeuristic[orig]) {
@@ -3891,8 +3886,8 @@ class AdjointGenerator
38913886
llvm_unreachable("unknown calling convention found for cabs");
38923887
}
38933888
}
3894-
if (called->getName() == "ldexp" || called->getName() == "ldexpf" ||
3895-
called->getName() == "ldexpl") {
3889+
if (funcName == "ldexp" || funcName == "ldexpf" ||
3890+
funcName == "ldexpl") {
38963891
if (gutils->knownRecomputeHeuristic.find(orig) !=
38973892
gutils->knownRecomputeHeuristic.end()) {
38983893
if (!gutils->knownRecomputeHeuristic[orig]) {
@@ -3923,10 +3918,11 @@ class AdjointGenerator
39233918
}
39243919
}
39253920

3926-
if (n == "lgamma" || n == "lgammaf" || n == "lgammal" ||
3927-
n == "lgamma_r" || n == "lgammaf_r" || n == "lgammal_r" ||
3928-
n == "__lgamma_r_finite" || n == "__lgammaf_r_finite" ||
3929-
n == "__lgammal_r_finite") {
3921+
if (funcName == "lgamma" || funcName == "lgammaf" ||
3922+
funcName == "lgammal" || funcName == "lgamma_r" ||
3923+
funcName == "lgammaf_r" || funcName == "lgammal_r" ||
3924+
funcName == "__lgamma_r_finite" || funcName == "__lgammaf_r_finite" ||
3925+
funcName == "__lgammal_r_finite") {
39303926
if (gutils->knownRecomputeHeuristic.find(orig) !=
39313927
gutils->knownRecomputeHeuristic.end()) {
39323928
if (!gutils->knownRecomputeHeuristic[orig]) {
@@ -4004,7 +4000,7 @@ class AdjointGenerator
40044000
return;
40054001
}
40064002

4007-
if (called && called->getName() == "julia.pointer_from_objref") {
4003+
if (funcName == "julia.pointer_from_objref") {
40084004
eraseIfUnused(*orig);
40094005
if (gutils->isConstantValue(orig))
40104006
return;
@@ -4024,7 +4020,7 @@ class AdjointGenerator
40244020
return;
40254021
}
40264022

4027-
if (called && called->getName() == "posix_memalign") {
4023+
if (funcName == "posix_memalign") {
40284024
if (gutils->invertedPointers.count(orig)) {
40294025
auto placeholder = cast<PHINode>(gutils->invertedPointers[orig]);
40304026
gutils->invertedPointers.erase(orig);

0 commit comments

Comments
 (0)