@@ -2814,11 +2814,19 @@ class AdjointGenerator
2814
2814
2815
2815
Function *called = orig->getCalledFunction ();
2816
2816
2817
+ StringRef funcName = " " ;
2818
+ if (called) {
2819
+ if (called->hasFnAttribute (" enzyme_math" ))
2820
+ funcName = called->getFnAttribute (" enzyme_math" ).getValueAsString ();
2821
+ else
2822
+ funcName = called->getName ();
2823
+ }
2824
+
2817
2825
if (Mode != DerivativeMode::ReverseModePrimal && called) {
2818
- if (called-> getName () == " __kmpc_for_static_init_4" ||
2819
- called-> getName () == " __kmpc_for_static_init_4u" ||
2820
- called-> getName () == " __kmpc_for_static_init_8" ||
2821
- called-> getName () == " __kmpc_for_static_init_8u" ) {
2826
+ if (funcName == " __kmpc_for_static_init_4" ||
2827
+ funcName == " __kmpc_for_static_init_4u" ||
2828
+ funcName == " __kmpc_for_static_init_8" ||
2829
+ funcName == " __kmpc_for_static_init_8u" ) {
2822
2830
IRBuilder<> Builder2 (call.getParent ());
2823
2831
getReverseBuilder (Builder2);
2824
2832
auto fini = called->getParent ()->getFunction (" __kmpc_for_static_fini" );
@@ -2834,8 +2842,7 @@ class AdjointGenerator
2834
2842
}
2835
2843
2836
2844
// MPI send / recv can only send float/integers
2837
- if (called && (called->getName () == " MPI_Isend" ||
2838
- called->getName () == " MPI_Irecv" )) {
2845
+ if (funcName == " MPI_Isend" || funcName == " MPI_Irecv" ) {
2839
2846
Value *firstallocation = nullptr ;
2840
2847
if (Mode == DerivativeMode::ReverseModePrimal ||
2841
2848
Mode == DerivativeMode::ReverseModeCombined) {
@@ -2853,7 +2860,7 @@ class AdjointGenerator
2853
2860
/* 5 */ Type::getInt8PtrTy (call.getContext ()),
2854
2861
/* 6 */ Type::getInt8Ty (call.getContext ()),
2855
2862
};
2856
- auto impi = StructType::get (called-> getContext (), types, false );
2863
+ auto impi = StructType::get (call. getContext (), types, false );
2857
2864
2858
2865
Value *impialloc = CallInst::CreateMalloc (
2859
2866
gutils->getNewFromOriginal (&call), i64 , impi,
@@ -2869,7 +2876,7 @@ class AdjointGenerator
2869
2876
d_req, PointerType::getUnqual (impialloc->getType ()));
2870
2877
BuilderZ.CreateStore (impialloc, d_req);
2871
2878
2872
- if (called-> getName () == " MPI_Isend" ) {
2879
+ if (funcName == " MPI_Isend" ) {
2873
2880
Value *tysize = MPI_TYPE_SIZE (
2874
2881
gutils->getNewFromOriginal (call.getOperand (2 )), BuilderZ);
2875
2882
@@ -2935,7 +2942,7 @@ class AdjointGenerator
2935
2942
2936
2943
BuilderZ.CreateStore (
2937
2944
ConstantInt::get (Type::getInt8Ty (impialloc->getContext ()),
2938
- (called-> getName () == " MPI_Isend" )
2945
+ (funcName == " MPI_Isend" )
2939
2946
? (int )MPI_CallType::ISEND
2940
2947
: (int )MPI_CallType::IRECV),
2941
2948
BuilderZ.CreateInBoundsGEP (impialloc,
@@ -2969,7 +2976,7 @@ class AdjointGenerator
2969
2976
Type::getInt64Ty (Builder2.getContext ())),
2970
2977
" " , true , true );
2971
2978
2972
- if (called-> getName () == " MPI_Irecv" ) {
2979
+ if (funcName == " MPI_Irecv" ) {
2973
2980
auto val_arg =
2974
2981
ConstantInt::get (Type::getInt8Ty (Builder2.getContext ()), 0 );
2975
2982
auto volatile_arg = ConstantInt::getFalse (Builder2.getContext ());
@@ -2989,7 +2996,7 @@ class AdjointGenerator
2989
2996
tys),
2990
2997
nargs));
2991
2998
memset->addParamAttr (0 , Attribute::NonNull);
2992
- } else if (called-> getName () == " MPI_Isend" ) {
2999
+ } else if (funcName == " MPI_Isend" ) {
2993
3000
Value *shadow = gutils->invertPointerM (call.getOperand (0 ), Builder2);
2994
3001
if (Mode == DerivativeMode::ReverseModeCombined)
2995
3002
firstallocation = lookup (firstallocation, Builder2);
@@ -3103,7 +3110,7 @@ class AdjointGenerator
3103
3110
return ;
3104
3111
}
3105
3112
3106
- if (called && called-> getName () == " MPI_Wait" ) {
3113
+ if (funcName == " MPI_Wait" ) {
3107
3114
if (Mode == DerivativeMode::ReverseModeGradient ||
3108
3115
Mode == DerivativeMode::ReverseModeCombined) {
3109
3116
IRBuilder<> Builder2 (call.getParent ());
@@ -3121,7 +3128,7 @@ class AdjointGenerator
3121
3128
/* 5 */ Type::getInt8PtrTy (call.getContext ()),
3122
3129
/* 6 */ Type::getInt8Ty (call.getContext ()),
3123
3130
};
3124
- auto impi = StructType::get (called-> getContext (), types, false );
3131
+ auto impi = StructType::get (call. getContext (), types, false );
3125
3132
3126
3133
Value *d_reqp = Builder2.CreateLoad (Builder2.CreatePointerCast (
3127
3134
d_req, PointerType::getUnqual (PointerType::getUnqual (impi))));
@@ -3150,8 +3157,7 @@ class AdjointGenerator
3150
3157
return ;
3151
3158
}
3152
3159
3153
- if (called &&
3154
- (called->getName () == " MPI_Send" || called->getName () == " MPI_Ssend" )) {
3160
+ if (funcName == " MPI_Send" || funcName == " MPI_Ssend" ) {
3155
3161
if (Mode == DerivativeMode::ReverseModeGradient ||
3156
3162
Mode == DerivativeMode::ReverseModeCombined) {
3157
3163
IRBuilder<> Builder2 (call.getParent ());
@@ -3305,7 +3311,7 @@ class AdjointGenerator
3305
3311
return ;
3306
3312
}
3307
3313
3308
- if (called && called-> getName () == " MPI_Recv" ) {
3314
+ if (funcName == " MPI_Recv" ) {
3309
3315
if (Mode == DerivativeMode::ReverseModeGradient ||
3310
3316
Mode == DerivativeMode::ReverseModeCombined) {
3311
3317
IRBuilder<> Builder2 (call.getParent ());
@@ -3368,10 +3374,9 @@ class AdjointGenerator
3368
3374
}
3369
3375
}
3370
3376
3371
- if (called &&
3372
- (called->getName () == " printf" || called->getName () == " puts" ||
3373
- called->getName ().startswith (" _ZN3std2io5stdio6_print" ) ||
3374
- called->getName ().startswith (" _ZN4core3fmt" ))) {
3377
+ if (funcName == " printf" || funcName == " puts" ||
3378
+ funcName.startswith (" _ZN3std2io5stdio6_print" ) ||
3379
+ funcName.startswith (" _ZN4core3fmt" )) {
3375
3380
if (Mode == DerivativeMode::ReverseModeGradient) {
3376
3381
eraseIfUnused (*orig, /* erase*/ true , /* check*/ false );
3377
3382
}
@@ -3387,14 +3392,11 @@ class AdjointGenerator
3387
3392
3388
3393
// Handle lgamma, safe to recompute so no store/change to forward
3389
3394
if (called) {
3390
- auto n = called->getName ();
3391
- if (called->getName () == " __kmpc_fork_call" ) {
3395
+ if (funcName == " __kmpc_fork_call" ) {
3392
3396
visitOMPCall (call);
3393
3397
return ;
3394
3398
}
3395
- if (called &&
3396
- (called->getName () == " asin" || called->getName () == " asinf" ||
3397
- called->getName () == " asinl" )) {
3399
+ if (funcName == " asin" || funcName == " asinf" || funcName == " asinl" ) {
3398
3400
if (gutils->knownRecomputeHeuristic .find (orig) !=
3399
3401
gutils->knownRecomputeHeuristic .end ()) {
3400
3402
if (!gutils->knownRecomputeHeuristic [orig]) {
@@ -3426,10 +3428,8 @@ class AdjointGenerator
3426
3428
return ;
3427
3429
}
3428
3430
3429
- if (called &&
3430
- (called->getName () == " atan" || called->getName () == " atanf" ||
3431
- called->getName () == " atanl" ||
3432
- called->getName () == " __fd_atan_1" )) {
3431
+ if (funcName == " atan" || funcName == " atanf" || funcName == " atanl" ||
3432
+ funcName == " __fd_atan_1" ) {
3433
3433
if (gutils->knownRecomputeHeuristic .find (orig) !=
3434
3434
gutils->knownRecomputeHeuristic .end ()) {
3435
3435
if (!gutils->knownRecomputeHeuristic [orig]) {
@@ -3453,8 +3453,7 @@ class AdjointGenerator
3453
3453
return ;
3454
3454
}
3455
3455
3456
- if (called &&
3457
- (called->getName () == " tanhf" || called->getName () == " tanh" )) {
3456
+ if (funcName == " tanhf" || funcName == " tanh" ) {
3458
3457
if (gutils->knownRecomputeHeuristic .find (orig) !=
3459
3458
gutils->knownRecomputeHeuristic .end ()) {
3460
3459
if (!gutils->knownRecomputeHeuristic [orig]) {
@@ -3474,8 +3473,8 @@ class AdjointGenerator
3474
3473
3475
3474
SmallVector<Value *, 1 > args = {x};
3476
3475
auto coshf = gutils->oldFunc ->getParent ()->getOrInsertFunction (
3477
- (called-> getName () == " tanh" ) ? " cosh" : " coshf" ,
3478
- called->getFunctionType (), called-> getAttributes ());
3476
+ (funcName == " tanh" ) ? " cosh" : " coshf" , called-> getFunctionType () ,
3477
+ called->getAttributes ());
3479
3478
auto cal = cast<CallInst>(Builder2.CreateCall (coshf, args));
3480
3479
Value *dif0 = Builder2.CreateFDiv (diffe (orig, Builder2),
3481
3480
Builder2.CreateFMul (cal, cal));
@@ -3484,7 +3483,7 @@ class AdjointGenerator
3484
3483
return ;
3485
3484
}
3486
3485
3487
- if (called-> getName () == " coshf" || called-> getName () == " cosh" ) {
3486
+ if (funcName == " coshf" || funcName == " cosh" ) {
3488
3487
if (gutils->knownRecomputeHeuristic .find (orig) !=
3489
3488
gutils->knownRecomputeHeuristic .end ()) {
3490
3489
if (!gutils->knownRecomputeHeuristic [orig]) {
@@ -3504,15 +3503,15 @@ class AdjointGenerator
3504
3503
3505
3504
SmallVector<Value *, 1 > args = {x};
3506
3505
auto sinhf = gutils->oldFunc ->getParent ()->getOrInsertFunction (
3507
- (called-> getName () == " cosh" ) ? " sinh" : " sinhf" ,
3508
- called->getFunctionType (), called-> getAttributes ());
3506
+ (funcName == " cosh" ) ? " sinh" : " sinhf" , called-> getFunctionType () ,
3507
+ called->getAttributes ());
3509
3508
auto cal = cast<CallInst>(Builder2.CreateCall (sinhf, args));
3510
3509
Value *dif0 = Builder2.CreateFMul (diffe (orig, Builder2), cal);
3511
3510
setDiffe (orig, Constant::getNullValue (orig->getType ()), Builder2);
3512
3511
addToDiffe (orig->getArgOperand (0 ), dif0, Builder2, x->getType ());
3513
3512
return ;
3514
3513
}
3515
- if (called-> getName () == " sinhf" || called-> getName () == " sinh" ) {
3514
+ if (funcName == " sinhf" || funcName == " sinh" ) {
3516
3515
if (gutils->knownRecomputeHeuristic .find (orig) !=
3517
3516
gutils->knownRecomputeHeuristic .end ()) {
3518
3517
if (!gutils->knownRecomputeHeuristic [orig]) {
@@ -3532,8 +3531,8 @@ class AdjointGenerator
3532
3531
3533
3532
SmallVector<Value *, 1 > args = {x};
3534
3533
auto sinhf = gutils->oldFunc ->getParent ()->getOrInsertFunction (
3535
- (called-> getName () == " sinh" ) ? " cosh" : " coshf" ,
3536
- called->getFunctionType (), called-> getAttributes ());
3534
+ (funcName == " sinh" ) ? " cosh" : " coshf" , called-> getFunctionType () ,
3535
+ called->getAttributes ());
3537
3536
auto cal = cast<CallInst>(Builder2.CreateCall (sinhf, args));
3538
3537
Value *dif0 = Builder2.CreateFMul (diffe (orig, Builder2), cal);
3539
3538
setDiffe (orig, Constant::getNullValue (orig->getType ()), Builder2);
@@ -3542,7 +3541,7 @@ class AdjointGenerator
3542
3541
}
3543
3542
3544
3543
if (called) {
3545
- if (called-> getName () == " erf" ) {
3544
+ if (funcName == " erf" ) {
3546
3545
if (gutils->knownRecomputeHeuristic .find (orig) !=
3547
3546
gutils->knownRecomputeHeuristic .end ()) {
3548
3547
if (!gutils->knownRecomputeHeuristic [orig]) {
@@ -3576,7 +3575,7 @@ class AdjointGenerator
3576
3575
addToDiffe (orig->getArgOperand (0 ), cal, Builder2, x->getType ());
3577
3576
return ;
3578
3577
}
3579
- if (called-> getName () == " erfi" ) {
3578
+ if (funcName == " erfi" ) {
3580
3579
if (gutils->knownRecomputeHeuristic .find (orig) !=
3581
3580
gutils->knownRecomputeHeuristic .end ()) {
3582
3581
if (!gutils->knownRecomputeHeuristic [orig]) {
@@ -3610,7 +3609,7 @@ class AdjointGenerator
3610
3609
addToDiffe (orig->getArgOperand (0 ), cal, Builder2, x->getType ());
3611
3610
return ;
3612
3611
}
3613
- if (called-> getName () == " erfc" ) {
3612
+ if (funcName == " erfc" ) {
3614
3613
if (gutils->knownRecomputeHeuristic .find (orig) !=
3615
3614
gutils->knownRecomputeHeuristic .end ()) {
3616
3615
if (!gutils->knownRecomputeHeuristic [orig]) {
@@ -3645,8 +3644,8 @@ class AdjointGenerator
3645
3644
return ;
3646
3645
}
3647
3646
3648
- if (called-> getName () == " j0" || called-> getName () == " y0" ||
3649
- called-> getName () == " j0f " || called-> getName () == " y0f" ) {
3647
+ if (funcName == " j0" || funcName == " y0" || funcName == " j0f " ||
3648
+ funcName == " y0f" ) {
3650
3649
if (gutils->knownRecomputeHeuristic .find (orig) !=
3651
3650
gutils->knownRecomputeHeuristic .end ()) {
3652
3651
if (!gutils->knownRecomputeHeuristic [orig]) {
@@ -3667,9 +3666,8 @@ class AdjointGenerator
3667
3666
3668
3667
Value *dx = Builder2.CreateCall (
3669
3668
gutils->oldFunc ->getParent ()->getOrInsertFunction (
3670
- (called->getName ()[0 ] == ' j' )
3671
- ? ((called->getName () == " j0" ) ? " j1" : " j1f" )
3672
- : ((called->getName () == " y0" ) ? " y1" : " y1f" ),
3669
+ (funcName[0 ] == ' j' ) ? ((funcName == " j0" ) ? " j1" : " j1f" )
3670
+ : ((funcName == " y0" ) ? " y1" : " y1f" ),
3673
3671
called->getFunctionType ()),
3674
3672
std::vector<Value *>({x}));
3675
3673
dx = Builder2.CreateFNeg (dx);
@@ -3679,8 +3677,8 @@ class AdjointGenerator
3679
3677
return ;
3680
3678
}
3681
3679
3682
- if (called-> getName () == " j1" || called-> getName () == " y1" ||
3683
- called-> getName () == " j1f " || called-> getName () == " y1f" ) {
3680
+ if (funcName == " j1" || funcName == " y1" || funcName == " j1f " ||
3681
+ funcName == " y1f" ) {
3684
3682
if (gutils->knownRecomputeHeuristic .find (orig) !=
3685
3683
gutils->knownRecomputeHeuristic .end ()) {
3686
3684
if (!gutils->knownRecomputeHeuristic [orig]) {
@@ -3701,9 +3699,8 @@ class AdjointGenerator
3701
3699
3702
3700
Value *d0 = Builder2.CreateCall (
3703
3701
gutils->oldFunc ->getParent ()->getOrInsertFunction (
3704
- (called->getName ()[0 ] == ' j' )
3705
- ? ((called->getName () == " j1" ) ? " j0" : " j0f" )
3706
- : ((called->getName () == " y1" ) ? " y0" : " y0f" ),
3702
+ (funcName[0 ] == ' j' ) ? ((funcName == " j1" ) ? " j0" : " j0f" )
3703
+ : ((funcName == " y1" ) ? " y0" : " y0f" ),
3707
3704
called->getFunctionType ()),
3708
3705
std::vector<Value *>({x}));
3709
3706
@@ -3713,9 +3710,8 @@ class AdjointGenerator
3713
3710
auto FT2 = FunctionType::get (x->getType (), pargs, false );
3714
3711
Value *d2 = Builder2.CreateCall (
3715
3712
gutils->oldFunc ->getParent ()->getOrInsertFunction (
3716
- (called->getName ()[0 ] == ' j' )
3717
- ? ((called->getName () == " j1" ) ? " jn" : " jnf" )
3718
- : ((called->getName () == " y1" ) ? " yn" : " ynf" ),
3713
+ (funcName[0 ] == ' j' ) ? ((funcName == " j1" ) ? " jn" : " jnf" )
3714
+ : ((funcName == " y1" ) ? " yn" : " ynf" ),
3719
3715
FT2),
3720
3716
std::vector<Value *>({ConstantInt::get (intType, 2 ), x}));
3721
3717
Value *dx = Builder2.CreateFSub (d0, d2);
@@ -3726,8 +3722,8 @@ class AdjointGenerator
3726
3722
return ;
3727
3723
}
3728
3724
3729
- if (called-> getName () == " jn" || called-> getName () == " yn" ||
3730
- called-> getName () == " jnf " || called-> getName () == " ynf" ) {
3725
+ if (funcName == " jn" || funcName == " yn" || funcName == " jnf " ||
3726
+ funcName == " ynf" ) {
3731
3727
if (gutils->knownRecomputeHeuristic .find (orig) !=
3732
3728
gutils->knownRecomputeHeuristic .end ()) {
3733
3729
if (!gutils->knownRecomputeHeuristic [orig]) {
@@ -3768,7 +3764,7 @@ class AdjointGenerator
3768
3764
return ;
3769
3765
}
3770
3766
3771
- if (called-> getName () == " julia.write_barrier" ) {
3767
+ if (funcName == " julia.write_barrier" ) {
3772
3768
if (Mode == DerivativeMode::ReverseModeGradient) {
3773
3769
eraseIfUnused (*orig, /* erase*/ true , /* check*/ false );
3774
3770
return ;
@@ -3788,7 +3784,7 @@ class AdjointGenerator
3788
3784
return ;
3789
3785
}
3790
3786
Intrinsic::ID ID = Intrinsic::not_intrinsic;
3791
- if (isMemFreeLibMFunction (called-> getName () , &ID)) {
3787
+ if (isMemFreeLibMFunction (funcName , &ID)) {
3792
3788
if (Mode == DerivativeMode::ReverseModePrimal ||
3793
3789
gutils->isConstantInstruction (orig)) {
3794
3790
eraseIfUnused (*orig);
@@ -3804,7 +3800,7 @@ class AdjointGenerator
3804
3800
return ;
3805
3801
}
3806
3802
}
3807
- if (called-> getName () == " __fd_sincos_1" ) {
3803
+ if (funcName == " __fd_sincos_1" ) {
3808
3804
if (gutils->knownRecomputeHeuristic .find (orig) !=
3809
3805
gutils->knownRecomputeHeuristic .end ()) {
3810
3806
if (!gutils->knownRecomputeHeuristic [orig]) {
@@ -3847,8 +3843,7 @@ class AdjointGenerator
3847
3843
addToDiffe (orig->getArgOperand (0 ), dif0, Builder2, x->getType ());
3848
3844
return ;
3849
3845
}
3850
- if (called->getName () == " cabs" || called->getName () == " cabsf" ||
3851
- called->getName () == " cabsl" ) {
3846
+ if (funcName == " cabs" || funcName == " cabsf" || funcName == " cabsl" ) {
3852
3847
if (gutils->knownRecomputeHeuristic .find (orig) !=
3853
3848
gutils->knownRecomputeHeuristic .end ()) {
3854
3849
if (!gutils->knownRecomputeHeuristic [orig]) {
@@ -3891,8 +3886,8 @@ class AdjointGenerator
3891
3886
llvm_unreachable (" unknown calling convention found for cabs" );
3892
3887
}
3893
3888
}
3894
- if (called-> getName () == " ldexp" || called-> getName () == " ldexpf" ||
3895
- called-> getName () == " ldexpl" ) {
3889
+ if (funcName == " ldexp" || funcName == " ldexpf" ||
3890
+ funcName == " ldexpl" ) {
3896
3891
if (gutils->knownRecomputeHeuristic .find (orig) !=
3897
3892
gutils->knownRecomputeHeuristic .end ()) {
3898
3893
if (!gutils->knownRecomputeHeuristic [orig]) {
@@ -3923,10 +3918,11 @@ class AdjointGenerator
3923
3918
}
3924
3919
}
3925
3920
3926
- if (n == " lgamma" || n == " lgammaf" || n == " lgammal" ||
3927
- n == " lgamma_r" || n == " lgammaf_r" || n == " lgammal_r" ||
3928
- n == " __lgamma_r_finite" || n == " __lgammaf_r_finite" ||
3929
- n == " __lgammal_r_finite" ) {
3921
+ if (funcName == " lgamma" || funcName == " lgammaf" ||
3922
+ funcName == " lgammal" || funcName == " lgamma_r" ||
3923
+ funcName == " lgammaf_r" || funcName == " lgammal_r" ||
3924
+ funcName == " __lgamma_r_finite" || funcName == " __lgammaf_r_finite" ||
3925
+ funcName == " __lgammal_r_finite" ) {
3930
3926
if (gutils->knownRecomputeHeuristic .find (orig) !=
3931
3927
gutils->knownRecomputeHeuristic .end ()) {
3932
3928
if (!gutils->knownRecomputeHeuristic [orig]) {
@@ -4004,7 +4000,7 @@ class AdjointGenerator
4004
4000
return ;
4005
4001
}
4006
4002
4007
- if (called && called-> getName () == " julia.pointer_from_objref" ) {
4003
+ if (funcName == " julia.pointer_from_objref" ) {
4008
4004
eraseIfUnused (*orig);
4009
4005
if (gutils->isConstantValue (orig))
4010
4006
return ;
@@ -4024,7 +4020,7 @@ class AdjointGenerator
4024
4020
return ;
4025
4021
}
4026
4022
4027
- if (called && called-> getName () == " posix_memalign" ) {
4023
+ if (funcName == " posix_memalign" ) {
4028
4024
if (gutils->invertedPointers .count (orig)) {
4029
4025
auto placeholder = cast<PHINode>(gutils->invertedPointers [orig]);
4030
4026
gutils->invertedPointers .erase (orig);
0 commit comments