Skip to content

Commit 5dac0d3

Browse files
committed
Use custom reverse in combined mode
1 parent 788cc8e commit 5dac0d3

File tree

2 files changed

+144
-10
lines changed

2 files changed

+144
-10
lines changed

enzyme/Enzyme/EnzymeLogic.cpp

Lines changed: 85 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2689,8 +2689,91 @@ Function *EnzymeLogic::CreatePrimalAndGradient(
26892689
}
26902690
}
26912691

2692-
if (!hasconstant && mode != DerivativeMode::ReverseModeCombined &&
2693-
!returnValue && hasMetadata(todiff, "enzyme_gradient")) {
2692+
if (!hasconstant && !returnValue && hasMetadata(todiff, "enzyme_gradient")) {
2693+
2694+
DIFFE_TYPE subretType = todiff->getReturnType()->isFPOrFPVectorTy()
2695+
? DIFFE_TYPE::OUT_DIFF
2696+
: DIFFE_TYPE::DUP_ARG;
2697+
if (todiff->getReturnType()->isVoidTy() ||
2698+
todiff->getReturnType()->isEmptyTy())
2699+
subretType = DIFFE_TYPE::CONSTANT;
2700+
assert(subretType == retType);
2701+
2702+
auto res = getDefaultFunctionTypeForGradient(todiff->getFunctionType(),
2703+
/*retType*/ retType);
2704+
2705+
if (mode == DerivativeMode::ReverseModeCombined) {
2706+
2707+
FunctionType *FTy =
2708+
FunctionType::get(StructType::get(todiff->getContext(), {res.second}),
2709+
res.first, todiff->getFunctionType()->isVarArg());
2710+
Function *NewF = Function::Create(
2711+
FTy, Function::LinkageTypes::InternalLinkage,
2712+
"fixgradient_" + todiff->getName(), todiff->getParent());
2713+
2714+
BasicBlock *BB = BasicBlock::Create(NewF->getContext(), "entry", NewF);
2715+
IRBuilder<> bb(BB);
2716+
2717+
auto &aug = CreateAugmentedPrimal(
2718+
todiff, retType, constant_args, TLI, TA, returnUsed, oldTypeInfo_,
2719+
_uncacheable_args, /*forceAnonymousTape*/ false, AtomicAdd, PostOpt,
2720+
omp);
2721+
2722+
SmallVector<Value *, 4> fwdargs;
2723+
for (auto &a : NewF->args())
2724+
fwdargs.push_back(&a);
2725+
if (retType == DIFFE_TYPE::OUT_DIFF)
2726+
fwdargs.pop_back();
2727+
auto cal = bb.CreateCall(aug.fn, fwdargs);
2728+
cal->setCallingConv(aug.fn->getCallingConv());
2729+
2730+
llvm::Value *tape = nullptr;
2731+
2732+
if (aug.returns.find(AugmentedStruct::Tape) != aug.returns.end()) {
2733+
auto tapeIdx = aug.returns.find(AugmentedStruct::Tape)->second;
2734+
tape = (tapeIdx == -1) ? cal : bb.CreateExtractValue(cal, tapeIdx);
2735+
}
2736+
2737+
if (aug.tapeType) {
2738+
assert(tape);
2739+
auto tapep =
2740+
bb.CreatePointerCast(tape, PointerType::getUnqual(aug.tapeType));
2741+
auto truetape = bb.CreateLoad(tapep, "tapeld");
2742+
truetape->setMetadata("enzyme_mustcache",
2743+
MDNode::get(truetape->getContext(), {}));
2744+
2745+
CallInst *ci = cast<CallInst>(CallInst::CreateFree(tape, BB));
2746+
bb.Insert(ci);
2747+
ci->addAttribute(AttributeList::FirstArgIndex, Attribute::NonNull);
2748+
tape = truetape;
2749+
}
2750+
2751+
auto revfn = CreatePrimalAndGradient(
2752+
todiff, retType, constant_args, TLI, TA,
2753+
/*returnUsed*/ false, /*dretPtr*/ false,
2754+
/*mode*/ DerivativeMode::ReverseModeGradient,
2755+
/*additionalArg*/ tape ? tape->getType() : nullptr, oldTypeInfo_,
2756+
_uncacheable_args, &aug, AtomicAdd, PostOpt, omp);
2757+
2758+
SmallVector<Value *, 4> revargs;
2759+
for (auto &a : NewF->args()) {
2760+
revargs.push_back(&a);
2761+
}
2762+
if (tape) {
2763+
revargs.push_back(tape);
2764+
}
2765+
auto revcal = bb.CreateCall(revfn, revargs);
2766+
revcal->setCallingConv(revfn->getCallingConv());
2767+
if (NewF->getReturnType()->isEmptyTy())
2768+
bb.CreateRet(UndefValue::get(NewF->getReturnType()));
2769+
else
2770+
bb.CreateRet(revcal);
2771+
assert(!returnUsed);
2772+
2773+
return insert_or_assign2<ReverseCacheKey, Function *>(
2774+
ReverseCachedFunctions, tup, NewF)
2775+
->second;
2776+
}
26942777

26952778
auto md = todiff->getMetadata("enzyme_gradient");
26962779
if (!isa<MDTuple>(md)) {
@@ -2704,14 +2787,6 @@ Function *EnzymeLogic::CreatePrimalAndGradient(
27042787
auto gvemd = cast<ConstantAsMetadata>(md2->getOperand(0));
27052788
auto foundcalled = cast<Function>(gvemd->getValue());
27062789

2707-
DIFFE_TYPE subretType = todiff->getReturnType()->isFPOrFPVectorTy()
2708-
? DIFFE_TYPE::OUT_DIFF
2709-
: DIFFE_TYPE::DUP_ARG;
2710-
if (todiff->getReturnType()->isVoidTy() ||
2711-
todiff->getReturnType()->isEmptyTy())
2712-
subretType = DIFFE_TYPE::CONSTANT;
2713-
auto res = getDefaultFunctionTypeForGradient(todiff->getFunctionType(),
2714-
/*retType*/ subretType);
27152790
assert(augmenteddata);
27162791
if (foundcalled->arg_size() == res.first.size() + 1 /*tape*/) {
27172792
auto lastarg = foundcalled->arg_end();
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
// RUN: %clang -std=c11 -O0 %s -S -emit-llvm -o - | %opt - %loadEnzyme -enzyme -S | %lli -
2+
// RUN: %clang -std=c11 -O1 %s -S -emit-llvm -o - | %opt - %loadEnzyme -enzyme -S | %lli -
3+
// RUN: %clang -std=c11 -O2 %s -S -emit-llvm -o - | %opt - %loadEnzyme -enzyme -S | %lli -
4+
// RUN: %clang -std=c11 -O3 %s -S -emit-llvm -o - | %opt - %loadEnzyme -enzyme -S | %lli -
5+
// RUN: %clang -std=c11 -O0 %s -S -emit-llvm -o - | %opt - %loadEnzyme -enzyme -enzyme-inline=1 -S | %lli -
6+
// RUN: %clang -std=c11 -O1 %s -S -emit-llvm -o - | %opt - %loadEnzyme -enzyme -enzyme-inline=1 -S | %lli -
7+
// RUN: %clang -std=c11 -O2 %s -S -emit-llvm -o - | %opt - %loadEnzyme -enzyme -enzyme-inline=1 -S | %lli -
8+
// RUN: %clang -std=c11 -O3 %s -S -emit-llvm -o - | %opt - %loadEnzyme -enzyme -enzyme-inline=1 -S | %lli -
9+
10+
#include "test_utils.h"
11+
12+
double __enzyme_autodiff(void*, double);
13+
14+
__attribute__((noinline))
15+
void square_(const double* src, double* dest) {
16+
*dest = *src * *src;
17+
}
18+
19+
int augment = 0;
20+
void* augment_square_(const double* src, const double *d_src, double* dest, double* d_dest) {
21+
augment++;
22+
// intentionally incorrect for debugging
23+
*dest = 7.0;
24+
*d_dest = 11.0;
25+
return NULL;
26+
}
27+
28+
int gradient = 0;
29+
void gradient_square_(const double* src, double *d_src, const double* dest, const double* d_dest, void* tape) {
30+
gradient++;
31+
// intentionally incorrect for debugging
32+
*d_src = 13.0;
33+
}
34+
35+
void* __enzyme_register_gradient_square[] = {
36+
(void*)square_,
37+
(void*)augment_square_,
38+
(void*)gradient_square_,
39+
};
40+
41+
42+
double square(double x) {
43+
double y;
44+
square_(&x, &y);
45+
return y;
46+
}
47+
48+
double dsquare(double x) {
49+
return __enzyme_autodiff((void*)square, x);
50+
}
51+
52+
53+
int main() {
54+
double res = dsquare(3.0);
55+
printf("res=%f augment=%d gradient=%d\n", res, augment, gradient);
56+
APPROX_EQ(res, 13.0, 1e-10);
57+
APPROX_EQ(augment, 1.0, 1e-10);
58+
APPROX_EQ(gradient, 1.0, 1e-10);
59+
}

0 commit comments

Comments
 (0)