@@ -2487,6 +2487,84 @@ DiffeGradientUtils *DiffeGradientUtils::CreateFromClone(
2487
2487
return res;
2488
2488
}
2489
2489
2490
+ Constant *GradientUtils::GetOrCreateShadowFunction (EnzymeLogic &Logic,
2491
+ TargetLibraryInfo &TLI,
2492
+ TypeAnalysis &TA,
2493
+ Function *fn, bool AtomicAdd,
2494
+ bool PostOpt) {
2495
+ // ! Todo allow tape propagation
2496
+ // Note that specifically this should _not_ be called with topLevel=true
2497
+ // (since it may not be valid to always assume we can recompute the
2498
+ // augmented primal) However, in the absence of a way to pass tape data
2499
+ // from an indirect augmented (and also since we dont presently allow
2500
+ // indirect augmented calls), topLevel MUST be true otherwise subcalls will
2501
+ // not be able to lookup the augmenteddata/subdata (triggering an assertion
2502
+ // failure, among much worse)
2503
+ std::map<Argument *, bool > uncacheable_args;
2504
+ FnTypeInfo type_args (fn);
2505
+
2506
+ // conservatively assume that we can only cache existing floating types
2507
+ // (i.e. that all args are uncacheable)
2508
+ std::vector<DIFFE_TYPE> types;
2509
+ for (auto &a : fn->args ()) {
2510
+ uncacheable_args[&a] = !a.getType ()->isFPOrFPVectorTy ();
2511
+ type_args.Arguments .insert (std::pair<Argument *, TypeTree>(&a, {}));
2512
+ type_args.KnownValues .insert (
2513
+ std::pair<Argument *, std::set<int64_t >>(&a, {}));
2514
+ DIFFE_TYPE typ;
2515
+ if (a.getType ()->isFPOrFPVectorTy ()) {
2516
+ typ = DIFFE_TYPE::OUT_DIFF;
2517
+ } else if (a.getType ()->isIntegerTy () &&
2518
+ cast<IntegerType>(a.getType ())->getBitWidth () < 16 ) {
2519
+ typ = DIFFE_TYPE::CONSTANT;
2520
+ } else if (a.getType ()->isVoidTy () || a.getType ()->isEmptyTy ()) {
2521
+ typ = DIFFE_TYPE::CONSTANT;
2522
+ } else {
2523
+ typ = DIFFE_TYPE::DUP_ARG;
2524
+ }
2525
+ types.push_back (typ);
2526
+ }
2527
+
2528
+ DIFFE_TYPE retType = fn->getReturnType ()->isFPOrFPVectorTy ()
2529
+ ? DIFFE_TYPE::OUT_DIFF
2530
+ : DIFFE_TYPE::DUP_ARG;
2531
+ if (fn->getReturnType ()->isVoidTy () || fn->getReturnType ()->isEmptyTy () ||
2532
+ (fn->getReturnType ()->isIntegerTy () &&
2533
+ cast<IntegerType>(fn->getReturnType ())->getBitWidth () < 16 ))
2534
+ retType = DIFFE_TYPE::CONSTANT;
2535
+
2536
+ // TODO re atomic add consider forcing it to be atomic always as fallback if
2537
+ // used in a parallel context
2538
+ auto &augdata = Logic.CreateAugmentedPrimal (
2539
+ fn, retType, /* constant_args*/ types, TLI, TA,
2540
+ /* returnUsed*/ !fn->getReturnType ()->isEmptyTy () &&
2541
+ !fn->getReturnType ()->isVoidTy (),
2542
+ type_args, uncacheable_args, /* forceAnonymousTape*/ true , AtomicAdd,
2543
+ PostOpt);
2544
+ Constant *newf = Logic.CreatePrimalAndGradient (
2545
+ fn, retType, /* constant_args*/ types, TLI, TA,
2546
+ /* returnValue*/ false , /* dretPtr*/ false ,
2547
+ DerivativeMode::ReverseModeGradient,
2548
+ /* additionalArg*/ Type::getInt8PtrTy (fn->getContext ()), type_args,
2549
+ uncacheable_args,
2550
+ /* map*/ &augdata, AtomicAdd);
2551
+ if (!newf)
2552
+ newf = UndefValue::get (fn->getType ());
2553
+ auto cdata = ConstantStruct::get (
2554
+ StructType::get (newf->getContext (),
2555
+ {augdata.fn ->getType (), newf->getType ()}),
2556
+ {augdata.fn , newf});
2557
+ std::string globalname = (" _enzyme_" + fn->getName () + " '" ).str ();
2558
+ auto GV = fn->getParent ()->getNamedValue (globalname);
2559
+
2560
+ if (GV == nullptr ) {
2561
+ GV = new GlobalVariable (*fn->getParent (), cdata->getType (), true ,
2562
+ GlobalValue::LinkageTypes::InternalLinkage, cdata,
2563
+ globalname);
2564
+ }
2565
+ return ConstantExpr::getPointerCast (GV, fn->getType ());
2566
+ }
2567
+
2490
2568
Value *GradientUtils::invertPointerM (Value *const oval, IRBuilder<> &BuilderM,
2491
2569
bool nullShadow) {
2492
2570
assert (oval);
@@ -2768,78 +2846,7 @@ Value *GradientUtils::invertPointerM(Value *const oval, IRBuilder<> &BuilderM,
2768
2846
std::make_pair ((const Value *)oval, InvertedPointerVH (this , cs)));
2769
2847
return cs;
2770
2848
} else if (auto fn = dyn_cast<Function>(oval)) {
2771
- // ! Todo allow tape propagation
2772
- // Note that specifically this should _not_ be called with topLevel=true
2773
- // (since it may not be valid to always assume we can recompute the
2774
- // augmented primal) However, in the absence of a way to pass tape data
2775
- // from an indirect augmented (and also since we dont presently allow
2776
- // indirect augmented calls), topLevel MUST be true otherwise subcalls will
2777
- // not be able to lookup the augmenteddata/subdata (triggering an assertion
2778
- // failure, among much worse)
2779
- std::map<Argument *, bool > uncacheable_args;
2780
- FnTypeInfo type_args (fn);
2781
-
2782
- // conservatively assume that we can only cache existing floating types
2783
- // (i.e. that all args are uncacheable)
2784
- std::vector<DIFFE_TYPE> types;
2785
- for (auto &a : fn->args ()) {
2786
- uncacheable_args[&a] = !a.getType ()->isFPOrFPVectorTy ();
2787
- type_args.Arguments .insert (std::pair<Argument *, TypeTree>(&a, {}));
2788
- type_args.KnownValues .insert (
2789
- std::pair<Argument *, std::set<int64_t >>(&a, {}));
2790
- DIFFE_TYPE typ;
2791
- if (a.getType ()->isFPOrFPVectorTy ()) {
2792
- typ = DIFFE_TYPE::OUT_DIFF;
2793
- } else if (a.getType ()->isIntegerTy () &&
2794
- cast<IntegerType>(a.getType ())->getBitWidth () < 16 ) {
2795
- typ = DIFFE_TYPE::CONSTANT;
2796
- } else if (a.getType ()->isVoidTy () || a.getType ()->isEmptyTy ()) {
2797
- typ = DIFFE_TYPE::CONSTANT;
2798
- } else {
2799
- typ = DIFFE_TYPE::DUP_ARG;
2800
- }
2801
- types.push_back (typ);
2802
- }
2803
-
2804
- DIFFE_TYPE retType = fn->getReturnType ()->isFPOrFPVectorTy ()
2805
- ? DIFFE_TYPE::OUT_DIFF
2806
- : DIFFE_TYPE::DUP_ARG;
2807
- if (fn->getReturnType ()->isVoidTy () || fn->getReturnType ()->isEmptyTy () ||
2808
- (fn->getReturnType ()->isIntegerTy () &&
2809
- cast<IntegerType>(fn->getReturnType ())->getBitWidth () < 16 ))
2810
- retType = DIFFE_TYPE::CONSTANT;
2811
-
2812
- // TODO re atomic add consider forcing it to be atomic always as fallback if
2813
- // used in a parallel context
2814
- auto &augdata = Logic.CreateAugmentedPrimal (
2815
- fn, retType, /* constant_args*/ types, TLI, TA,
2816
- /* returnUsed*/ !fn->getReturnType ()->isEmptyTy () &&
2817
- !fn->getReturnType ()->isVoidTy (),
2818
- type_args, uncacheable_args, /* forceAnonymousTape*/ true , AtomicAdd,
2819
- /* PostOpt*/ false );
2820
- Constant *newf = Logic.CreatePrimalAndGradient (
2821
- fn, retType, /* constant_args*/ types, TLI, TA,
2822
- /* returnValue*/ false , /* dretPtr*/ false ,
2823
- DerivativeMode::ReverseModeGradient,
2824
- /* additionalArg*/ Type::getInt8PtrTy (fn->getContext ()), type_args,
2825
- uncacheable_args,
2826
- /* map*/ &augdata, AtomicAdd);
2827
- if (!newf)
2828
- newf = UndefValue::get (fn->getType ());
2829
- auto cdata = ConstantStruct::get (
2830
- StructType::get (newf->getContext (),
2831
- {augdata.fn ->getType (), newf->getType ()}),
2832
- {augdata.fn , newf});
2833
- std::string globalname = (" _enzyme_" + fn->getName () + " '" ).str ();
2834
- auto GV = fn->getParent ()->getNamedValue (globalname);
2835
-
2836
- if (GV == nullptr ) {
2837
- GV = new GlobalVariable (*fn->getParent (), cdata->getType (), true ,
2838
- GlobalValue::LinkageTypes::InternalLinkage, cdata,
2839
- globalname);
2840
- }
2841
-
2842
- return BuilderM.CreatePointerCast (GV, fn->getType ());
2849
+ return GetOrCreateShadowFunction (Logic, TLI, TA, fn, AtomicAdd);
2843
2850
} else if (auto arg = dyn_cast<CastInst>(oval)) {
2844
2851
IRBuilder<> bb (getNewFromOriginal (arg));
2845
2852
Value *invertOp = invertPointerM (arg->getOperand (0 ), bb);
0 commit comments