Skip to content

Commit dbb6552

Browse files
committed
Sugar for virtual functions
1 parent 8f0a66a commit dbb6552

File tree

4 files changed

+185
-75
lines changed

4 files changed

+185
-75
lines changed

enzyme/Enzyme/Enzyme.cpp

Lines changed: 51 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -725,6 +725,7 @@ class Enzyme : public ModulePass {
725725
Fn->getName() == "__enzyme_double" ||
726726
Fn->getName() == "__enzyme_integer" ||
727727
Fn->getName() == "__enzyme_pointer" ||
728+
Fn->getName().contains("__enzyme_virtualreverse") ||
728729
Fn->getName().contains("__enzyme_call_inactive") ||
729730
Fn->getName().contains("__enzyme_autodiff") ||
730731
Fn->getName().contains("__enzyme_fwddiff") ||
@@ -763,6 +764,7 @@ class Enzyme : public ModulePass {
763764
}
764765

765766
std::map<CallInst *, DerivativeMode> toLower;
767+
std::map<CallInst *, DerivativeMode> toVirtual;
766768
std::set<CallInst *> InactiveCalls;
767769
std::set<CallInst *> IterCalls;
768770
retry:;
@@ -824,6 +826,10 @@ class Enzyme : public ModulePass {
824826
}
825827
}
826828
}
829+
if (Fn->getName() == "__enzyme_virtualreverse") {
830+
Fn->addFnAttr(Attribute::ReadNone);
831+
CI->addAttribute(AttributeList::FunctionIndex, Attribute::ReadNone);
832+
}
827833
if (Fn->getName() == "__enzyme_iter") {
828834
Fn->addFnAttr(Attribute::ReadNone);
829835
CI->addAttribute(AttributeList::FunctionIndex, Attribute::ReadNone);
@@ -960,6 +966,7 @@ class Enzyme : public ModulePass {
960966
}
961967

962968
bool enableEnzyme = false;
969+
bool virtualCall = false;
963970
DerivativeMode mode;
964971
if (Fn->getName().contains("__enzyme_autodiff")) {
965972
enableEnzyme = true;
@@ -973,10 +980,13 @@ class Enzyme : public ModulePass {
973980
} else if (Fn->getName().contains("__enzyme_reverse")) {
974981
enableEnzyme = true;
975982
mode = DerivativeMode::ReverseModeGradient;
983+
} else if (Fn->getName().contains("__enzyme_virtualreverse")) {
984+
enableEnzyme = true;
985+
virtualCall = true;
986+
mode = DerivativeMode::ReverseModeCombined;
976987
}
977988

978989
if (enableEnzyme) {
979-
toLower[CI] = mode;
980990

981991
Value *fn = CI->getArgOperand(0);
982992
while (auto ci = dyn_cast<CastInst>(fn)) {
@@ -1013,9 +1023,16 @@ class Enzyme : public ModulePass {
10131023
}
10141024
goto retry;
10151025
}
1016-
if (auto dc = dyn_cast<Function>(fn))
1026+
1027+
if (virtualCall)
1028+
toVirtual[CI] = mode;
1029+
else
1030+
toLower[CI] = mode;
1031+
1032+
if (auto dc = dyn_cast<Function>(fn)) {
10171033
Changed |=
10181034
lowerEnzymeCalls(*dc, /*PostOpt*/ true, successful, done);
1035+
}
10191036
}
10201037
}
10211038
}
@@ -1048,6 +1065,36 @@ class Enzyme : public ModulePass {
10481065
break;
10491066
}
10501067

