@@ -2689,8 +2689,91 @@ Function *EnzymeLogic::CreatePrimalAndGradient(
2689
2689
}
2690
2690
}
2691
2691
2692
- if (!hasconstant && mode != DerivativeMode::ReverseModeCombined &&
2693
- !returnValue && hasMetadata (todiff, " enzyme_gradient" )) {
2692
+ if (!hasconstant && !returnValue && hasMetadata (todiff, " enzyme_gradient" )) {
2693
+
2694
+ DIFFE_TYPE subretType = todiff->getReturnType ()->isFPOrFPVectorTy ()
2695
+ ? DIFFE_TYPE::OUT_DIFF
2696
+ : DIFFE_TYPE::DUP_ARG;
2697
+ if (todiff->getReturnType ()->isVoidTy () ||
2698
+ todiff->getReturnType ()->isEmptyTy ())
2699
+ subretType = DIFFE_TYPE::CONSTANT;
2700
+ assert (subretType == retType);
2701
+
2702
+ auto res = getDefaultFunctionTypeForGradient (todiff->getFunctionType (),
2703
+ /* retType*/ retType);
2704
+
2705
+ if (mode == DerivativeMode::ReverseModeCombined) {
2706
+
2707
+ FunctionType *FTy =
2708
+ FunctionType::get (StructType::get (todiff->getContext (), {res.second }),
2709
+ res.first , todiff->getFunctionType ()->isVarArg ());
2710
+ Function *NewF = Function::Create (
2711
+ FTy, Function::LinkageTypes::InternalLinkage,
2712
+ " fixgradient_" + todiff->getName (), todiff->getParent ());
2713
+
2714
+ BasicBlock *BB = BasicBlock::Create (NewF->getContext (), " entry" , NewF);
2715
+ IRBuilder<> bb (BB);
2716
+
2717
+ auto &aug = CreateAugmentedPrimal (
2718
+ todiff, retType, constant_args, TLI, TA, returnUsed, oldTypeInfo_,
2719
+ _uncacheable_args, /* forceAnonymousTape*/ false , AtomicAdd, PostOpt,
2720
+ omp);
2721
+
2722
+ SmallVector<Value *, 4 > fwdargs;
2723
+ for (auto &a : NewF->args ())
2724
+ fwdargs.push_back (&a);
2725
+ if (retType == DIFFE_TYPE::OUT_DIFF)
2726
+ fwdargs.pop_back ();
2727
+ auto cal = bb.CreateCall (aug.fn , fwdargs);
2728
+ cal->setCallingConv (aug.fn ->getCallingConv ());
2729
+
2730
+ llvm::Value *tape = nullptr ;
2731
+
2732
+ if (aug.returns .find (AugmentedStruct::Tape) != aug.returns .end ()) {
2733
+ auto tapeIdx = aug.returns .find (AugmentedStruct::Tape)->second ;
2734
+ tape = (tapeIdx == -1 ) ? cal : bb.CreateExtractValue (cal, tapeIdx);
2735
+ }
2736
+
2737
+ if (aug.tapeType ) {
2738
+ assert (tape);
2739
+ auto tapep =
2740
+ bb.CreatePointerCast (tape, PointerType::getUnqual (aug.tapeType ));
2741
+ auto truetape = bb.CreateLoad (tapep, " tapeld" );
2742
+ truetape->setMetadata (" enzyme_mustcache" ,
2743
+ MDNode::get (truetape->getContext (), {}));
2744
+
2745
+ CallInst *ci = cast<CallInst>(CallInst::CreateFree (tape, BB));
2746
+ bb.Insert (ci);
2747
+ ci->addAttribute (AttributeList::FirstArgIndex, Attribute::NonNull);
2748
+ tape = truetape;
2749
+ }
2750
+
2751
+ auto revfn = CreatePrimalAndGradient (
2752
+ todiff, retType, constant_args, TLI, TA,
2753
+ /* returnUsed*/ false , /* dretPtr*/ false ,
2754
+ /* mode*/ DerivativeMode::ReverseModeGradient,
2755
+ /* additionalArg*/ tape ? tape->getType () : nullptr , oldTypeInfo_,
2756
+ _uncacheable_args, &aug, AtomicAdd, PostOpt, omp);
2757
+
2758
+ SmallVector<Value *, 4 > revargs;
2759
+ for (auto &a : NewF->args ()) {
2760
+ revargs.push_back (&a);
2761
+ }
2762
+ if (tape) {
2763
+ revargs.push_back (tape);
2764
+ }
2765
+ auto revcal = bb.CreateCall (revfn, revargs);
2766
+ revcal->setCallingConv (revfn->getCallingConv ());
2767
+ if (NewF->getReturnType ()->isEmptyTy ())
2768
+ bb.CreateRet (UndefValue::get (NewF->getReturnType ()));
2769
+ else
2770
+ bb.CreateRet (revcal);
2771
+ assert (!returnUsed);
2772
+
2773
+ return insert_or_assign2<ReverseCacheKey, Function *>(
2774
+ ReverseCachedFunctions, tup, NewF)
2775
+ ->second ;
2776
+ }
2694
2777
2695
2778
auto md = todiff->getMetadata (" enzyme_gradient" );
2696
2779
if (!isa<MDTuple>(md)) {
@@ -2704,14 +2787,6 @@ Function *EnzymeLogic::CreatePrimalAndGradient(
2704
2787
auto gvemd = cast<ConstantAsMetadata>(md2->getOperand (0 ));
2705
2788
auto foundcalled = cast<Function>(gvemd->getValue ());
2706
2789
2707
- DIFFE_TYPE subretType = todiff->getReturnType ()->isFPOrFPVectorTy ()
2708
- ? DIFFE_TYPE::OUT_DIFF
2709
- : DIFFE_TYPE::DUP_ARG;
2710
- if (todiff->getReturnType ()->isVoidTy () ||
2711
- todiff->getReturnType ()->isEmptyTy ())
2712
- subretType = DIFFE_TYPE::CONSTANT;
2713
- auto res = getDefaultFunctionTypeForGradient (todiff->getFunctionType (),
2714
- /* retType*/ subretType);
2715
2790
assert (augmenteddata);
2716
2791
if (foundcalled->arg_size () == res.first .size () + 1 /* tape*/ ) {
2717
2792
auto lastarg = foundcalled->arg_end ();
0 commit comments