Skip to content

Commit 49bf6b5

Browse files
committed
Fix style
1 parent 389a8c9 commit 49bf6b5

File tree

7 files changed

+181
-125
lines changed

7 files changed

+181
-125
lines changed

enzyme/Enzyme/AdjointGenerator.h

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -328,12 +328,12 @@ class AdjointGenerator
328328
is_value_needed_in_reverse<ValueType::Primal>(
329329
TR, gutils, &I,
330330
/*toplevel*/ Mode == DerivativeMode::ReverseModeCombined,
331-
oldUnreachable) )) {
331+
oldUnreachable))) {
332332
if (!gutils->unnecessaryIntermediates.count(&I)) {
333333
IRBuilder<> BuilderZ(gutils->getNewFromOriginal(&I)->getNextNode());
334334
// auto tbaa = inst->getMetadata(LLVMContext::MD_tbaa);
335335
inst = gutils->cacheForReverse(BuilderZ, newi,
336-
getIndex(&I, CacheType::Self));
336+
getIndex(&I, CacheType::Self));
337337
assert(inst->getType() == type);
338338

339339
if (Mode == DerivativeMode::ReverseModeGradient) {
@@ -3850,10 +3850,11 @@ class AdjointGenerator
38503850
// NOTE THAT TOPLEVEL IS THERE SIMPLY BECAUSE THAT WAS PREVIOUS ATTITUTE
38513851
// TO FREE'ing
38523852
if (Mode != DerivativeMode::ReverseModeCombined) {
3853-
if ( (is_value_needed_in_reverse<ValueType::Primal>(
3854-
TR, gutils, orig,
3855-
/*topLevel*/ Mode == DerivativeMode::ReverseModeCombined,
3856-
oldUnreachable) && !gutils->unnecessaryIntermediates.count(orig)) ||
3853+
if ((is_value_needed_in_reverse<ValueType::Primal>(
3854+
TR, gutils, orig,
3855+
/*topLevel*/ Mode == DerivativeMode::ReverseModeCombined,
3856+
oldUnreachable) &&
3857+
!gutils->unnecessaryIntermediates.count(orig)) ||
38573858
hasMetadata(orig, "enzyme_fromstack")) {
38583859
Value *nop = gutils->cacheForReverse(BuilderZ, op,
38593860
getIndex(orig, CacheType::Self));
@@ -4056,8 +4057,10 @@ class AdjointGenerator
40564057
if (Mode != DerivativeMode::ReverseModeCombined && subretused &&
40574058
!orig->doesNotAccessMemory()) {
40584059
if (!gutils->unnecessaryIntermediates.count(orig)) {
4059-
CallInst *const op = cast<CallInst>(gutils->getNewFromOriginal(&call));
4060-
gutils->cacheForReverse(BuilderZ, op, getIndex(orig, CacheType::Self));
4060+
CallInst *const op =
4061+
cast<CallInst>(gutils->getNewFromOriginal(&call));
4062+
gutils->cacheForReverse(BuilderZ, op,
4063+
getIndex(orig, CacheType::Self));
40614064
}
40624065
return;
40634066
}
@@ -4391,7 +4394,8 @@ class AdjointGenerator
43914394
is_value_needed_in_reverse<ValueType::Primal>(
43924395
TR, gutils, orig,
43934396
/*topLevel*/ Mode == DerivativeMode::ReverseModeCombined,
4394-
oldUnreachable) && !gutils->unnecessaryIntermediates.count(orig)) {
4397+
oldUnreachable) &&
4398+
!gutils->unnecessaryIntermediates.count(orig)) {
43954399
gutils->cacheForReverse(BuilderZ, dcall,
43964400
getIndex(orig, CacheType::Self));
43974401
}
@@ -4424,7 +4428,8 @@ class AdjointGenerator
44244428
if (subretused) {
44254429
if (is_value_needed_in_reverse<ValueType::Primal>(
44264430
TR, gutils, orig, Mode == DerivativeMode::ReverseModeCombined,
4427-
oldUnreachable) && !gutils->unnecessaryIntermediates.count(orig)) {
4431+
oldUnreachable) &&
4432+
!gutils->unnecessaryIntermediates.count(orig)) {
44284433
cachereplace = BuilderZ.CreatePHI(orig->getType(), 1,
44294434
orig->getName() + "_tmpcacheB");
44304435
cachereplace = gutils->cacheForReverse(
@@ -4516,7 +4521,8 @@ class AdjointGenerator
45164521
subretused && !orig->doesNotAccessMemory()) {
45174522
if (is_value_needed_in_reverse<ValueType::Primal>(
45184523
TR, gutils, orig, Mode == DerivativeMode::ReverseModeCombined,
4519-
oldUnreachable) && !gutils->unnecessaryIntermediates.count(orig)) {
4524+
oldUnreachable) &&
4525+
!gutils->unnecessaryIntermediates.count(orig)) {
45204526
assert(!replaceFunction);
45214527
cachereplace = BuilderZ.CreatePHI(orig->getType(), 1,
45224528
orig->getName() + "_cachereplace2");

enzyme/Enzyme/DifferentialUseAnalysis.h

Lines changed: 29 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -375,18 +375,21 @@ static inline void bfs(const Graph &G,
375375
// Return 1 if next is better
376376
// 0 if equal
377377
// -1 if prev is better, or unknown
378-
static inline int cmpLoopNest(Loop* prev, Loop* next) {
379-
if (next == prev) return 0;
380-
if (next == nullptr) return 1;
381-
else if (prev == nullptr) return -1;
378+
static inline int cmpLoopNest(Loop *prev, Loop *next) {
379+
if (next == prev)
380+
return 0;
381+
if (next == nullptr)
382+
return 1;
383+
else if (prev == nullptr)
384+
return -1;
382385
for (Loop *L = prev; L != nullptr; L = L->getParentLoop()) {
383-
if (L == next) return 1;
386+
if (L == next)
387+
return 1;
384388
}
385389
return -1;
386390
}
387391

388-
static inline void minCut(const DataLayout &DL,
389-
LoopInfo &OrigLI,
392+
static inline void minCut(const DataLayout &DL, LoopInfo &OrigLI,
390393
const SmallPtrSetImpl<Value *> &Recomputes,
391394
const SmallPtrSetImpl<Value *> &Intermediates,
392395
SmallPtrSetImpl<Value *> &Required,
@@ -465,14 +468,25 @@ static inline void minCut(const DataLayout &DL,
465468
todo.pop_front();
466469
auto found = Orig.find(Node(V, true));
467470
if (found->second.size() == 1 && !Required.count(V)) {
468-
bool potentiallyRecursive = isa<PHINode>((*found->second.begin()).V) && OrigLI.isLoopHeader(cast<PHINode>((*found->second.begin()).V)->getParent());
469-
int moreOuterLoop = cmpLoopNest(OrigLI.getLoopFor(cast<Instruction>(V)->getParent()),
470-
OrigLI.getLoopFor(cast<Instruction>(((*found->second.begin()).V))->getParent()));
471-
// llvm::errs() << " considering cache " << *V << " vs " << " " << *(*found->second.begin()).V << " potentiallyRecursive: " << (int)potentiallyRecursive << " cmpLoopNest: " <<moreOuterLoop << "\n";
472-
if (potentiallyRecursive) continue;
473-
if (moreOuterLoop == -1) continue;
474-
if (moreOuterLoop == 1 || moreOuterLoop == 0 &&
475-
DL.getTypeSizeInBits(V->getType()) >= DL.getTypeSizeInBits((*found->second.begin()).V->getType())) {
471+
bool potentiallyRecursive =
472+
isa<PHINode>((*found->second.begin()).V) &&
473+
OrigLI.isLoopHeader(
474+
cast<PHINode>((*found->second.begin()).V)->getParent());
475+
int moreOuterLoop = cmpLoopNest(
476+
OrigLI.getLoopFor(cast<Instruction>(V)->getParent()),
477+
OrigLI.getLoopFor(
478+
cast<Instruction>(((*found->second.begin()).V))->getParent()));
479+
// llvm::errs() << " considering cache " << *V << " vs " << " " <<
480+
// *(*found->second.begin()).V << " potentiallyRecursive: " <<
481+
// (int)potentiallyRecursive << " cmpLoopNest: " <<moreOuterLoop << "\n";
482+
if (potentiallyRecursive)
483+
continue;
484+
if (moreOuterLoop == -1)
485+
continue;
486+
if (moreOuterLoop == 1 ||
487+
moreOuterLoop == 0 &&
488+
DL.getTypeSizeInBits(V->getType()) >=
489+
DL.getTypeSizeInBits((*found->second.begin()).V->getType())) {
476490
MinReq.erase(V);
477491
// llvm::errs() << " - moved!\n";
478492
MinReq.insert((*found->second.begin()).V);

enzyme/Enzyme/Enzyme.cpp

Lines changed: 72 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,7 @@ class Enzyme : public ModulePass {
148148
Arch == Triple::amdgcn;
149149

150150
std::map<int, Type *> byVal;
151-
llvm::Value* tape = nullptr;
151+
llvm::Value *tape = nullptr;
152152
int allocatedTapeSize = -1;
153153
for (unsigned i = 1; i < CI->getNumArgOperands(); ++i) {
154154
Value *res = CI->getArgOperand(i);
@@ -211,7 +211,8 @@ class Enzyme : public ModulePass {
211211
res = CI->getArgOperand(i);
212212
} else if (MS == "enzyme_allocated") {
213213
++i;
214-
allocatedTapeSize = cast<ConstantInt>(CI->getArgOperand(i))->getSExtValue();
214+
allocatedTapeSize =
215+
cast<ConstantInt>(CI->getArgOperand(i))->getSExtValue();
215216
continue;
216217
} else {
217218
ty = whatType(PTy, mode == DerivativeMode::ForwardMode);
@@ -245,7 +246,8 @@ class Enzyme : public ModulePass {
245246
res = CI->getArgOperand(i);
246247
} else if (MS == "enzyme_allocated") {
247248
++i;
248-
allocatedTapeSize = cast<ConstantInt>(CI->getArgOperand(i))->getSExtValue();
249+
allocatedTapeSize =
250+
cast<ConstantInt>(CI->getArgOperand(i))->getSExtValue();
249251
continue;
250252
} else {
251253
ty = whatType(PTy, mode == DerivativeMode::ForwardMode);
@@ -435,9 +437,11 @@ class Enzyme : public ModulePass {
435437
}
436438

437439
bool differentialReturn =
438-
mode != DerivativeMode::ForwardMode && cast<Function>(fn)->getReturnType()->isFPOrFPVectorTy();
440+
mode != DerivativeMode::ForwardMode &&
441+
cast<Function>(fn)->getReturnType()->isFPOrFPVectorTy();
439442

440-
DIFFE_TYPE retType = whatType(cast<Function>(fn)->getReturnType(), mode == DerivativeMode::ForwardMode);
443+
DIFFE_TYPE retType = whatType(cast<Function>(fn)->getReturnType(),
444+
mode == DerivativeMode::ForwardMode);
441445

442446
std::map<Argument *, bool> volatile_args;
443447
FnTypeInfo type_args(cast<Function>(fn));
@@ -467,64 +471,72 @@ class Enzyme : public ModulePass {
467471
TypeAnalysis TA(TLI);
468472
type_args = TA.analyzeFunction(type_args).getAnalyzedTypeInfo();
469473

470-
Function * newFunc = nullptr;
471-
Type* tapeType = nullptr;
472-
switch(mode) {
473-
case DerivativeMode::ForwardMode:
474-
case DerivativeMode::ReverseModeCombined:
475-
newFunc = Logic.CreatePrimalAndGradient(
474+
Function *newFunc = nullptr;
475+
Type *tapeType = nullptr;
476+
switch (mode) {
477+
case DerivativeMode::ForwardMode:
478+
case DerivativeMode::ReverseModeCombined:
479+
newFunc = Logic.CreatePrimalAndGradient(
476480
cast<Function>(fn), retType, constants, TLI, TA,
477481
/*should return*/ false, /*dretPtr*/ false, /*topLevel*/ true,
478482
/*addedType*/ nullptr, type_args, volatile_args,
479-
/*index mapping*/ nullptr, AtomicAdd, mode == DerivativeMode::ForwardMode, PostOpt);
480-
break;
481-
case DerivativeMode::ReverseModePrimal:
482-
case DerivativeMode::ReverseModeGradient:{
483-
bool returnUsed = false;
484-
bool forceAnonymousTape = allocatedTapeSize == -1;
485-
auto &aug = Logic.CreateAugmentedPrimal(cast<Function>(fn),
486-
retType, constants, TLI, TA, /*returnUsed*/returnUsed, type_args,
487-
volatile_args, forceAnonymousTape, /*atomicAdd*/AtomicAdd, /*PostOpt*/PostOpt);
488-
auto &DL = cast<Function>(fn)->getParent()->getDataLayout();
489-
if (!forceAnonymousTape) {
490-
assert(!aug.tapeType);
491-
if (aug.returns.find(AugmentedStruct::Tape) != aug.returns.end()) {
492-
auto tapeIdx = aug.returns.find(AugmentedStruct::Tape)->second;
493-
tapeType = (tapeIdx == -1) ? aug.fn->getReturnType()
494-
: cast<StructType>(aug.fn->getReturnType())
495-
->getElementType(tapeIdx);
496-
}
497-
if (tapeType && DL.getTypeSizeInBits(tapeType) < 8 * allocatedTapeSize) {
498-
auto bytes = DL.getTypeSizeInBits(tapeType) / 8;
499-
EmitFailure("Insufficient tape allocation size", CI->getDebugLoc(), CI,
500-
"need ", bytes, " bytes have ", allocatedTapeSize, " bytes");
501-
}
502-
} else {
503-
tapeType = PointerType::getInt8PtrTy(fn->getContext());
483+
/*index mapping*/ nullptr, AtomicAdd,
484+
mode == DerivativeMode::ForwardMode, PostOpt);
485+
break;
486+
case DerivativeMode::ReverseModePrimal:
487+
case DerivativeMode::ReverseModeGradient: {
488+
bool returnUsed = false;
489+
bool forceAnonymousTape = allocatedTapeSize == -1;
490+
auto &aug = Logic.CreateAugmentedPrimal(
491+
cast<Function>(fn), retType, constants, TLI, TA,
492+
/*returnUsed*/ returnUsed, type_args, volatile_args,
493+
forceAnonymousTape, /*atomicAdd*/ AtomicAdd, /*PostOpt*/ PostOpt);
494+
auto &DL = cast<Function>(fn)->getParent()->getDataLayout();
495+
if (!forceAnonymousTape) {
496+
assert(!aug.tapeType);
497+
if (aug.returns.find(AugmentedStruct::Tape) != aug.returns.end()) {
498+
auto tapeIdx = aug.returns.find(AugmentedStruct::Tape)->second;
499+
tapeType = (tapeIdx == -1) ? aug.fn->getReturnType()
500+
: cast<StructType>(aug.fn->getReturnType())
501+
->getElementType(tapeIdx);
504502
}
505-
if (mode == DerivativeMode::ReverseModePrimal)
506-
newFunc = aug.fn;
507-
else
508-
newFunc = Logic.CreatePrimalAndGradient(cast<Function>(fn), retType, constants,
509-
TLI, TA, /*should return*/false, /*dretPtr*/ false, /*topLevel*/ false,
510-
tapeType, type_args, volatile_args,
511-
&aug, AtomicAdd, /*fwdMode*/false, PostOpt);
503+
if (tapeType &&
504+
DL.getTypeSizeInBits(tapeType) < 8 * allocatedTapeSize) {
505+
auto bytes = DL.getTypeSizeInBits(tapeType) / 8;
506+
EmitFailure("Insufficient tape allocation size", CI->getDebugLoc(),
507+
CI, "need ", bytes, " bytes have ", allocatedTapeSize,
508+
" bytes");
509+
}
510+
} else {
511+
tapeType = PointerType::getInt8PtrTy(fn->getContext());
512512
}
513+
if (mode == DerivativeMode::ReverseModePrimal)
514+
newFunc = aug.fn;
515+
else
516+
newFunc = Logic.CreatePrimalAndGradient(
517+
cast<Function>(fn), retType, constants, TLI, TA,
518+
/*should return*/ false, /*dretPtr*/ false, /*topLevel*/ false,
519+
tapeType, type_args, volatile_args, &aug, AtomicAdd,
520+
/*fwdMode*/ false, PostOpt);
521+
}
513522
}
514523

515524
if (!newFunc)
516525
return false;
517526

518527
if (differentialReturn)
519528
args.push_back(ConstantFP::get(cast<Function>(fn)->getReturnType(), 1.0));
520-
529+
521530
if (tape && tapeType) {
522531
auto &DL = cast<Function>(fn)->getParent()->getDataLayout();
523-
if (tapeType != tape->getType() && DL.getTypeSizeInBits(tapeType) <= DL.getTypeSizeInBits(tape->getType())) {
532+
if (tapeType != tape->getType() &&
533+
DL.getTypeSizeInBits(tapeType) <=
534+
DL.getTypeSizeInBits(tape->getType())) {
524535
IRBuilder<> EB(&CI->getParent()->getParent()->getEntryBlock().front());
525536
auto AL = EB.CreateAlloca(tape->getType());
526537
Builder.CreateStore(tape, AL);
527-
tape = Builder.CreateLoad(Builder.CreatePointerCast(AL, PointerType::getUnqual(tapeType)));
538+
tape = Builder.CreateLoad(
539+
Builder.CreatePointerCast(AL, PointerType::getUnqual(tapeType)));
528540
}
529541
llvm::errs() << *CI->getParent() << "\n";
530542
llvm::errs() << *CI->getParent() << "\n";
@@ -567,17 +579,21 @@ class Enzyme : public ModulePass {
567579
CI->replaceAllUsesWith(diffret);
568580
} else if (mode == DerivativeMode::ReverseModePrimal) {
569581
auto &DL = cast<Function>(fn)->getParent()->getDataLayout();
570-
if (DL.getTypeSizeInBits(CI->getType()) >= DL.getTypeSizeInBits(diffret->getType())) {
571-
IRBuilder<> EB(&CI->getParent()->getParent()->getEntryBlock().front());
582+
if (DL.getTypeSizeInBits(CI->getType()) >=
583+
DL.getTypeSizeInBits(diffret->getType())) {
584+
IRBuilder<> EB(
585+
&CI->getParent()->getParent()->getEntryBlock().front());
572586
auto AL = EB.CreateAlloca(CI->getType());
573-
Builder.CreateStore(diffret, Builder.CreatePointerCast(AL, PointerType::getUnqual(diffret->getType())));
587+
Builder.CreateStore(
588+
diffret, Builder.CreatePointerCast(
589+
AL, PointerType::getUnqual(diffret->getType())));
574590
CI->replaceAllUsesWith(Builder.CreateLoad(AL));
575591
} else {
576592
llvm::errs() << *CI << " - " << *diffret << "\n";
577593
assert(0 && " what");
578594
}
579595
} else {
580-
596+
581597
unsigned idxs[] = {0};
582598
auto diffreti = Builder.CreateExtractValue(diffret, idxs);
583599
if (diffreti->getType() == CI->getType()) {
@@ -756,7 +772,8 @@ class Enzyme : public ModulePass {
756772
Fn = fn;
757773
}
758774

759-
if (!Fn) continue;
775+
if (!Fn)
776+
continue;
760777

761778
if (Fn->getName() == "__enzyme_float") {
762779
CI->addAttribute(AttributeList::FunctionIndex, Attribute::ReadNone);
@@ -798,17 +815,16 @@ class Enzyme : public ModulePass {
798815
InactiveCalls.insert(CI);
799816
}
800817
if (Fn->getName() == "frexp" || Fn->getName() == "frexpf" ||
801-
Fn->getName() == "frexpl") {
818+
Fn->getName() == "frexpl") {
802819
CI->addAttribute(AttributeList::FunctionIndex, Attribute::ArgMemOnly);
803820
CI->addParamAttr(1, Attribute::WriteOnly);
804821
}
805-
if (Fn->getName() == "__fd_sincos_1" ||
806-
Fn->getName() == "__fd_cos_1" ||
807-
Fn->getName() == "__mth_i_ipowi") {
822+
if (Fn->getName() == "__fd_sincos_1" || Fn->getName() == "__fd_cos_1" ||
823+
Fn->getName() == "__mth_i_ipowi") {
808824
CI->addAttribute(AttributeList::FunctionIndex, Attribute::ReadNone);
809825
}
810826
if (Fn->getName() == "f90io_fmtw_end" ||
811-
Fn->getName() == "f90io_unf_end") {
827+
Fn->getName() == "f90io_unf_end") {
812828
Fn->addFnAttr(Attribute::InaccessibleMemOnly);
813829
CI->addAttribute(AttributeList::FunctionIndex,
814830
Attribute::InaccessibleMemOnly);

0 commit comments

Comments
 (0)