1068+
for (auto pair : toVirtual) {
1069+
auto CI = pair.first;
1070+
Value *fn = CI->getArgOperand(0);
1071+
while (auto ci = dyn_cast<CastInst>(fn)) {
1072+
fn = ci->getOperand(0);
1073+
}
1074+
while (auto ci = dyn_cast<BlockAddress>(fn)) {
1075+
fn = ci->getFunction();
1076+
}
1077+
while (auto ci = dyn_cast<ConstantExpr>(fn)) {
1078+
fn = ci->getOperand(0);
1079+
}
1080+
auto F = cast<Function>(fn);
1081+
TypeAnalysis TA(TLI);
1082+
1083+
auto Arch =
1084+
llvm::Triple(
1085+
CI->getParent()->getParent()->getParent()->getTargetTriple())
1086+
.getArch();
1087+
1088+
bool AtomicAdd = Arch == Triple::nvptx || Arch == Triple::nvptx64 ||
1089+
Arch == Triple::amdgcn;
1090+
1091+
auto val = GradientUtils::GetOrCreateShadowFunction(Logic, TLI, TA, F,
1092+
AtomicAdd, PostOpt);
1093+
CI->replaceAllUsesWith(ConstantExpr::getPointerCast(val, CI->getType()));
1094+
CI->eraseFromParent();
1095+
Changed = true;
1096+
}
1097+
10511098
if (Changed) {
10521099
// TODO consider enabling when attributor does not delete
10531100
// dead internal functions, which invalidates Enzyme's cache
@@ -1199,7 +1246,8 @@ class Enzyme : public ModulePass {
11991246
for (Function &F : M) {
12001247
if (F.getName() == "__enzyme_float" || F.getName() == "__enzyme_double" ||
12011248
F.getName() == "__enzyme_integer" ||
1202-
F.getName() == "__enzyme_pointer") {
1249+
F.getName() == "__enzyme_pointer" ||
1250+
F.getName().contains("__enzyme_virtualreverse")) {
12031251
F.addFnAttr(Attribute::ReadNone);
12041252
for (auto &arg : F.args()) {
12051253
if (arg.getType()->isPointerTy()) {

enzyme/Enzyme/GradientUtils.cpp

Lines changed: 79 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -2487,6 +2487,84 @@ DiffeGradientUtils *DiffeGradientUtils::CreateFromClone(
24872487
return res;
24882488
}
24892489

2490+
Constant *GradientUtils::GetOrCreateShadowFunction(EnzymeLogic &Logic,
2491+
TargetLibraryInfo &TLI,
2492+
TypeAnalysis &TA,
2493+
Function *fn, bool AtomicAdd,
2494+
bool PostOpt) {
2495+
//! Todo allow tape propagation
2496+
// Note that specifically this should _not_ be called with topLevel=true
2497+
// (since it may not be valid to always assume we can recompute the
2498+
// augmented primal) However, in the absence of a way to pass tape data
2499+
// from an indirect augmented (and also since we dont presently allow
2500+
// indirect augmented calls), topLevel MUST be true otherwise subcalls will
2501+
// not be able to lookup the augmenteddata/subdata (triggering an assertion
2502+
// failure, among much worse)
2503+
std::map<Argument *, bool> uncacheable_args;
2504+
FnTypeInfo type_args(fn);
2505+
2506+
// conservatively assume that we can only cache existing floating types
2507+
// (i.e. that all args are uncacheable)
2508+
std::vector<DIFFE_TYPE> types;
2509+
for (auto &a : fn->args()) {
2510+
uncacheable_args[&a] = !a.getType()->isFPOrFPVectorTy();
2511+
type_args.Arguments.insert(std::pair<Argument *, TypeTree>(&a, {}));
2512+
type_args.KnownValues.insert(
2513+
std::pair<Argument *, std::set<int64_t>>(&a, {}));
2514+
DIFFE_TYPE typ;
2515+
if (a.getType()->isFPOrFPVectorTy()) {
2516+
typ = DIFFE_TYPE::OUT_DIFF;
2517+
} else if (a.getType()->isIntegerTy() &&
2518+
cast<IntegerType>(a.getType())->getBitWidth() < 16) {
2519+
typ = DIFFE_TYPE::CONSTANT;
2520+
} else if (a.getType()->isVoidTy() || a.getType()->isEmptyTy()) {
2521+
typ = DIFFE_TYPE::CONSTANT;
2522+
} else {
2523+
typ = DIFFE_TYPE::DUP_ARG;
2524+
}
2525+
types.push_back(typ);
2526+
}
2527+
2528+
DIFFE_TYPE retType = fn->getReturnType()->isFPOrFPVectorTy()
2529+
? DIFFE_TYPE::OUT_DIFF
2530+
: DIFFE_TYPE::DUP_ARG;
2531+
if (fn->getReturnType()->isVoidTy() || fn->getReturnType()->isEmptyTy() ||
2532+
(fn->getReturnType()->isIntegerTy() &&
2533+
cast<IntegerType>(fn->getReturnType())->getBitWidth() < 16))
2534+
retType = DIFFE_TYPE::CONSTANT;
2535+
2536+
// TODO re atomic add consider forcing it to be atomic always as fallback if
2537+
// used in a parallel context
2538+
auto &augdata = Logic.CreateAugmentedPrimal(
2539+
fn, retType, /*constant_args*/ types, TLI, TA,
2540+
/*returnUsed*/ !fn->getReturnType()->isEmptyTy() &&
2541+
!fn->getReturnType()->isVoidTy(),
2542+
type_args, uncacheable_args, /*forceAnonymousTape*/ true, AtomicAdd,
2543+
PostOpt);
2544+
Constant *newf = Logic.CreatePrimalAndGradient(
2545+
fn, retType, /*constant_args*/ types, TLI, TA,
2546+
/*returnValue*/ false, /*dretPtr*/ false,
2547+
DerivativeMode::ReverseModeGradient,
2548+
/*additionalArg*/ Type::getInt8PtrTy(fn->getContext()), type_args,
2549+
uncacheable_args,
2550+
/*map*/ &augdata, AtomicAdd);
2551+
if (!newf)
2552+
newf = UndefValue::get(fn->getType());
2553+
auto cdata = ConstantStruct::get(
2554+
StructType::get(newf->getContext(),
2555+
{augdata.fn->getType(), newf->getType()}),
2556+
{augdata.fn, newf});
2557+
std::string globalname = ("_enzyme_" + fn->getName() + "'").str();
2558+
auto GV = fn->getParent()->getNamedValue(globalname);
2559+
2560+
if (GV == nullptr) {
2561+
GV = new GlobalVariable(*fn->getParent(), cdata->getType(), true,
2562+
GlobalValue::LinkageTypes::InternalLinkage, cdata,
2563+
globalname);
2564+
}
2565+
return ConstantExpr::getPointerCast(GV, fn->getType());
2566+
}
2567+
24902568
Value *GradientUtils::invertPointerM(Value *const oval, IRBuilder<> &BuilderM,
24912569
bool nullShadow) {
24922570
assert(oval);
@@ -2768,78 +2846,7 @@ Value *GradientUtils::invertPointerM(Value *const oval, IRBuilder<> &BuilderM,
27682846
std::make_pair((const Value *)oval, InvertedPointerVH(this, cs)));
27692847
return cs;
27702848
} else if (auto fn = dyn_cast<Function>(oval)) {
2771-
//! Todo allow tape propagation
2772-
// Note that specifically this should _not_ be called with topLevel=true
2773-
// (since it may not be valid to always assume we can recompute the
2774-
// augmented primal) However, in the absence of a way to pass tape data
2775-
// from an indirect augmented (and also since we dont presently allow
2776-
// indirect augmented calls), topLevel MUST be true otherwise subcalls will
2777-
// not be able to lookup the augmenteddata/subdata (triggering an assertion
2778-
// failure, among much worse)
2779-
std::map<Argument *, bool> uncacheable_args;
2780-
FnTypeInfo type_args(fn);
2781-
2782-
// conservatively assume that we can only cache existing floating types
2783-
// (i.e. that all args are uncacheable)
2784-
std::vector<DIFFE_TYPE> types;
2785-
for (auto &a : fn->args()) {
2786-
uncacheable_args[&a] = !a.getType()->isFPOrFPVectorTy();
2787-
type_args.Arguments.insert(std::pair<Argument *, TypeTree>(&a, {}));
2788-
type_args.KnownValues.insert(
2789-
std::pair<Argument *, std::set<int64_t>>(&a, {}));
2790-
DIFFE_TYPE typ;
2791-
if (a.getType()->isFPOrFPVectorTy()) {
2792-
typ = DIFFE_TYPE::OUT_DIFF;
2793-
} else if (a.getType()->isIntegerTy() &&
2794-
cast<IntegerType>(a.getType())->getBitWidth() < 16) {
2795-
typ = DIFFE_TYPE::CONSTANT;
2796-
} else if (a.getType()->isVoidTy() || a.getType()->isEmptyTy()) {
2797-
typ = DIFFE_TYPE::CONSTANT;
2798-
} else {
2799-
typ = DIFFE_TYPE::DUP_ARG;
2800-
}
2801-
types.push_back(typ);
2802-
}
2803-
2804-
DIFFE_TYPE retType = fn->getReturnType()->isFPOrFPVectorTy()
2805-
? DIFFE_TYPE::OUT_DIFF
2806-
: DIFFE_TYPE::DUP_ARG;
2807-
if (fn->getReturnType()->isVoidTy() || fn->getReturnType()->isEmptyTy() ||
2808-
(fn->getReturnType()->isIntegerTy() &&
2809-
cast<IntegerType>(fn->getReturnType())->getBitWidth() < 16))
2810-
retType = DIFFE_TYPE::CONSTANT;
2811-
2812-
// TODO re atomic add consider forcing it to be atomic always as fallback if
2813-
// used in a parallel context
2814-
auto &augdata = Logic.CreateAugmentedPrimal(
2815-
fn, retType, /*constant_args*/ types, TLI, TA,
2816-
/*returnUsed*/ !fn->getReturnType()->isEmptyTy() &&
2817-
!fn->getReturnType()->isVoidTy(),
2818-
type_args, uncacheable_args, /*forceAnonymousTape*/ true, AtomicAdd,
2819-
/*PostOpt*/ false);
2820-
Constant *newf = Logic.CreatePrimalAndGradient(
2821-
fn, retType, /*constant_args*/ types, TLI, TA,
2822-
/*returnValue*/ false, /*dretPtr*/ false,
2823-
DerivativeMode::ReverseModeGradient,
2824-
/*additionalArg*/ Type::getInt8PtrTy(fn->getContext()), type_args,
2825-
uncacheable_args,
2826-
/*map*/ &augdata, AtomicAdd);
2827-
if (!newf)
2828-
newf = UndefValue::get(fn->getType());
2829-
auto cdata = ConstantStruct::get(
2830-
StructType::get(newf->getContext(),
2831-
{augdata.fn->getType(), newf->getType()}),
2832-
{augdata.fn, newf});
2833-
std::string globalname = ("_enzyme_" + fn->getName() + "'").str();
2834-
auto GV = fn->getParent()->getNamedValue(globalname);
2835-
2836-
if (GV == nullptr) {
2837-
GV = new GlobalVariable(*fn->getParent(), cdata->getType(), true,
2838-
GlobalValue::LinkageTypes::InternalLinkage, cdata,
2839-
globalname);
2840-
}
2841-
2842-
return BuilderM.CreatePointerCast(GV, fn->getType());
2849+
return GetOrCreateShadowFunction(Logic, TLI, TA, fn, AtomicAdd);
28432850
} else if (auto arg = dyn_cast<CastInst>(oval)) {
28442851
IRBuilder<> bb(getNewFromOriginal(arg));
28452852
Value *invertOp = invertPointerM(arg->getOperand(0), bb);

enzyme/Enzyme/GradientUtils.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1323,6 +1323,12 @@ class GradientUtils : public CacheUtility {
13231323
Value *invertPointerM(Value *val, IRBuilder<> &BuilderM,
13241324
bool nullShadow = false);
13251325

1326+
static Constant *GetOrCreateShadowFunction(EnzymeLogic &Logic,
1327+
TargetLibraryInfo &TLI,
1328+
TypeAnalysis &TA, Function *F,
1329+
bool AtomicAdd = true,
1330+
bool PostOpt = false);
1331+
13261332
void branchToCorrespondingTarget(
13271333
BasicBlock *ctx, IRBuilder<> &BuilderM,
13281334
const std::map<BasicBlock *,
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
// RUN: %clang++ -std=c++11 -fno-exceptions -ffast-math -O0 %s -S -emit-llvm -o - | %opt - %loadEnzyme -enzyme -S | %lli -
2+
// RUN: %clang++ -std=c++11 -fno-exceptions -ffast-math -O1 %s -S -emit-llvm -o - | %opt - %loadEnzyme -enzyme -S | %lli -
3+
// RUN: %clang++ -std=c++11 -fno-exceptions -ffast-math -O2 %s -S -emit-llvm -o - | %opt - %loadEnzyme -enzyme -S | %lli -
4+
// RUN: %clang++ -std=c++11 -fno-exceptions -ffast-math -O3 %s -S -emit-llvm -o - | %opt - %loadEnzyme -enzyme -S | %lli -
5+
// RUN: %clang++ -std=c++11 -fno-exceptions -ffast-math -O0 %s -S -emit-llvm -o - | %opt - %loadEnzyme -enzyme -enzyme-inline=1 -S | %lli -
6+
// RUN: %clang++ -std=c++11 -fno-exceptions -ffast-math -O1 %s -S -emit-llvm -o - | %opt - %loadEnzyme -enzyme -enzyme-inline=1 -S | %lli -
7+
// RUN: %clang++ -std=c++11 -fno-exceptions -ffast-math -O2 %s -S -emit-llvm -o - | %opt - %loadEnzyme -enzyme -enzyme-inline=1 -S | %lli -
8+
// RUN: %clang++ -std=c++11 -fno-exceptions -ffast-math -O3 %s -S -emit-llvm -o - | %opt - %loadEnzyme -enzyme -enzyme-inline=1 -S | %lli -
9+
10+
#include <stdio.h>
11+
#include "test_utils.h"
12+
13+
struct S {
14+
double (*fn)(double);
15+
double val;
16+
};
17+
18+
double square(double x){ return x * x; }
19+
20+
double foo(struct S* s) {
21+
return square(s->val);
22+
}
23+
24+
25+
void primal() {
26+
struct S s;
27+
s.fn = square;
28+
s.val = 3.0;
29+
printf("%f\n", foo(&s));
30+
}
31+
32+
void* __enzyme_virtualreverse(void*);
33+
void __enzyme_autodiff(void*, void*, void*);
34+
void reverse() {
35+
struct S s;
36+
s.fn = square;
37+
s.val = 3.0;
38+
struct S d_s;
39+
d_s.fn = (double (*)(double))__enzyme_virtualreverse((void*)square);
40+
d_s.val = 0.0;
41+
__enzyme_autodiff((void*)foo, &s, &d_s);
42+
printf("shadow res=%f\n", d_s.val);
43+
APPROX_EQ(d_s.val, 6.0, 1e-7);
44+
}
45+
46+
int main() {
47+
primal();
48+
reverse();
49+
}

0 commit comments

Comments
 (0)