Skip to content

Commit f92c0eb

Browse files
authored
Simplify differential argument code (rust-lang#803)
* Simplify diffe ABI * Correct writetomem * Fix bug * Fix token ty bug
1 parent 04caf58 commit f92c0eb

File tree

5 files changed

+204
-166
lines changed

5 files changed

+204
-166
lines changed

enzyme/Enzyme/AdjointGenerator.h

Lines changed: 26 additions & 129 deletions
Original file line numberDiff line numberDiff line change
@@ -4655,32 +4655,16 @@ class AdjointGenerator
46554655
args.push_back(lookup(argi, Builder2));
46564656
}
46574657

4658-
if (gutils->isConstantValue(call.getArgOperand(i)) && !foreignFunction) {
4659-
argsInverted.push_back(DIFFE_TYPE::CONSTANT);
4658+
auto argTy = gutils->getDiffeType(call.getArgOperand(i), foreignFunction);
4659+
argsInverted.push_back(argTy);
4660+
4661+
if (argTy == DIFFE_TYPE::CONSTANT) {
46604662
continue;
46614663
}
46624664

46634665
auto argType = argi->getType();
46644666

4665-
if (!argType->isFPOrFPVectorTy() &&
4666-
TR.query(call.getArgOperand(i)).Inner0().isPossiblePointer()) {
4667-
DIFFE_TYPE ty = DIFFE_TYPE::DUP_ARG;
4668-
if (argType->isPointerTy()) {
4669-
#if LLVM_VERSION_MAJOR >= 12
4670-
auto at = getUnderlyingObject(call.getArgOperand(i), 100);
4671-
#else
4672-
auto at = GetUnderlyingObject(
4673-
call.getArgOperand(i),
4674-
gutils->oldFunc->getParent()->getDataLayout(), 100);
4675-
#endif
4676-
if (auto arg = dyn_cast<Argument>(at)) {
4677-
if (constant_args[arg->getArgNo()] == DIFFE_TYPE::DUP_NONEED) {
4678-
ty = DIFFE_TYPE::DUP_NONEED;
4679-
}
4680-
}
4681-
}
4682-
argsInverted.push_back(ty);
4683-
4667+
if (argTy == DIFFE_TYPE::DUP_ARG || argTy == DIFFE_TYPE::DUP_NONEED) {
46844668
if (Mode != DerivativeMode::ReverseModePrimal) {
46854669
IRBuilder<> Builder2(call.getParent());
46864670
getReverseBuilder(Builder2);
@@ -4699,7 +4683,6 @@ class AdjointGenerator
46994683
assert(TR.query(call.getArgOperand(i)).Inner0().isFloat());
47004684
OutTypes.push_back(call.getArgOperand(i));
47014685
OutFPTypes.push_back(argType);
4702-
argsInverted.push_back(DIFFE_TYPE::OUT_DIFF);
47034686
assert(whatType(argType, Mode) == DIFFE_TYPE::OUT_DIFF ||
47044687
whatType(argType, Mode) == DIFFE_TYPE::CONSTANT);
47054688
}
@@ -8484,37 +8467,10 @@ class AdjointGenerator
84848467
funcName = called->getName();
84858468
}
84868469

8487-
bool subretused = unnecessaryValues.find(orig) == unnecessaryValues.end();
8488-
if (gutils->knownRecomputeHeuristic.find(orig) !=
8489-
gutils->knownRecomputeHeuristic.end()) {
8490-
if (!gutils->knownRecomputeHeuristic[orig]) {
8491-
subretused = true;
8492-
}
8493-
}
8470+
bool subretused = false;
84948471
bool shadowReturnUsed = false;
8495-
8496-
DIFFE_TYPE subretType;
8497-
if (gutils->isConstantValue(orig)) {
8498-
subretType = DIFFE_TYPE::CONSTANT;
8499-
} else {
8500-
if (Mode == DerivativeMode::ForwardMode ||
8501-
Mode == DerivativeMode::ForwardModeSplit) {
8502-
subretType = DIFFE_TYPE::DUP_ARG;
8503-
shadowReturnUsed = true;
8504-
} else {
8505-
if (!orig->getType()->isFPOrFPVectorTy() &&
8506-
TR.query(orig).Inner0().isPossiblePointer()) {
8507-
if (is_value_needed_in_reverse<ValueType::Shadow>(gutils, orig, Mode,
8508-
oldUnreachable)) {
8509-
subretType = DIFFE_TYPE::DUP_ARG;
8510-
shadowReturnUsed = true;
8511-
} else
8512-
subretType = DIFFE_TYPE::CONSTANT;
8513-
} else {
8514-
subretType = DIFFE_TYPE::OUT_DIFF;
8515-
}
8516-
}
8517-
}
8472+
DIFFE_TYPE subretType =
8473+
gutils->getReturnDiffeType(orig, &subretused, &shadowReturnUsed);
85188474

85198475
if (Mode == DerivativeMode::ForwardMode) {
85208476
auto found = customFwdCallHandlers.find(funcName.str());
@@ -8576,22 +8532,9 @@ class AdjointGenerator
85768532
getReverseBuilder(Builder2);
85778533

85788534
Value *invertedReturn = nullptr;
8579-
bool hasNonReturnUse = false;
85808535
auto ifound = gutils->invertedPointers.find(orig);
85818536
if (ifound != gutils->invertedPointers.end()) {
8582-
//! We only need the shadow pointer for non-forward Mode if it is used
8583-
//! in a non return setting
8584-
if (!gutils->isConstantValue(orig)) {
8585-
if (!orig->getType()->isFPOrFPVectorTy() &&
8586-
TR.query(orig).Inner0().isPossiblePointer()) {
8587-
if (is_value_needed_in_reverse<ValueType::Shadow>(
8588-
gutils, orig, DerivativeMode::ReverseModePrimal,
8589-
oldUnreachable)) {
8590-
hasNonReturnUse = true;
8591-
}
8592-
}
8593-
}
8594-
if (hasNonReturnUse)
8537+
if (shadowReturnUsed)
85958538
invertedReturn = cast<PHINode>(&*ifound->second);
85968539
}
85978540

@@ -8627,7 +8570,7 @@ class AdjointGenerator
86278570

86288571
if (ifound != gutils->invertedPointers.end()) {
86298572
auto placeholder = cast<PHINode>(&*ifound->second);
8630-
if (!hasNonReturnUse) {
8573+
if (!shadowReturnUsed) {
86318574
gutils->invertedPointers.erase(ifound);
86328575
gutils->erase(placeholder);
86338576
} else {
@@ -8687,8 +8630,7 @@ class AdjointGenerator
86878630
gutils->replaceAWithB(newCall, normalReturn);
86888631
BuilderZ.SetInsertPoint(newCall->getNextNode());
86898632
gutils->erase(newCall);
8690-
} else if ((!orig->mayWriteToMemory() ||
8691-
Mode == DerivativeMode::ReverseModeGradient) &&
8633+
} else if (Mode == DerivativeMode::ReverseModeGradient &&
86928634
!orig->getType()->isTokenTy())
86938635
eraseIfUnused(*orig, /*erase*/ true, /*check*/ false);
86948636
}
@@ -11244,47 +11186,18 @@ class AdjointGenerator
1124411186
#endif
1124511187
args.push_back(argi);
1124611188

11247-
if (gutils->isConstantValue(orig->getArgOperand(i)) &&
11248-
!foreignFunction) {
11249-
argsInverted.push_back(DIFFE_TYPE::CONSTANT);
11189+
auto argTy =
11190+
gutils->getDiffeType(orig->getArgOperand(i), foreignFunction);
11191+
argsInverted.push_back(argTy);
11192+
11193+
if (argTy == DIFFE_TYPE::CONSTANT) {
1125011194
continue;
1125111195
}
1125211196

11253-
auto argType = argi->getType();
11254-
11255-
if (!argType->isFPOrFPVectorTy() &&
11256-
(TR.query(orig->getArgOperand(i)).Inner0().isPossiblePointer() ||
11257-
foreignFunction)) {
11258-
DIFFE_TYPE ty = DIFFE_TYPE::DUP_ARG;
11259-
if (argType->isPointerTy()) {
11260-
#if LLVM_VERSION_MAJOR >= 12
11261-
auto at = getUnderlyingObject(orig->getArgOperand(i), 100);
11262-
#else
11263-
auto at = GetUnderlyingObject(
11264-
orig->getArgOperand(i),
11265-
gutils->oldFunc->getParent()->getDataLayout(), 100);
11266-
#endif
11267-
if (auto arg = dyn_cast<Argument>(at)) {
11268-
if (constant_args[arg->getArgNo()] == DIFFE_TYPE::DUP_NONEED) {
11269-
ty = DIFFE_TYPE::DUP_NONEED;
11270-
}
11271-
}
11272-
}
11273-
args.push_back(
11274-
gutils->invertPointerM(orig->getArgOperand(i), Builder2));
11275-
argsInverted.push_back(ty);
11197+
assert(argTy == DIFFE_TYPE::DUP_ARG || argTy == DIFFE_TYPE::DUP_NONEED);
1127611198

11277-
// Note sometimes whattype mistakenly says something should be
11278-
// constant [because composed of integer pointers alone]
11279-
assert(whatType(argType, Mode) == DIFFE_TYPE::DUP_ARG ||
11280-
whatType(argType, Mode) == DIFFE_TYPE::CONSTANT);
11281-
} else {
11282-
if (foreignFunction)
11283-
assert(!argType->isIntOrIntVectorTy());
11284-
11285-
args.push_back(diffe(orig->getArgOperand(i), Builder2));
11286-
argsInverted.push_back(DIFFE_TYPE::DUP_ARG);
11287-
}
11199+
args.push_back(
11200+
gutils->invertPointerM(orig->getArgOperand(i), Builder2));
1128811201
}
1128911202

1129011203
Optional<int> tapeIdx;
@@ -11478,33 +11391,18 @@ class AdjointGenerator
1147811391
args.push_back(lookup(argi, Builder2));
1147911392
}
1148011393

11481-
if (gutils->isConstantValue(orig->getArgOperand(i)) && !foreignFunction) {
11482-
argsInverted.push_back(DIFFE_TYPE::CONSTANT);
11394+
auto argTy =
11395+
gutils->getDiffeType(orig->getArgOperand(i), foreignFunction);
11396+
11397+
argsInverted.push_back(argTy);
11398+
11399+
if (argTy == DIFFE_TYPE::CONSTANT) {
1148311400
continue;
1148411401
}
1148511402

1148611403
auto argType = argi->getType();
1148711404

11488-
if (!argType->isFPOrFPVectorTy() &&
11489-
(TR.query(orig->getArgOperand(i)).Inner0().isPossiblePointer() ||
11490-
foreignFunction)) {
11491-
DIFFE_TYPE ty = DIFFE_TYPE::DUP_ARG;
11492-
if (argType->isPointerTy()) {
11493-
#if LLVM_VERSION_MAJOR >= 12
11494-
auto at = getUnderlyingObject(orig->getArgOperand(i), 100);
11495-
#else
11496-
auto at = GetUnderlyingObject(
11497-
orig->getArgOperand(i),
11498-
gutils->oldFunc->getParent()->getDataLayout(), 100);
11499-
#endif
11500-
if (auto arg = dyn_cast<Argument>(at)) {
11501-
if (constant_args[arg->getArgNo()] == DIFFE_TYPE::DUP_NONEED) {
11502-
ty = DIFFE_TYPE::DUP_NONEED;
11503-
}
11504-
}
11505-
}
11506-
argsInverted.push_back(ty);
11507-
11405+
if (argTy == DIFFE_TYPE::DUP_ARG || argTy == DIFFE_TYPE::DUP_NONEED) {
1150811406
if (Mode != DerivativeMode::ReverseModePrimal) {
1150911407
IRBuilder<> Builder2(call.getParent());
1151011408
getReverseBuilder(Builder2);
@@ -11522,7 +11420,6 @@ class AdjointGenerator
1152211420
} else {
1152311421
if (foreignFunction)
1152411422
assert(!argType->isIntOrIntVectorTy());
11525-
argsInverted.push_back(DIFFE_TYPE::OUT_DIFF);
1152611423
assert(whatType(argType, Mode) == DIFFE_TYPE::OUT_DIFF ||
1152711424
whatType(argType, Mode) == DIFFE_TYPE::CONSTANT);
1152811425
}

enzyme/Enzyme/CApi.cpp

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -318,6 +318,27 @@ CDerivativeMode EnzymeGradientUtilsGetMode(GradientUtils *gutils) {
318318
return (CDerivativeMode)gutils->mode;
319319
}
320320

321+
CDIFFE_TYPE
322+
EnzymeGradientUtilsGetDiffeType(GradientUtils *G, LLVMValueRef oval,
323+
uint8_t foreignFunction) {
324+
return (CDIFFE_TYPE)(G->getDiffeType(unwrap(oval), foreignFunction != 0));
325+
}
326+
327+
CDIFFE_TYPE
328+
EnzymeGradientUtilsGetReturnDiffeType(GradientUtils *G, LLVMValueRef oval,
329+
uint8_t *needsPrimal,
330+
uint8_t *needsShadow) {
331+
bool needsPrimalB;
332+
bool needsShadowB;
333+
auto res = (CDIFFE_TYPE)(G->getReturnDiffeType(cast<CallInst>(unwrap(oval)),
334+
&needsPrimalB, &needsShadowB));
335+
if (needsPrimal)
336+
*needsPrimal = needsPrimalB;
337+
if (needsShadow)
338+
*needsShadow = needsShadowB;
339+
return res;
340+
}
341+
321342
void EnzymeGradientUtilsSetDebugLocFromOriginal(GradientUtils *gutils,
322343
LLVMValueRef val,
323344
LLVMValueRef orig) {
@@ -367,6 +388,31 @@ LLVMBasicBlockRef EnzymeGradientUtilsAllocationBlock(GradientUtils *gutils) {
367388
return wrap(gutils->inversionAllocs);
368389
}
369390

391+
void EnzymeGradientUtilsGetUncacheableArgs(GradientUtils *gutils,
392+
LLVMValueRef orig, uint8_t *data,
393+
uint64_t size) {
394+
if (gutils->mode == DerivativeMode::ForwardMode)
395+
return;
396+
397+
CallInst *call = cast<CallInst>(unwrap(orig));
398+
399+
auto found = gutils->uncacheable_args_map_ptr->find(call);
400+
assert(found != gutils->uncacheable_args_map_ptr->end());
401+
402+
const std::map<Argument *, bool> &uncacheable_args = found->second;
403+
404+
auto Fn = getFunctionFromCall(call);
405+
assert(Fn);
406+
size_t cur = 0;
407+
for (auto &arg : Fn->args()) {
408+
assert(cur < size);
409+
auto found2 = uncacheable_args.find(&arg);
410+
assert(found2 != uncacheable_args.end());
411+
data[cur] = found2->second;
412+
cur++;
413+
}
414+
}
415+
370416
CTypeTreeRef EnzymeGradientUtilsAllocAndGetTypeTree(GradientUtils *gutils,
371417
LLVMValueRef val) {
372418
auto v = unwrap(val);

0 commit comments

Comments
 (0)