@@ -148,7 +148,7 @@ class Enzyme : public ModulePass {
148
148
Arch == Triple::amdgcn;
149
149
150
150
std::map<int , Type *> byVal;
151
- llvm::Value* tape = nullptr ;
151
+ llvm::Value * tape = nullptr ;
152
152
int allocatedTapeSize = -1 ;
153
153
for (unsigned i = 1 ; i < CI->getNumArgOperands (); ++i) {
154
154
Value *res = CI->getArgOperand (i);
@@ -211,7 +211,8 @@ class Enzyme : public ModulePass {
211
211
res = CI->getArgOperand (i);
212
212
} else if (MS == " enzyme_allocated" ) {
213
213
++i;
214
- allocatedTapeSize = cast<ConstantInt>(CI->getArgOperand (i))->getSExtValue ();
214
+ allocatedTapeSize =
215
+ cast<ConstantInt>(CI->getArgOperand (i))->getSExtValue ();
215
216
continue ;
216
217
} else {
217
218
ty = whatType (PTy, mode == DerivativeMode::ForwardMode);
@@ -245,7 +246,8 @@ class Enzyme : public ModulePass {
245
246
res = CI->getArgOperand (i);
246
247
} else if (MS == " enzyme_allocated" ) {
247
248
++i;
248
- allocatedTapeSize = cast<ConstantInt>(CI->getArgOperand (i))->getSExtValue ();
249
+ allocatedTapeSize =
250
+ cast<ConstantInt>(CI->getArgOperand (i))->getSExtValue ();
249
251
continue ;
250
252
} else {
251
253
ty = whatType (PTy, mode == DerivativeMode::ForwardMode);
@@ -435,9 +437,11 @@ class Enzyme : public ModulePass {
435
437
}
436
438
437
439
bool differentialReturn =
438
- mode != DerivativeMode::ForwardMode && cast<Function>(fn)->getReturnType ()->isFPOrFPVectorTy ();
440
+ mode != DerivativeMode::ForwardMode &&
441
+ cast<Function>(fn)->getReturnType ()->isFPOrFPVectorTy ();
439
442
440
- DIFFE_TYPE retType = whatType (cast<Function>(fn)->getReturnType (), mode == DerivativeMode::ForwardMode);
443
+ DIFFE_TYPE retType = whatType (cast<Function>(fn)->getReturnType (),
444
+ mode == DerivativeMode::ForwardMode);
441
445
442
446
std::map<Argument *, bool > volatile_args;
443
447
FnTypeInfo type_args (cast<Function>(fn));
@@ -467,64 +471,72 @@ class Enzyme : public ModulePass {
467
471
TypeAnalysis TA (TLI);
468
472
type_args = TA.analyzeFunction (type_args).getAnalyzedTypeInfo ();
469
473
470
- Function * newFunc = nullptr ;
471
- Type* tapeType = nullptr ;
472
- switch (mode) {
473
- case DerivativeMode::ForwardMode:
474
- case DerivativeMode::ReverseModeCombined:
475
- newFunc = Logic.CreatePrimalAndGradient (
474
+ Function *newFunc = nullptr ;
475
+ Type * tapeType = nullptr ;
476
+ switch (mode) {
477
+ case DerivativeMode::ForwardMode:
478
+ case DerivativeMode::ReverseModeCombined:
479
+ newFunc = Logic.CreatePrimalAndGradient (
476
480
cast<Function>(fn), retType, constants, TLI, TA,
477
481
/* should return*/ false , /* dretPtr*/ false , /* topLevel*/ true ,
478
482
/* addedType*/ nullptr , type_args, volatile_args,
479
- /* index mapping*/ nullptr , AtomicAdd, mode == DerivativeMode::ForwardMode, PostOpt);
480
- break ;
481
- case DerivativeMode::ReverseModePrimal:
482
- case DerivativeMode::ReverseModeGradient:{
483
- bool returnUsed = false ;
484
- bool forceAnonymousTape = allocatedTapeSize == -1 ;
485
- auto &aug = Logic.CreateAugmentedPrimal (cast<Function>(fn),
486
- retType, constants, TLI, TA, /* returnUsed*/ returnUsed, type_args,
487
- volatile_args, forceAnonymousTape, /* atomicAdd*/ AtomicAdd, /* PostOpt*/ PostOpt);
488
- auto &DL = cast<Function>(fn)->getParent ()->getDataLayout ();
489
- if (!forceAnonymousTape) {
490
- assert (!aug.tapeType );
491
- if (aug.returns .find (AugmentedStruct::Tape) != aug.returns .end ()) {
492
- auto tapeIdx = aug.returns .find (AugmentedStruct::Tape)->second ;
493
- tapeType = (tapeIdx == -1 ) ? aug.fn ->getReturnType ()
494
- : cast<StructType>(aug.fn ->getReturnType ())
495
- ->getElementType (tapeIdx);
496
- }
497
- if (tapeType && DL.getTypeSizeInBits (tapeType) < 8 * allocatedTapeSize) {
498
- auto bytes = DL.getTypeSizeInBits (tapeType) / 8 ;
499
- EmitFailure (" Insufficient tape allocation size" , CI->getDebugLoc (), CI,
500
- " need " , bytes, " bytes have " , allocatedTapeSize, " bytes" );
501
- }
502
- } else {
503
- tapeType = PointerType::getInt8PtrTy (fn->getContext ());
483
+ /* index mapping*/ nullptr , AtomicAdd,
484
+ mode == DerivativeMode::ForwardMode, PostOpt);
485
+ break ;
486
+ case DerivativeMode::ReverseModePrimal:
487
+ case DerivativeMode::ReverseModeGradient: {
488
+ bool returnUsed = false ;
489
+ bool forceAnonymousTape = allocatedTapeSize == -1 ;
490
+ auto &aug = Logic.CreateAugmentedPrimal (
491
+ cast<Function>(fn), retType, constants, TLI, TA,
492
+ /* returnUsed*/ returnUsed, type_args, volatile_args,
493
+ forceAnonymousTape, /* atomicAdd*/ AtomicAdd, /* PostOpt*/ PostOpt);
494
+ auto &DL = cast<Function>(fn)->getParent ()->getDataLayout ();
495
+ if (!forceAnonymousTape) {
496
+ assert (!aug.tapeType );
497
+ if (aug.returns .find (AugmentedStruct::Tape) != aug.returns .end ()) {
498
+ auto tapeIdx = aug.returns .find (AugmentedStruct::Tape)->second ;
499
+ tapeType = (tapeIdx == -1 ) ? aug.fn ->getReturnType ()
500
+ : cast<StructType>(aug.fn ->getReturnType ())
501
+ ->getElementType (tapeIdx);
504
502
}
505
- if (mode == DerivativeMode::ReverseModePrimal)
506
- newFunc = aug.fn ;
507
- else
508
- newFunc = Logic.CreatePrimalAndGradient (cast<Function>(fn), retType, constants,
509
- TLI, TA, /* should return*/ false , /* dretPtr*/ false , /* topLevel*/ false ,
510
- tapeType, type_args, volatile_args,
511
- &aug, AtomicAdd, /* fwdMode*/ false , PostOpt);
503
+ if (tapeType &&
504
+ DL.getTypeSizeInBits (tapeType) < 8 * allocatedTapeSize) {
505
+ auto bytes = DL.getTypeSizeInBits (tapeType) / 8 ;
506
+ EmitFailure (" Insufficient tape allocation size" , CI->getDebugLoc (),
507
+ CI, " need " , bytes, " bytes have " , allocatedTapeSize,
508
+ " bytes" );
509
+ }
510
+ } else {
511
+ tapeType = PointerType::getInt8PtrTy (fn->getContext ());
512
512
}
513
+ if (mode == DerivativeMode::ReverseModePrimal)
514
+ newFunc = aug.fn ;
515
+ else
516
+ newFunc = Logic.CreatePrimalAndGradient (
517
+ cast<Function>(fn), retType, constants, TLI, TA,
518
+ /* should return*/ false , /* dretPtr*/ false , /* topLevel*/ false ,
519
+ tapeType, type_args, volatile_args, &aug, AtomicAdd,
520
+ /* fwdMode*/ false , PostOpt);
521
+ }
513
522
}
514
523
515
524
if (!newFunc)
516
525
return false ;
517
526
518
527
if (differentialReturn)
519
528
args.push_back (ConstantFP::get (cast<Function>(fn)->getReturnType (), 1.0 ));
520
-
529
+
521
530
if (tape && tapeType) {
522
531
auto &DL = cast<Function>(fn)->getParent ()->getDataLayout ();
523
- if (tapeType != tape->getType () && DL.getTypeSizeInBits (tapeType) <= DL.getTypeSizeInBits (tape->getType ())) {
532
+ if (tapeType != tape->getType () &&
533
+ DL.getTypeSizeInBits (tapeType) <=
534
+ DL.getTypeSizeInBits (tape->getType ())) {
524
535
IRBuilder<> EB (&CI->getParent ()->getParent ()->getEntryBlock ().front ());
525
536
auto AL = EB.CreateAlloca (tape->getType ());
526
537
Builder.CreateStore (tape, AL);
527
- tape = Builder.CreateLoad (Builder.CreatePointerCast (AL, PointerType::getUnqual (tapeType)));
538
+ tape = Builder.CreateLoad (
539
+ Builder.CreatePointerCast (AL, PointerType::getUnqual (tapeType)));
528
540
}
529
541
llvm::errs () << *CI->getParent () << " \n " ;
530
542
llvm::errs () << *CI->getParent () << " \n " ;
@@ -567,17 +579,21 @@ class Enzyme : public ModulePass {
567
579
CI->replaceAllUsesWith (diffret);
568
580
} else if (mode == DerivativeMode::ReverseModePrimal) {
569
581
auto &DL = cast<Function>(fn)->getParent ()->getDataLayout ();
570
- if (DL.getTypeSizeInBits (CI->getType ()) >= DL.getTypeSizeInBits (diffret->getType ())) {
571
- IRBuilder<> EB (&CI->getParent ()->getParent ()->getEntryBlock ().front ());
582
+ if (DL.getTypeSizeInBits (CI->getType ()) >=
583
+ DL.getTypeSizeInBits (diffret->getType ())) {
584
+ IRBuilder<> EB (
585
+ &CI->getParent ()->getParent ()->getEntryBlock ().front ());
572
586
auto AL = EB.CreateAlloca (CI->getType ());
573
- Builder.CreateStore (diffret, Builder.CreatePointerCast (AL, PointerType::getUnqual (diffret->getType ())));
587
+ Builder.CreateStore (
588
+ diffret, Builder.CreatePointerCast (
589
+ AL, PointerType::getUnqual (diffret->getType ())));
574
590
CI->replaceAllUsesWith (Builder.CreateLoad (AL));
575
591
} else {
576
592
llvm::errs () << *CI << " - " << *diffret << " \n " ;
577
593
assert (0 && " what" );
578
594
}
579
595
} else {
580
-
596
+
581
597
unsigned idxs[] = {0 };
582
598
auto diffreti = Builder.CreateExtractValue (diffret, idxs);
583
599
if (diffreti->getType () == CI->getType ()) {
@@ -756,7 +772,8 @@ class Enzyme : public ModulePass {
756
772
Fn = fn;
757
773
}
758
774
759
- if (!Fn) continue ;
775
+ if (!Fn)
776
+ continue ;
760
777
761
778
if (Fn->getName () == " __enzyme_float" ) {
762
779
CI->addAttribute (AttributeList::FunctionIndex, Attribute::ReadNone);
@@ -798,17 +815,16 @@ class Enzyme : public ModulePass {
798
815
InactiveCalls.insert (CI);
799
816
}
800
817
if (Fn->getName () == " frexp" || Fn->getName () == " frexpf" ||
801
- Fn->getName () == " frexpl" ) {
818
+ Fn->getName () == " frexpl" ) {
802
819
CI->addAttribute (AttributeList::FunctionIndex, Attribute::ArgMemOnly);
803
820
CI->addParamAttr (1 , Attribute::WriteOnly);
804
821
}
805
- if (Fn->getName () == " __fd_sincos_1" ||
806
- Fn->getName () == " __fd_cos_1" ||
807
- Fn->getName () == " __mth_i_ipowi" ) {
822
+ if (Fn->getName () == " __fd_sincos_1" || Fn->getName () == " __fd_cos_1" ||
823
+ Fn->getName () == " __mth_i_ipowi" ) {
808
824
CI->addAttribute (AttributeList::FunctionIndex, Attribute::ReadNone);
809
825
}
810
826
if (Fn->getName () == " f90io_fmtw_end" ||
811
- Fn->getName () == " f90io_unf_end" ) {
827
+ Fn->getName () == " f90io_unf_end" ) {
812
828
Fn->addFnAttr (Attribute::InaccessibleMemOnly);
813
829
CI->addAttribute (AttributeList::FunctionIndex,
814
830
Attribute::InaccessibleMemOnly);
0 commit comments