Skip to content

Commit b690fac

Browse files
committed
Fix format
1 parent 3500823 commit b690fac

File tree

4 files changed

+58
-41
lines changed

4 files changed

+58
-41
lines changed

enzyme/Enzyme/AdjointGenerator.h

Lines changed: 52 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -3390,7 +3390,8 @@ class AdjointGenerator
33903390
if (called &&
33913391
(called->getName() == "asin" || called->getName() == "asinf" ||
33923392
called->getName() == "asinl")) {
3393-
if (gutils->knownRecomputeHeuristic.find(orig) != gutils->knownRecomputeHeuristic.end()) {
3393+
if (gutils->knownRecomputeHeuristic.find(orig) !=
3394+
gutils->knownRecomputeHeuristic.end()) {
33943395
if (!gutils->knownRecomputeHeuristic[orig]) {
33953396
gutils->cacheForReverse(BuilderZ, gutils->getNewFromOriginal(&call),
33963397
getIndex(orig, CacheType::Self));
@@ -3424,7 +3425,8 @@ class AdjointGenerator
34243425
(called->getName() == "atan" || called->getName() == "atanf" ||
34253426
called->getName() == "atanl" ||
34263427
called->getName() == "__fd_atan_1")) {
3427-
if (gutils->knownRecomputeHeuristic.find(orig) != gutils->knownRecomputeHeuristic.end()) {
3428+
if (gutils->knownRecomputeHeuristic.find(orig) !=
3429+
gutils->knownRecomputeHeuristic.end()) {
34283430
if (!gutils->knownRecomputeHeuristic[orig]) {
34293431
gutils->cacheForReverse(BuilderZ, gutils->getNewFromOriginal(&call),
34303432
getIndex(orig, CacheType::Self));
@@ -3448,7 +3450,8 @@ class AdjointGenerator
34483450

34493451
if (called &&
34503452
(called->getName() == "tanhf" || called->getName() == "tanh")) {
3451-
if (gutils->knownRecomputeHeuristic.find(orig) != gutils->knownRecomputeHeuristic.end()) {
3453+
if (gutils->knownRecomputeHeuristic.find(orig) !=
3454+
gutils->knownRecomputeHeuristic.end()) {
34523455
if (!gutils->knownRecomputeHeuristic[orig]) {
34533456
gutils->cacheForReverse(BuilderZ, gutils->getNewFromOriginal(&call),
34543457
getIndex(orig, CacheType::Self));
@@ -3477,7 +3480,8 @@ class AdjointGenerator
34773480
}
34783481

34793482
if (called->getName() == "coshf" || called->getName() == "cosh") {
3480-
if (gutils->knownRecomputeHeuristic.find(orig) != gutils->knownRecomputeHeuristic.end()) {
3483+
if (gutils->knownRecomputeHeuristic.find(orig) !=
3484+
gutils->knownRecomputeHeuristic.end()) {
34813485
if (!gutils->knownRecomputeHeuristic[orig]) {
34823486
gutils->cacheForReverse(BuilderZ, gutils->getNewFromOriginal(&call),
34833487
getIndex(orig, CacheType::Self));
@@ -3504,7 +3508,8 @@ class AdjointGenerator
35043508
return;
35053509
}
35063510
if (called->getName() == "sinhf" || called->getName() == "sinh") {
3507-
if (gutils->knownRecomputeHeuristic.find(orig) != gutils->knownRecomputeHeuristic.end()) {
3511+
if (gutils->knownRecomputeHeuristic.find(orig) !=
3512+
gutils->knownRecomputeHeuristic.end()) {
35083513
if (!gutils->knownRecomputeHeuristic[orig]) {
35093514
gutils->cacheForReverse(BuilderZ, gutils->getNewFromOriginal(&call),
35103515
getIndex(orig, CacheType::Self));
@@ -3533,9 +3538,11 @@ class AdjointGenerator
35333538

35343539
if (called) {
35353540
if (called->getName() == "erf") {
3536-
if (gutils->knownRecomputeHeuristic.find(orig) != gutils->knownRecomputeHeuristic.end()) {
3541+
if (gutils->knownRecomputeHeuristic.find(orig) !=
3542+
gutils->knownRecomputeHeuristic.end()) {
35373543
if (!gutils->knownRecomputeHeuristic[orig]) {
3538-
gutils->cacheForReverse(BuilderZ, gutils->getNewFromOriginal(&call),
3544+
gutils->cacheForReverse(BuilderZ,
3545+
gutils->getNewFromOriginal(&call),
35393546
getIndex(orig, CacheType::Self));
35403547
}
35413548
}
@@ -3565,9 +3572,11 @@ class AdjointGenerator
35653572
return;
35663573
}
35673574
if (called->getName() == "erfi") {
3568-
if (gutils->knownRecomputeHeuristic.find(orig) != gutils->knownRecomputeHeuristic.end()) {
3575+
if (gutils->knownRecomputeHeuristic.find(orig) !=
3576+
gutils->knownRecomputeHeuristic.end()) {
35693577
if (!gutils->knownRecomputeHeuristic[orig]) {
3570-
gutils->cacheForReverse(BuilderZ, gutils->getNewFromOriginal(&call),
3578+
gutils->cacheForReverse(BuilderZ,
3579+
gutils->getNewFromOriginal(&call),
35713580
getIndex(orig, CacheType::Self));
35723581
}
35733582
}
@@ -3597,9 +3606,11 @@ class AdjointGenerator
35973606
return;
35983607
}
35993608
if (called->getName() == "erfc") {
3600-
if (gutils->knownRecomputeHeuristic.find(orig) != gutils->knownRecomputeHeuristic.end()) {
3609+
if (gutils->knownRecomputeHeuristic.find(orig) !=
3610+
gutils->knownRecomputeHeuristic.end()) {
36013611
if (!gutils->knownRecomputeHeuristic[orig]) {
3602-
gutils->cacheForReverse(BuilderZ, gutils->getNewFromOriginal(&call),
3612+
gutils->cacheForReverse(BuilderZ,
3613+
gutils->getNewFromOriginal(&call),
36033614
getIndex(orig, CacheType::Self));
36043615
}
36053616
}
@@ -3631,9 +3642,11 @@ class AdjointGenerator
36313642

36323643
if (called->getName() == "j0" || called->getName() == "y0" ||
36333644
called->getName() == "j0f" || called->getName() == "y0f") {
3634-
if (gutils->knownRecomputeHeuristic.find(orig) != gutils->knownRecomputeHeuristic.end()) {
3645+
if (gutils->knownRecomputeHeuristic.find(orig) !=
3646+
gutils->knownRecomputeHeuristic.end()) {
36353647
if (!gutils->knownRecomputeHeuristic[orig]) {
3636-
gutils->cacheForReverse(BuilderZ, gutils->getNewFromOriginal(&call),
3648+
gutils->cacheForReverse(BuilderZ,
3649+
gutils->getNewFromOriginal(&call),
36373650
getIndex(orig, CacheType::Self));
36383651
}
36393652
}
@@ -3663,9 +3676,11 @@ class AdjointGenerator
36633676

36643677
if (called->getName() == "j1" || called->getName() == "y1" ||
36653678
called->getName() == "j1f" || called->getName() == "y1f") {
3666-
if (gutils->knownRecomputeHeuristic.find(orig) != gutils->knownRecomputeHeuristic.end()) {
3679+
if (gutils->knownRecomputeHeuristic.find(orig) !=
3680+
gutils->knownRecomputeHeuristic.end()) {
36673681
if (!gutils->knownRecomputeHeuristic[orig]) {
3668-
gutils->cacheForReverse(BuilderZ, gutils->getNewFromOriginal(&call),
3682+
gutils->cacheForReverse(BuilderZ,
3683+
gutils->getNewFromOriginal(&call),
36693684
getIndex(orig, CacheType::Self));
36703685
}
36713686
}
@@ -3708,9 +3723,11 @@ class AdjointGenerator
37083723

37093724
if (called->getName() == "jn" || called->getName() == "yn" ||
37103725
called->getName() == "jnf" || called->getName() == "ynf") {
3711-
if (gutils->knownRecomputeHeuristic.find(orig) != gutils->knownRecomputeHeuristic.end()) {
3726+
if (gutils->knownRecomputeHeuristic.find(orig) !=
3727+
gutils->knownRecomputeHeuristic.end()) {
37123728
if (!gutils->knownRecomputeHeuristic[orig]) {
3713-
gutils->cacheForReverse(BuilderZ, gutils->getNewFromOriginal(&call),
3729+
gutils->cacheForReverse(BuilderZ,
3730+
gutils->getNewFromOriginal(&call),
37143731
getIndex(orig, CacheType::Self));
37153732
}
37163733
}
@@ -3783,9 +3800,11 @@ class AdjointGenerator
37833800
}
37843801
}
37853802
if (called->getName() == "__fd_sincos_1") {
3786-
if (gutils->knownRecomputeHeuristic.find(orig) != gutils->knownRecomputeHeuristic.end()) {
3803+
if (gutils->knownRecomputeHeuristic.find(orig) !=
3804+
gutils->knownRecomputeHeuristic.end()) {
37873805
if (!gutils->knownRecomputeHeuristic[orig]) {
3788-
gutils->cacheForReverse(BuilderZ, gutils->getNewFromOriginal(&call),
3806+
gutils->cacheForReverse(BuilderZ,
3807+
gutils->getNewFromOriginal(&call),
37893808
getIndex(orig, CacheType::Self));
37903809
}
37913810
}
@@ -3825,9 +3844,11 @@ class AdjointGenerator
38253844
}
38263845
if (called->getName() == "cabs" || called->getName() == "cabsf" ||
38273846
called->getName() == "cabsl") {
3828-
if (gutils->knownRecomputeHeuristic.find(orig) != gutils->knownRecomputeHeuristic.end()) {
3847+
if (gutils->knownRecomputeHeuristic.find(orig) !=
3848+
gutils->knownRecomputeHeuristic.end()) {
38293849
if (!gutils->knownRecomputeHeuristic[orig]) {
3830-
gutils->cacheForReverse(BuilderZ, gutils->getNewFromOriginal(&call),
3850+
gutils->cacheForReverse(BuilderZ,
3851+
gutils->getNewFromOriginal(&call),
38313852
getIndex(orig, CacheType::Self));
38323853
}
38333854
}
@@ -3867,9 +3888,11 @@ class AdjointGenerator
38673888
}
38683889
if (called->getName() == "ldexp" || called->getName() == "ldexpf" ||
38693890
called->getName() == "ldexpl") {
3870-
if (gutils->knownRecomputeHeuristic.find(orig) != gutils->knownRecomputeHeuristic.end()) {
3891+
if (gutils->knownRecomputeHeuristic.find(orig) !=
3892+
gutils->knownRecomputeHeuristic.end()) {
38713893
if (!gutils->knownRecomputeHeuristic[orig]) {
3872-
gutils->cacheForReverse(BuilderZ, gutils->getNewFromOriginal(&call),
3894+
gutils->cacheForReverse(BuilderZ,
3895+
gutils->getNewFromOriginal(&call),
38733896
getIndex(orig, CacheType::Self));
38743897
}
38753898
}
@@ -3899,7 +3922,8 @@ class AdjointGenerator
38993922
n == "lgamma_r" || n == "lgammaf_r" || n == "lgammal_r" ||
39003923
n == "__lgamma_r_finite" || n == "__lgammaf_r_finite" ||
39013924
n == "__lgammal_r_finite") {
3902-
if (gutils->knownRecomputeHeuristic.find(orig) != gutils->knownRecomputeHeuristic.end()) {
3925+
if (gutils->knownRecomputeHeuristic.find(orig) !=
3926+
gutils->knownRecomputeHeuristic.end()) {
39033927
if (!gutils->knownRecomputeHeuristic[orig]) {
39043928
gutils->cacheForReverse(BuilderZ, gutils->getNewFromOriginal(&call),
39053929
getIndex(orig, CacheType::Self));
@@ -4139,14 +4163,16 @@ class AdjointGenerator
41394163
// gutils->isConstantValue(orig) << " subretused=" << subretused << " ivn:"
41404164
// << is_value_needed_in_reverse<Primal>(TR, gutils, &call, /*topLevel*/Mode
41414165
// == DerivativeMode::Both) << "\n";
4142-
if (gutils->knownRecomputeHeuristic.find(orig) != gutils->knownRecomputeHeuristic.end()) {
4166+
if (gutils->knownRecomputeHeuristic.find(orig) !=
4167+
gutils->knownRecomputeHeuristic.end()) {
41434168
if (!gutils->knownRecomputeHeuristic[orig]) {
41444169
subretused = true;
41454170
}
41464171
}
41474172

41484173
if (gutils->isConstantInstruction(orig) && gutils->isConstantValue(orig)) {
4149-
if (gutils->knownRecomputeHeuristic.find(orig) != gutils->knownRecomputeHeuristic.end()) {
4174+
if (gutils->knownRecomputeHeuristic.find(orig) !=
4175+
gutils->knownRecomputeHeuristic.end()) {
41504176
if (!gutils->knownRecomputeHeuristic[orig]) {
41514177
gutils->cacheForReverse(BuilderZ, gutils->getNewFromOriginal(&call),
41524178
getIndex(orig, CacheType::Self));

enzyme/Enzyme/DifferentialUseAnalysis.h

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -476,19 +476,15 @@ static inline void minCut(const DataLayout &DL, LoopInfo &OrigLI,
476476
OrigLI.getLoopFor(cast<Instruction>(V)->getParent()),
477477
OrigLI.getLoopFor(
478478
cast<Instruction>(((*found->second.begin()).V))->getParent()));
479-
// llvm::errs() << " considering cache " << *V << " vs " << " " <<
480-
// *(*found->second.begin()).V << " potentiallyRecursive: " <<
481-
// (int)potentiallyRecursive << " cmpLoopNest: " <<moreOuterLoop << "\n";
482479
if (potentiallyRecursive)
483480
continue;
484481
if (moreOuterLoop == -1)
485482
continue;
486483
if (moreOuterLoop == 1 ||
487-
moreOuterLoop == 0 &&
488-
DL.getTypeSizeInBits(V->getType()) >=
489-
DL.getTypeSizeInBits((*found->second.begin()).V->getType())) {
484+
(moreOuterLoop == 0 &&
485+
DL.getTypeSizeInBits(V->getType()) >=
486+
DL.getTypeSizeInBits((*found->second.begin()).V->getType()))) {
490487
MinReq.erase(V);
491-
// llvm::errs() << " - moved!\n";
492488
MinReq.insert((*found->second.begin()).V);
493489
todo.push_back((*found->second.begin()).V);
494490
}

enzyme/Enzyme/EnzymeLogic.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -818,7 +818,7 @@ void calculateUnusedValuesInFunction(
818818
}
819819
return UseReq::Recur;
820820
});
821-
#if 1
821+
#if 0
822822
llvm::errs() << "unnecessaryValues of " << func.getName() << ":\n";
823823
for (auto a : unnecessaryValues) {
824824
llvm::errs() << *a << "\n";

enzyme/Enzyme/GradientUtils.cpp

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,8 @@ Value *GradientUtils::unwrapM(Value *const val, IRBuilder<> &BuilderM,
154154
Value *orig = isOriginal(val);
155155
if (orig &&
156156
knownRecomputeHeuristic.find(orig) != knownRecomputeHeuristic.end()) {
157-
if (!knownRecomputeHeuristic[orig] && !legalRecompute(orig, available, &BuilderM)) {
157+
if (!knownRecomputeHeuristic[orig] &&
158+
!legalRecompute(orig, available, &BuilderM)) {
158159
return nullptr;
159160
}
160161
}
@@ -4341,8 +4342,6 @@ void GradientUtils::computeMinCache(
43414342
TR, this, &I,
43424343
/*topLevel*/ mode == DerivativeMode::ReverseModeCombined,
43434344
OneLevelSeen, guaranteedUnreachable);
4344-
llvm::errs() << " not legal recompute: " << I << " oneneed: " <<
4345-
(int)oneneed << "\n";
43464345
if (oneneed)
43474346
knownRecomputeHeuristic[&I] = false;
43484347
else
@@ -4391,7 +4390,6 @@ void GradientUtils::computeMinCache(
43914390
TR, this, V,
43924391
/*topLevel*/ mode == DerivativeMode::ReverseModeCombined,
43934392
OneLevelSeen, guaranteedUnreachable)) {
4394-
llvm::errs() << " Required: " << *V << "\n";
43954393
Required.insert(V);
43964394
} else {
43974395
for (auto V2 : V->users()) {
@@ -4422,11 +4420,8 @@ void GradientUtils::computeMinCache(
44224420
}
44234421

44244422
for (auto V : Intermediates) {
4425-
llvm::errs() << " int: " << *V << " minreq: " << (int)MinReq.count(V)
4426-
<< "\n";
44274423
knownRecomputeHeuristic[V] = !MinReq.count(V);
44284424
if (!NeedGraph.count(V)) {
4429-
llvm::errs() << " ++ unnecessary\n";
44304425
unnecessaryIntermediates.insert(cast<Instruction>(V));
44314426
}
44324427
}

0 commit comments

Comments
 (0)