@@ -86,6 +86,7 @@ cl::opt<bool> nonmarkedglobals_inactiveloads(
86
86
}
87
87
88
88
struct CacheAnalysis {
89
+ TypeResults &TR;
89
90
AAResults &AA;
90
91
Function *oldFunc;
91
92
ScalarEvolution &SE;
@@ -99,13 +100,14 @@ struct CacheAnalysis {
99
100
bool omp;
100
101
SmallVector<CallInst *, 0 > kmpcCall;
101
102
CacheAnalysis (
102
- AAResults &AA, Function *oldFunc, ScalarEvolution &SE, LoopInfo &OrigLI ,
103
- DominatorTree &OrigDT, TargetLibraryInfo &TLI,
103
+ TypeResults &TR, AAResults &AA, Function *oldFunc, ScalarEvolution &SE,
104
+ LoopInfo &OrigLI, DominatorTree &OrigDT, TargetLibraryInfo &TLI,
104
105
const SmallPtrSetImpl<const Instruction *> &unnecessaryInstructions,
105
106
const std::map<Argument *, bool > &uncacheable_args, DerivativeMode mode,
106
107
bool omp)
107
- : AA(AA), oldFunc(oldFunc), SE(SE), OrigLI(OrigLI), OrigDT(OrigDT),
108
- TLI (TLI), unnecessaryInstructions(unnecessaryInstructions),
108
+ : TR(TR), AA(AA), oldFunc(oldFunc), SE(SE), OrigLI(OrigLI),
109
+ OrigDT (OrigDT), TLI(TLI),
110
+ unnecessaryInstructions(unnecessaryInstructions),
109
111
uncacheable_args(uncacheable_args), mode(mode), omp(omp) {
110
112
111
113
for (auto &BB : *oldFunc)
@@ -559,6 +561,9 @@ struct CacheAnalysis {
559
561
#endif
560
562
561
563
bool init_safe = !is_value_mustcache_from_origin (obj);
564
+ auto CD = TR.query (obj)[{-1 }];
565
+ if (CD == BaseType::Integer || CD.isFloat ())
566
+ init_safe = true ;
562
567
if (!init_safe && !isa<ConstantInt>(obj) && !isa<Function>(obj)) {
563
568
EmitWarning (" UncacheableOrigin" , callsite_op->getDebugLoc (), oldFunc,
564
569
callsite_op->getParent (), " Callsite " , *callsite_op,
@@ -615,6 +620,10 @@ struct CacheAnalysis {
615
620
return false ;
616
621
617
622
for (unsigned i = 0 ; i < args.size (); ++i) {
623
+ auto CD = TR.query (args[i])[{-1 }];
624
+ if (CD == BaseType::Integer || CD.isFloat ())
625
+ continue ;
626
+
618
627
if (llvm::isModSet (AA.getModRefInfo (
619
628
inst2, MemoryLocation::getForArgument (callsite_op, i, TLI)))) {
620
629
if (!isa<ConstantInt>(callsite_op->getArgOperand (i)))
@@ -1534,26 +1543,6 @@ const AugmentedReturn &EnzymeLogic::CreateAugmentedPrimal(
1534
1543
++in_arg;
1535
1544
}
1536
1545
}
1537
- // TODO actually populate unnecessaryInstructions with what can be
1538
- // derived without activity info
1539
- SmallPtrSet<const Instruction *, 4 > unnecessaryInstructionsTmp;
1540
- for (auto BB : guaranteedUnreachable) {
1541
- for (auto &I : *BB)
1542
- unnecessaryInstructionsTmp.insert (&I);
1543
- }
1544
- CacheAnalysis CA (gutils->OrigAA , gutils->oldFunc ,
1545
- PPC.FAM .getResult <ScalarEvolutionAnalysis>(*gutils->oldFunc ),
1546
- gutils->OrigLI , gutils->OrigDT , TLI,
1547
- unnecessaryInstructionsTmp, _uncacheable_argsPP,
1548
- DerivativeMode::ReverseModePrimal, omp);
1549
- const std::map<CallInst *, const std::map<Argument *, bool >>
1550
- uncacheable_args_map = CA.compute_uncacheable_args_for_callsites ();
1551
-
1552
- const std::map<Instruction *, bool > can_modref_map =
1553
- CA.compute_uncacheable_load_map ();
1554
- gutils->can_modref_map = &can_modref_map;
1555
-
1556
- // gutils->forceContexts();
1557
1546
1558
1547
FnTypeInfo typeInfo (gutils->oldFunc );
1559
1548
{
@@ -1579,6 +1568,26 @@ const AugmentedReturn &EnzymeLogic::CreateAugmentedPrimal(
1579
1568
}
1580
1569
TypeResults TR = TA.analyzeFunction (typeInfo);
1581
1570
assert (TR.info .Function == gutils->oldFunc );
1571
+
1572
+ // TODO actually populate unnecessaryInstructions with what can be
1573
+ // derived without activity info
1574
+ SmallPtrSet<const Instruction *, 4 > unnecessaryInstructionsTmp;
1575
+ for (auto BB : guaranteedUnreachable) {
1576
+ for (auto &I : *BB)
1577
+ unnecessaryInstructionsTmp.insert (&I);
1578
+ }
1579
+ CacheAnalysis CA (TR, gutils->OrigAA , gutils->oldFunc ,
1580
+ PPC.FAM .getResult <ScalarEvolutionAnalysis>(*gutils->oldFunc ),
1581
+ gutils->OrigLI , gutils->OrigDT , TLI,
1582
+ unnecessaryInstructionsTmp, _uncacheable_argsPP,
1583
+ DerivativeMode::ReverseModePrimal, omp);
1584
+ const std::map<CallInst *, const std::map<Argument *, bool >>
1585
+ uncacheable_args_map = CA.compute_uncacheable_args_for_callsites ();
1586
+
1587
+ const std::map<Instruction *, bool > can_modref_map =
1588
+ CA.compute_uncacheable_load_map ();
1589
+ gutils->can_modref_map = &can_modref_map;
1590
+
1582
1591
gutils->forceActiveDetection (TR);
1583
1592
1584
1593
gutils->forceAugmentedReturns (TR, guaranteedUnreachable);
@@ -2780,28 +2789,6 @@ Function *EnzymeLogic::CreatePrimalAndGradient(
2780
2789
++in_arg;
2781
2790
}
2782
2791
}
2783
- // TODO populate with actual unnecessaryInstructions once the dependency
2784
- // cycle with activity analysis is removed
2785
- SmallPtrSet<const Instruction *, 4 > unnecessaryInstructionsTmp;
2786
- for (auto BB : guaranteedUnreachable) {
2787
- for (auto &I : *BB)
2788
- unnecessaryInstructionsTmp.insert (&I);
2789
- }
2790
- CacheAnalysis CA (gutils->OrigAA , gutils->oldFunc ,
2791
- PPC.FAM .getResult <ScalarEvolutionAnalysis>(*gutils->oldFunc ),
2792
- gutils->OrigLI , gutils->OrigDT , TLI,
2793
- unnecessaryInstructionsTmp, _uncacheable_argsPP, mode, omp);
2794
- const std::map<CallInst *, const std::map<Argument *, bool >>
2795
- uncacheable_args_map =
2796
- (augmenteddata) ? augmenteddata->uncacheable_args_map
2797
- : CA.compute_uncacheable_args_for_callsites ();
2798
-
2799
- const std::map<Instruction *, bool > can_modref_map =
2800
- augmenteddata ? augmenteddata->can_modref_map
2801
- : CA.compute_uncacheable_load_map ();
2802
- gutils->can_modref_map = &can_modref_map;
2803
-
2804
- // gutils->forceContexts();
2805
2792
2806
2793
FnTypeInfo typeInfo (gutils->oldFunc );
2807
2794
{
@@ -2829,6 +2816,27 @@ Function *EnzymeLogic::CreatePrimalAndGradient(
2829
2816
TypeResults TR = TA.analyzeFunction (typeInfo);
2830
2817
assert (TR.info .Function == gutils->oldFunc );
2831
2818
2819
+ // TODO populate with actual unnecessaryInstructions once the dependency
2820
+ // cycle with activity analysis is removed
2821
+ SmallPtrSet<const Instruction *, 4 > unnecessaryInstructionsTmp;
2822
+ for (auto BB : guaranteedUnreachable) {
2823
+ for (auto &I : *BB)
2824
+ unnecessaryInstructionsTmp.insert (&I);
2825
+ }
2826
+ CacheAnalysis CA (TR, gutils->OrigAA , gutils->oldFunc ,
2827
+ PPC.FAM .getResult <ScalarEvolutionAnalysis>(*gutils->oldFunc ),
2828
+ gutils->OrigLI , gutils->OrigDT , TLI,
2829
+ unnecessaryInstructionsTmp, _uncacheable_argsPP, mode, omp);
2830
+ const std::map<CallInst *, const std::map<Argument *, bool >>
2831
+ uncacheable_args_map =
2832
+ (augmenteddata) ? augmenteddata->uncacheable_args_map
2833
+ : CA.compute_uncacheable_args_for_callsites ();
2834
+
2835
+ const std::map<Instruction *, bool > can_modref_map =
2836
+ augmenteddata ? augmenteddata->can_modref_map
2837
+ : CA.compute_uncacheable_load_map ();
2838
+ gutils->can_modref_map = &can_modref_map;
2839
+
2832
2840
gutils->forceActiveDetection (TR);
2833
2841
gutils->forceAugmentedReturns (TR, guaranteedUnreachable);
2834
2842
0 commit comments