@@ -3390,6 +3390,12 @@ class AdjointGenerator
3390
3390
if (called &&
3391
3391
(called->getName () == " asin" || called->getName () == " asinf" ||
3392
3392
called->getName () == " asinl" )) {
3393
+ if (gutils->knownRecomputeHeuristic .find (orig) != gutils->knownRecomputeHeuristic .end ()) {
3394
+ if (!gutils->knownRecomputeHeuristic [orig]) {
3395
+ gutils->cacheForReverse (BuilderZ, gutils->getNewFromOriginal (&call),
3396
+ getIndex (orig, CacheType::Self));
3397
+ }
3398
+ }
3393
3399
eraseIfUnused (*orig);
3394
3400
if (Mode == DerivativeMode::ReverseModePrimal ||
3395
3401
gutils->isConstantInstruction (orig))
@@ -3418,6 +3424,12 @@ class AdjointGenerator
3418
3424
(called->getName () == " atan" || called->getName () == " atanf" ||
3419
3425
called->getName () == " atanl" ||
3420
3426
called->getName () == " __fd_atan_1" )) {
3427
+ if (gutils->knownRecomputeHeuristic .find (orig) != gutils->knownRecomputeHeuristic .end ()) {
3428
+ if (!gutils->knownRecomputeHeuristic [orig]) {
3429
+ gutils->cacheForReverse (BuilderZ, gutils->getNewFromOriginal (&call),
3430
+ getIndex (orig, CacheType::Self));
3431
+ }
3432
+ }
3421
3433
eraseIfUnused (*orig);
3422
3434
if (Mode == DerivativeMode::ReverseModePrimal ||
3423
3435
gutils->isConstantInstruction (orig))
@@ -3436,6 +3448,12 @@ class AdjointGenerator
3436
3448
3437
3449
if (called &&
3438
3450
(called->getName () == " tanhf" || called->getName () == " tanh" )) {
3451
+ if (gutils->knownRecomputeHeuristic .find (orig) != gutils->knownRecomputeHeuristic .end ()) {
3452
+ if (!gutils->knownRecomputeHeuristic [orig]) {
3453
+ gutils->cacheForReverse (BuilderZ, gutils->getNewFromOriginal (&call),
3454
+ getIndex (orig, CacheType::Self));
3455
+ }
3456
+ }
3439
3457
eraseIfUnused (*orig);
3440
3458
if (Mode == DerivativeMode::ReverseModePrimal ||
3441
3459
gutils->isConstantInstruction (orig))
@@ -3459,6 +3477,12 @@ class AdjointGenerator
3459
3477
}
3460
3478
3461
3479
if (called->getName () == " coshf" || called->getName () == " cosh" ) {
3480
+ if (gutils->knownRecomputeHeuristic .find (orig) != gutils->knownRecomputeHeuristic .end ()) {
3481
+ if (!gutils->knownRecomputeHeuristic [orig]) {
3482
+ gutils->cacheForReverse (BuilderZ, gutils->getNewFromOriginal (&call),
3483
+ getIndex (orig, CacheType::Self));
3484
+ }
3485
+ }
3462
3486
eraseIfUnused (*orig);
3463
3487
if (Mode == DerivativeMode::ReverseModePrimal ||
3464
3488
gutils->isConstantInstruction (orig))
@@ -3480,6 +3504,12 @@ class AdjointGenerator
3480
3504
return ;
3481
3505
}
3482
3506
if (called->getName () == " sinhf" || called->getName () == " sinh" ) {
3507
+ if (gutils->knownRecomputeHeuristic .find (orig) != gutils->knownRecomputeHeuristic .end ()) {
3508
+ if (!gutils->knownRecomputeHeuristic [orig]) {
3509
+ gutils->cacheForReverse (BuilderZ, gutils->getNewFromOriginal (&call),
3510
+ getIndex (orig, CacheType::Self));
3511
+ }
3512
+ }
3483
3513
eraseIfUnused (*orig);
3484
3514
if (Mode == DerivativeMode::ReverseModePrimal ||
3485
3515
gutils->isConstantInstruction (orig))
@@ -3503,6 +3533,12 @@ class AdjointGenerator
3503
3533
3504
3534
if (called) {
3505
3535
if (called->getName () == " erf" ) {
3536
+ if (gutils->knownRecomputeHeuristic .find (orig) != gutils->knownRecomputeHeuristic .end ()) {
3537
+ if (!gutils->knownRecomputeHeuristic [orig]) {
3538
+ gutils->cacheForReverse (BuilderZ, gutils->getNewFromOriginal (&call),
3539
+ getIndex (orig, CacheType::Self));
3540
+ }
3541
+ }
3506
3542
eraseIfUnused (*orig);
3507
3543
if (Mode == DerivativeMode::ReverseModePrimal ||
3508
3544
gutils->isConstantInstruction (orig))
@@ -3529,6 +3565,12 @@ class AdjointGenerator
3529
3565
return ;
3530
3566
}
3531
3567
if (called->getName () == " erfi" ) {
3568
+ if (gutils->knownRecomputeHeuristic .find (orig) != gutils->knownRecomputeHeuristic .end ()) {
3569
+ if (!gutils->knownRecomputeHeuristic [orig]) {
3570
+ gutils->cacheForReverse (BuilderZ, gutils->getNewFromOriginal (&call),
3571
+ getIndex (orig, CacheType::Self));
3572
+ }
3573
+ }
3532
3574
eraseIfUnused (*orig);
3533
3575
if (Mode == DerivativeMode::ReverseModePrimal ||
3534
3576
gutils->isConstantInstruction (orig))
@@ -3555,6 +3597,12 @@ class AdjointGenerator
3555
3597
return ;
3556
3598
}
3557
3599
if (called->getName () == " erfc" ) {
3600
+ if (gutils->knownRecomputeHeuristic .find (orig) != gutils->knownRecomputeHeuristic .end ()) {
3601
+ if (!gutils->knownRecomputeHeuristic [orig]) {
3602
+ gutils->cacheForReverse (BuilderZ, gutils->getNewFromOriginal (&call),
3603
+ getIndex (orig, CacheType::Self));
3604
+ }
3605
+ }
3558
3606
eraseIfUnused (*orig);
3559
3607
if (Mode == DerivativeMode::ReverseModePrimal ||
3560
3608
gutils->isConstantInstruction (orig))
@@ -3583,6 +3631,12 @@ class AdjointGenerator
3583
3631
3584
3632
if (called->getName () == " j0" || called->getName () == " y0" ||
3585
3633
called->getName () == " j0f" || called->getName () == " y0f" ) {
3634
+ if (gutils->knownRecomputeHeuristic .find (orig) != gutils->knownRecomputeHeuristic .end ()) {
3635
+ if (!gutils->knownRecomputeHeuristic [orig]) {
3636
+ gutils->cacheForReverse (BuilderZ, gutils->getNewFromOriginal (&call),
3637
+ getIndex (orig, CacheType::Self));
3638
+ }
3639
+ }
3586
3640
eraseIfUnused (*orig);
3587
3641
if (Mode == DerivativeMode::ReverseModePrimal ||
3588
3642
gutils->isConstantInstruction (orig))
@@ -3609,6 +3663,12 @@ class AdjointGenerator
3609
3663
3610
3664
if (called->getName () == " j1" || called->getName () == " y1" ||
3611
3665
called->getName () == " j1f" || called->getName () == " y1f" ) {
3666
+ if (gutils->knownRecomputeHeuristic .find (orig) != gutils->knownRecomputeHeuristic .end ()) {
3667
+ if (!gutils->knownRecomputeHeuristic [orig]) {
3668
+ gutils->cacheForReverse (BuilderZ, gutils->getNewFromOriginal (&call),
3669
+ getIndex (orig, CacheType::Self));
3670
+ }
3671
+ }
3612
3672
eraseIfUnused (*orig);
3613
3673
if (Mode == DerivativeMode::ReverseModePrimal ||
3614
3674
gutils->isConstantInstruction (orig))
@@ -3648,6 +3708,12 @@ class AdjointGenerator
3648
3708
3649
3709
if (called->getName () == " jn" || called->getName () == " yn" ||
3650
3710
called->getName () == " jnf" || called->getName () == " ynf" ) {
3711
+ if (gutils->knownRecomputeHeuristic .find (orig) != gutils->knownRecomputeHeuristic .end ()) {
3712
+ if (!gutils->knownRecomputeHeuristic [orig]) {
3713
+ gutils->cacheForReverse (BuilderZ, gutils->getNewFromOriginal (&call),
3714
+ getIndex (orig, CacheType::Self));
3715
+ }
3716
+ }
3651
3717
eraseIfUnused (*orig);
3652
3718
if (Mode == DerivativeMode::ReverseModePrimal ||
3653
3719
gutils->isConstantInstruction (orig))
@@ -3717,6 +3783,12 @@ class AdjointGenerator
3717
3783
}
3718
3784
}
3719
3785
if (called->getName () == " __fd_sincos_1" ) {
3786
+ if (gutils->knownRecomputeHeuristic .find (orig) != gutils->knownRecomputeHeuristic .end ()) {
3787
+ if (!gutils->knownRecomputeHeuristic [orig]) {
3788
+ gutils->cacheForReverse (BuilderZ, gutils->getNewFromOriginal (&call),
3789
+ getIndex (orig, CacheType::Self));
3790
+ }
3791
+ }
3720
3792
if (Mode == DerivativeMode::ReverseModePrimal ||
3721
3793
gutils->isConstantInstruction (orig)) {
3722
3794
eraseIfUnused (*orig);
@@ -3753,6 +3825,12 @@ class AdjointGenerator
3753
3825
}
3754
3826
if (called->getName () == " cabs" || called->getName () == " cabsf" ||
3755
3827
called->getName () == " cabsl" ) {
3828
+ if (gutils->knownRecomputeHeuristic .find (orig) != gutils->knownRecomputeHeuristic .end ()) {
3829
+ if (!gutils->knownRecomputeHeuristic [orig]) {
3830
+ gutils->cacheForReverse (BuilderZ, gutils->getNewFromOriginal (&call),
3831
+ getIndex (orig, CacheType::Self));
3832
+ }
3833
+ }
3756
3834
if (Mode == DerivativeMode::ReverseModePrimal ||
3757
3835
gutils->isConstantInstruction (orig)) {
3758
3836
eraseIfUnused (*orig);
@@ -3789,6 +3867,12 @@ class AdjointGenerator
3789
3867
}
3790
3868
if (called->getName () == " ldexp" || called->getName () == " ldexpf" ||
3791
3869
called->getName () == " ldexpl" ) {
3870
+ if (gutils->knownRecomputeHeuristic .find (orig) != gutils->knownRecomputeHeuristic .end ()) {
3871
+ if (!gutils->knownRecomputeHeuristic [orig]) {
3872
+ gutils->cacheForReverse (BuilderZ, gutils->getNewFromOriginal (&call),
3873
+ getIndex (orig, CacheType::Self));
3874
+ }
3875
+ }
3792
3876
if (Mode == DerivativeMode::ReverseModePrimal ||
3793
3877
gutils->isConstantInstruction (orig)) {
3794
3878
eraseIfUnused (*orig);
@@ -3815,6 +3899,12 @@ class AdjointGenerator
3815
3899
n == " lgamma_r" || n == " lgammaf_r" || n == " lgammal_r" ||
3816
3900
n == " __lgamma_r_finite" || n == " __lgammaf_r_finite" ||
3817
3901
n == " __lgammal_r_finite" ) {
3902
+ if (gutils->knownRecomputeHeuristic .find (orig) != gutils->knownRecomputeHeuristic .end ()) {
3903
+ if (!gutils->knownRecomputeHeuristic [orig]) {
3904
+ gutils->cacheForReverse (BuilderZ, gutils->getNewFromOriginal (&call),
3905
+ getIndex (orig, CacheType::Self));
3906
+ }
3907
+ }
3818
3908
if (Mode == DerivativeMode::ReverseModePrimal ||
3819
3909
gutils->isConstantInstruction (orig)) {
3820
3910
return ;
@@ -4049,8 +4139,20 @@ class AdjointGenerator
4049
4139
// gutils->isConstantValue(orig) << " subretused=" << subretused << " ivn:"
4050
4140
// << is_value_needed_in_reverse<Primal>(TR, gutils, &call, /*topLevel*/Mode
4051
4141
// == DerivativeMode::Both) << "\n";
4142
+ if (gutils->knownRecomputeHeuristic .find (orig) != gutils->knownRecomputeHeuristic .end ()) {
4143
+ if (!gutils->knownRecomputeHeuristic [orig]) {
4144
+ subretused = true ;
4145
+ }
4146
+ }
4052
4147
4053
4148
if (gutils->isConstantInstruction (orig) && gutils->isConstantValue (orig)) {
4149
+ if (gutils->knownRecomputeHeuristic .find (orig) != gutils->knownRecomputeHeuristic .end ()) {
4150
+ if (!gutils->knownRecomputeHeuristic [orig]) {
4151
+ gutils->cacheForReverse (BuilderZ, gutils->getNewFromOriginal (&call),
4152
+ getIndex (orig, CacheType::Self));
4153
+ return ;
4154
+ }
4155
+ }
4054
4156
// If we need this value and it is illegal to recompute it (it writes or
4055
4157
// may load uncacheable data)
4056
4158
// Store and reload it
0 commit comments