Skip to content

Commit 9b7c449

Browse files
authored
Handle atomicrmw of inactive values (rust-lang#259)
1 parent 6117bbd commit 9b7c449

File tree

6 files changed

+425
-232
lines changed

6 files changed

+425
-232
lines changed

enzyme/Enzyme/ActivityAnalysis.cpp

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -313,6 +313,38 @@ bool ActivityAnalyzer::isConstantInstruction(TypeResults &TR, Instruction *I) {
313313
auto dt = q[{i}];
314314
if (dt.isIntegral() || dt == BaseType::Anything) {
315315
SeenInteger = true;
316+
if (i == -1)
317+
break;
318+
} else if (dt.isKnown()) {
319+
AllIntegral = false;
320+
break;
321+
}
322+
}
323+
324+
if (AllIntegral && SeenInteger) {
325+
if (EnzymePrintActivity)
326+
llvm::errs() << " constant instruction from TA " << *I << "\n";
327+
InsertConstantInstruction(TR, I);
328+
return true;
329+
}
330+
}
331+
if (auto SI = dyn_cast<AtomicRMWInst>(I)) {
332+
auto StoreSize = SI->getParent()
333+
->getParent()
334+
->getParent()
335+
->getDataLayout()
336+
.getTypeSizeInBits(I->getType()) /
337+
8;
338+
339+
bool AllIntegral = true;
340+
bool SeenInteger = false;
341+
auto q = TR.query(SI->getOperand(0)).Data0();
342+
for (int i = -1; i < (int)StoreSize; ++i) {
343+
auto dt = q[{i}];
344+
if (dt.isIntegral() || dt == BaseType::Anything) {
345+
SeenInteger = true;
346+
if (i == -1)
347+
break;
316348
} else if (dt.isKnown()) {
317349
AllIntegral = false;
318350
break;
@@ -1063,6 +1095,11 @@ bool ActivityAnalyzer::isConstantValue(TypeResults &TR, Value *Val) {
10631095
F->getName() == "__fd_sincos_1") {
10641096
continue;
10651097
}
1098+
for (auto FuncName : KnownInactiveFunctionsStartingWith) {
1099+
if (F->getName().startswith(FuncName)) {
1100+
return true;
1101+
}
1102+
}
10661103

10671104
if (F->getName() == "__cxa_guard_acquire" ||
10681105
F->getName() == "__cxa_guard_release" ||

enzyme/Enzyme/AdjointGenerator.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -560,6 +560,20 @@ class AdjointGenerator
560560
eraseIfUnused(LI);
561561
}
562562

563+
void visitAtomicRMWInst(llvm::AtomicRMWInst &I) {
564+
if (!gutils->isConstantInstruction(&I) || !gutils->isConstantValue(&I)) {
565+
TR.dump();
566+
llvm::errs() << "oldFunc: " << *gutils->newFunc << "\n";
567+
llvm::errs() << "I: " << I << "\n";
568+
}
569+
assert(gutils->isConstantInstruction(&I));
570+
assert(gutils->isConstantValue(&I));
571+
572+
if (Mode == DerivativeMode::ReverseModeGradient) {
573+
eraseIfUnused(I, /*erase*/ true, /*check*/ false);
574+
}
575+
}
576+
563577
void visitStoreInst(llvm::StoreInst &SI) {
564578
Value *orig_ptr = SI.getPointerOperand();
565579
Value *orig_val = SI.getValueOperand();

enzyme/Enzyme/EnzymeLogic.cpp

Lines changed: 54 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ cl::opt<bool> nonmarkedglobals_inactiveloads(
8686
}
8787

8888
struct CacheAnalysis {
89+
TypeResults &TR;
8990
AAResults &AA;
9091
Function *oldFunc;
9192
ScalarEvolution &SE;
@@ -99,13 +100,14 @@ struct CacheAnalysis {
99100
bool omp;
100101
SmallVector<CallInst *, 0> kmpcCall;
101102
CacheAnalysis(
102-
AAResults &AA, Function *oldFunc, ScalarEvolution &SE, LoopInfo &OrigLI,
103-
DominatorTree &OrigDT, TargetLibraryInfo &TLI,
103+
TypeResults &TR, AAResults &AA, Function *oldFunc, ScalarEvolution &SE,
104+
LoopInfo &OrigLI, DominatorTree &OrigDT, TargetLibraryInfo &TLI,
104105
const SmallPtrSetImpl<const Instruction *> &unnecessaryInstructions,
105106
const std::map<Argument *, bool> &uncacheable_args, DerivativeMode mode,
106107
bool omp)
107-
: AA(AA), oldFunc(oldFunc), SE(SE), OrigLI(OrigLI), OrigDT(OrigDT),
108-
TLI(TLI), unnecessaryInstructions(unnecessaryInstructions),
108+
: TR(TR), AA(AA), oldFunc(oldFunc), SE(SE), OrigLI(OrigLI),
109+
OrigDT(OrigDT), TLI(TLI),
110+
unnecessaryInstructions(unnecessaryInstructions),
109111
uncacheable_args(uncacheable_args), mode(mode), omp(omp) {
110112

111113
for (auto &BB : *oldFunc)
@@ -559,6 +561,9 @@ struct CacheAnalysis {
559561
#endif
560562

561563
bool init_safe = !is_value_mustcache_from_origin(obj);
564+
auto CD = TR.query(obj)[{-1}];
565+
if (CD == BaseType::Integer || CD.isFloat())
566+
init_safe = true;
562567
if (!init_safe && !isa<ConstantInt>(obj) && !isa<Function>(obj)) {
563568
EmitWarning("UncacheableOrigin", callsite_op->getDebugLoc(), oldFunc,
564569
callsite_op->getParent(), "Callsite ", *callsite_op,
@@ -615,6 +620,10 @@ struct CacheAnalysis {
615620
return false;
616621

617622
for (unsigned i = 0; i < args.size(); ++i) {
623+
auto CD = TR.query(args[i])[{-1}];
624+
if (CD == BaseType::Integer || CD.isFloat())
625+
continue;
626+
618627
if (llvm::isModSet(AA.getModRefInfo(
619628
inst2, MemoryLocation::getForArgument(callsite_op, i, TLI)))) {
620629
if (!isa<ConstantInt>(callsite_op->getArgOperand(i)))
@@ -1534,26 +1543,6 @@ const AugmentedReturn &EnzymeLogic::CreateAugmentedPrimal(
15341543
++in_arg;
15351544
}
15361545
}
1537-
// TODO actually populate unnecessaryInstructions with what can be
1538-
// derived without activity info
1539-
SmallPtrSet<const Instruction *, 4> unnecessaryInstructionsTmp;
1540-
for (auto BB : guaranteedUnreachable) {
1541-
for (auto &I : *BB)
1542-
unnecessaryInstructionsTmp.insert(&I);
1543-
}
1544-
CacheAnalysis CA(gutils->OrigAA, gutils->oldFunc,
1545-
PPC.FAM.getResult<ScalarEvolutionAnalysis>(*gutils->oldFunc),
1546-
gutils->OrigLI, gutils->OrigDT, TLI,
1547-
unnecessaryInstructionsTmp, _uncacheable_argsPP,
1548-
DerivativeMode::ReverseModePrimal, omp);
1549-
const std::map<CallInst *, const std::map<Argument *, bool>>
1550-
uncacheable_args_map = CA.compute_uncacheable_args_for_callsites();
1551-
1552-
const std::map<Instruction *, bool> can_modref_map =
1553-
CA.compute_uncacheable_load_map();
1554-
gutils->can_modref_map = &can_modref_map;
1555-
1556-
// gutils->forceContexts();
15571546

15581547
FnTypeInfo typeInfo(gutils->oldFunc);
15591548
{
@@ -1579,6 +1568,26 @@ const AugmentedReturn &EnzymeLogic::CreateAugmentedPrimal(
15791568
}
15801569
TypeResults TR = TA.analyzeFunction(typeInfo);
15811570
assert(TR.info.Function == gutils->oldFunc);
1571+
1572+
// TODO actually populate unnecessaryInstructions with what can be
1573+
// derived without activity info
1574+
SmallPtrSet<const Instruction *, 4> unnecessaryInstructionsTmp;
1575+
for (auto BB : guaranteedUnreachable) {
1576+
for (auto &I : *BB)
1577+
unnecessaryInstructionsTmp.insert(&I);
1578+
}
1579+
CacheAnalysis CA(TR, gutils->OrigAA, gutils->oldFunc,
1580+
PPC.FAM.getResult<ScalarEvolutionAnalysis>(*gutils->oldFunc),
1581+
gutils->OrigLI, gutils->OrigDT, TLI,
1582+
unnecessaryInstructionsTmp, _uncacheable_argsPP,
1583+
DerivativeMode::ReverseModePrimal, omp);
1584+
const std::map<CallInst *, const std::map<Argument *, bool>>
1585+
uncacheable_args_map = CA.compute_uncacheable_args_for_callsites();
1586+
1587+
const std::map<Instruction *, bool> can_modref_map =
1588+
CA.compute_uncacheable_load_map();
1589+
gutils->can_modref_map = &can_modref_map;
1590+
15821591
gutils->forceActiveDetection(TR);
15831592

15841593
gutils->forceAugmentedReturns(TR, guaranteedUnreachable);
@@ -2780,28 +2789,6 @@ Function *EnzymeLogic::CreatePrimalAndGradient(
27802789
++in_arg;
27812790
}
27822791
}
2783-
// TODO populate with actual unnecessaryInstructions once the dependency
2784-
// cycle with activity analysis is removed
2785-
SmallPtrSet<const Instruction *, 4> unnecessaryInstructionsTmp;
2786-
for (auto BB : guaranteedUnreachable) {
2787-
for (auto &I : *BB)
2788-
unnecessaryInstructionsTmp.insert(&I);
2789-
}
2790-
CacheAnalysis CA(gutils->OrigAA, gutils->oldFunc,
2791-
PPC.FAM.getResult<ScalarEvolutionAnalysis>(*gutils->oldFunc),
2792-
gutils->OrigLI, gutils->OrigDT, TLI,
2793-
unnecessaryInstructionsTmp, _uncacheable_argsPP, mode, omp);
2794-
const std::map<CallInst *, const std::map<Argument *, bool>>
2795-
uncacheable_args_map =
2796-
(augmenteddata) ? augmenteddata->uncacheable_args_map
2797-
: CA.compute_uncacheable_args_for_callsites();
2798-
2799-
const std::map<Instruction *, bool> can_modref_map =
2800-
augmenteddata ? augmenteddata->can_modref_map
2801-
: CA.compute_uncacheable_load_map();
2802-
gutils->can_modref_map = &can_modref_map;
2803-
2804-
// gutils->forceContexts();
28052792

28062793
FnTypeInfo typeInfo(gutils->oldFunc);
28072794
{
@@ -2829,6 +2816,27 @@ Function *EnzymeLogic::CreatePrimalAndGradient(
28292816
TypeResults TR = TA.analyzeFunction(typeInfo);
28302817
assert(TR.info.Function == gutils->oldFunc);
28312818

2819+
// TODO populate with actual unnecessaryInstructions once the dependency
2820+
// cycle with activity analysis is removed
2821+
SmallPtrSet<const Instruction *, 4> unnecessaryInstructionsTmp;
2822+
for (auto BB : guaranteedUnreachable) {
2823+
for (auto &I : *BB)
2824+
unnecessaryInstructionsTmp.insert(&I);
2825+
}
2826+
CacheAnalysis CA(TR, gutils->OrigAA, gutils->oldFunc,
2827+
PPC.FAM.getResult<ScalarEvolutionAnalysis>(*gutils->oldFunc),
2828+
gutils->OrigLI, gutils->OrigDT, TLI,
2829+
unnecessaryInstructionsTmp, _uncacheable_argsPP, mode, omp);
2830+
const std::map<CallInst *, const std::map<Argument *, bool>>
2831+
uncacheable_args_map =
2832+
(augmenteddata) ? augmenteddata->uncacheable_args_map
2833+
: CA.compute_uncacheable_args_for_callsites();
2834+
2835+
const std::map<Instruction *, bool> can_modref_map =
2836+
augmenteddata ? augmenteddata->can_modref_map
2837+
: CA.compute_uncacheable_load_map();
2838+
gutils->can_modref_map = &can_modref_map;
2839+
28322840
gutils->forceActiveDetection(TR);
28332841
gutils->forceAugmentedReturns(TR, guaranteedUnreachable);
28342842

enzyme/Enzyme/GradientUtils.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2593,7 +2593,7 @@ Value *GradientUtils::invertPointerM(Value *oval, IRBuilder<> &BuilderM) {
25932593
Type *type = cast<PointerType>(arg->getType())->getElementType();
25942594
auto shadow = new GlobalVariable(
25952595
*arg->getParent(), type, arg->isConstant(), arg->getLinkage(),
2596-
ConstantAggregateZero::get(type), arg->getName() + "_shadow", arg,
2596+
Constant::getNullValue(type), arg->getName() + "_shadow", arg,
25972597
arg->getThreadLocalMode(), arg->getType()->getAddressSpace(),
25982598
arg->isExternallyInitialized());
25992599
arg->setMetadata("enzyme_shadow",

0 commit comments

Comments
 (0)