Skip to content

Commit 82d04cc

Browse files
authored
Add probabilistic programming capabilities (rust-lang#909)
* add tracing * add conditioning * add dynamic interface support * add tests
1 parent 038f63c commit 82d04cc

15 files changed

+2021
-22
lines changed

enzyme/Enzyme/Enzyme.cpp

Lines changed: 235 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@
5151
#include "llvm/Transforms/Utils/Cloning.h"
5252
#if LLVM_VERSION_MAJOR >= 11
5353
#include "llvm/Analysis/InlineAdvisor.h"
54+
#include "llvm/IR/AbstractCallSite.h"
5455
#endif
5556
#include "llvm/Analysis/BasicAliasAnalysis.h"
5657
#include "llvm/Analysis/GlobalsModRef.h"
@@ -61,6 +62,7 @@
6162
#include "ActivityAnalysis.h"
6263
#include "EnzymeLogic.h"
6364
#include "GradientUtils.h"
65+
#include "TraceUtils.h"
6466
#include "Utils.h"
6567

6668
#include "InstructionBatcher.h"
@@ -704,15 +706,8 @@ class EnzymeBase {
704706
fn = CI->getArgOperand(1);
705707
}
706708

707-
while (auto ci = dyn_cast<CastInst>(fn)) {
708-
fn = ci->getOperand(0);
709-
}
710-
while (auto ci = dyn_cast<BlockAddress>(fn)) {
711-
fn = ci->getFunction();
712-
}
713-
while (auto ci = dyn_cast<ConstantExpr>(fn)) {
714-
fn = ci->getOperand(0);
715-
}
709+
fn = GetFunctionFromValue(fn);
710+
716711
if (!isa<Function>(fn)) {
717712
EmitFailure("NoFunctionToDifferentiate", CI->getDebugLoc(), CI,
718713
"failed to find fn to differentiate", *CI, " - found - ",
@@ -1798,6 +1793,120 @@ class EnzymeBase {
17981793
return true;
17991794
}
18001795

1796+
bool HandleProbProg(CallInst *CI, ProbProgMode mode) {
1797+
IRBuilder<> Builder(CI);
1798+
Function *F;
1799+
auto parsedFunction = parseFunctionParameter(CI);
1800+
if (parsedFunction.hasValue()) {
1801+
F = parsedFunction.getValue();
1802+
} else {
1803+
return false;
1804+
}
1805+
1806+
assert(F);
1807+
1808+
bool sret = false;
1809+
SmallVector<Value *, 4> args;
1810+
Value *conditioning_trace = nullptr;
1811+
Value *dynamic_interface = nullptr;
1812+
1813+
#if LLVM_VERSION_MAJOR >= 14
1814+
for (unsigned i = 1 + sret; i < CI->arg_size(); ++i)
1815+
#else
1816+
for (unsigned i = 1 + sret; i < CI->getNumArgOperands(); ++i)
1817+
#endif
1818+
{
1819+
Value *res = CI->getArgOperand(i);
1820+
Optional<StringRef> metaString = getMetadataName(res);
1821+
1822+
// handle metadata
1823+
if (metaString && metaString.getValue().startswith("enzyme_")) {
1824+
if (*metaString == "enzyme_interface") {
1825+
++i;
1826+
dynamic_interface = CI->getArgOperand(i);
1827+
continue;
1828+
} else if (*metaString == "enzyme_condition") {
1829+
++i;
1830+
conditioning_trace = CI->getArgOperand(i);
1831+
continue;
1832+
} else {
1833+
EmitFailure("IllegalDiffeType", CI->getDebugLoc(), CI,
1834+
"illegal enzyme metadata classification ", *CI,
1835+
*metaString);
1836+
return false;
1837+
}
1838+
}
1839+
1840+
args.push_back(res);
1841+
}
1842+
1843+
// Interface
1844+
1845+
Function *sample = nullptr;
1846+
for (auto &&interface_func : F->getParent()->functions()) {
1847+
if (interface_func.getName().contains("__enzyme_sample")) {
1848+
assert(interface_func.getFunctionType()->getNumParams() >= 3);
1849+
sample = &interface_func;
1850+
}
1851+
}
1852+
1853+
assert(sample);
1854+
1855+
if (dynamic_interface)
1856+
args.push_back(dynamic_interface);
1857+
1858+
if (mode == ProbProgMode::Condition)
1859+
args.push_back(conditioning_trace);
1860+
1861+
// Determine generative functions
1862+
SmallPtrSet<Function *, 4> generativeFunctions;
1863+
SetVector<Function *, std::deque<Function *>> workList;
1864+
workList.insert(sample);
1865+
generativeFunctions.insert(sample);
1866+
1867+
while (!workList.empty()) {
1868+
auto todo = *workList.begin();
1869+
workList.erase(workList.begin());
1870+
1871+
#if LLVM_VERSION_MAJOR > 10
1872+
for (auto &&U : todo->uses()) {
1873+
if (auto ACS = AbstractCallSite(&U)) {
1874+
auto fun = ACS.getInstruction()->getParent()->getParent();
1875+
auto [it, inserted] = generativeFunctions.insert(fun);
1876+
if (inserted)
1877+
workList.insert(fun);
1878+
}
1879+
}
1880+
#else
1881+
for (auto &&U : todo->uses()) {
1882+
if (auto &&call = dyn_cast<CallInst>(U.getUser())) {
1883+
auto &&fun = call->getParent()->getParent();
1884+
auto &&[it, inserted] = generativeFunctions.insert(fun);
1885+
if (inserted)
1886+
workList.insert(fun);
1887+
}
1888+
}
1889+
#endif
1890+
}
1891+
1892+
auto newFunc = Logic.CreateTrace(F, generativeFunctions, mode,
1893+
dynamic_interface != nullptr);
1894+
1895+
Value *trace =
1896+
Builder.CreateCall(newFunc->getFunctionType(), newFunc, args);
1897+
if (!F->getReturnType()->isVoidTy())
1898+
trace = Builder.CreateExtractValue(trace, {1});
1899+
1900+
// try to cast i8* returned from trace to CI->getRetType....
1901+
if (CI->getType() != trace->getType())
1902+
trace = Builder.CreatePointerCast(trace, CI->getType());
1903+
1904+
CI->replaceAllUsesWith(trace);
1905+
CI->eraseFromParent();
1906+
1907+
return true;
1908+
}
1909+
18011910
bool lowerEnzymeCalls(Function &F, std::set<Function *> &done) {
18021911
if (done.count(&F))
18031912
return false;
@@ -1838,7 +1947,9 @@ class EnzymeBase {
18381947
Fn->getName().contains("__enzyme_augmentfwd") ||
18391948
Fn->getName().contains("__enzyme_augmentsize") ||
18401949
Fn->getName().contains("__enzyme_reverse") ||
1841-
Fn->getName().contains("__enzyme_batch")))
1950+
Fn->getName().contains("__enzyme_batch") ||
1951+
Fn->getName().contains("__enzyme_trace") ||
1952+
Fn->getName().contains("__enzyme_condition")))
18421953
continue;
18431954

18441955
SmallVector<Value *, 16> CallArgs(II->arg_begin(), II->arg_end());
@@ -1875,6 +1986,7 @@ class EnzymeBase {
18751986
MapVector<CallInst *, DerivativeMode> toVirtual;
18761987
MapVector<CallInst *, DerivativeMode> toSize;
18771988
SmallVector<CallInst *, 4> toBatch;
1989+
MapVector<CallInst *, ProbProgMode> toProbProg;
18781990
SetVector<CallInst *> InactiveCalls;
18791991
SetVector<CallInst *> IterCalls;
18801992
retry:;
@@ -2094,33 +2206,43 @@ class EnzymeBase {
20942206
bool virtualCall = false;
20952207
bool sizeOnly = false;
20962208
bool batch = false;
2097-
DerivativeMode mode;
2209+
bool probProg = false;
2210+
DerivativeMode derivativeMode;
2211+
ProbProgMode probProgMode;
20982212
if (Fn->getName().contains("__enzyme_autodiff")) {
20992213
enableEnzyme = true;
2100-
mode = DerivativeMode::ReverseModeCombined;
2214+
derivativeMode = DerivativeMode::ReverseModeCombined;
21012215
} else if (Fn->getName().contains("__enzyme_fwddiff")) {
21022216
enableEnzyme = true;
2103-
mode = DerivativeMode::ForwardMode;
2217+
derivativeMode = DerivativeMode::ForwardMode;
21042218
} else if (Fn->getName().contains("__enzyme_fwdsplit")) {
21052219
enableEnzyme = true;
2106-
mode = DerivativeMode::ForwardModeSplit;
2220+
derivativeMode = DerivativeMode::ForwardModeSplit;
21072221
} else if (Fn->getName().contains("__enzyme_augmentfwd")) {
21082222
enableEnzyme = true;
2109-
mode = DerivativeMode::ReverseModePrimal;
2223+
derivativeMode = DerivativeMode::ReverseModePrimal;
21102224
} else if (Fn->getName().contains("__enzyme_augmentsize")) {
21112225
enableEnzyme = true;
21122226
sizeOnly = true;
2113-
mode = DerivativeMode::ReverseModePrimal;
2227+
derivativeMode = DerivativeMode::ReverseModePrimal;
21142228
} else if (Fn->getName().contains("__enzyme_reverse")) {
21152229
enableEnzyme = true;
2116-
mode = DerivativeMode::ReverseModeGradient;
2230+
derivativeMode = DerivativeMode::ReverseModeGradient;
21172231
} else if (Fn->getName().contains("__enzyme_virtualreverse")) {
21182232
enableEnzyme = true;
21192233
virtualCall = true;
2120-
mode = DerivativeMode::ReverseModeCombined;
2234+
derivativeMode = DerivativeMode::ReverseModeCombined;
21212235
} else if (Fn->getName().contains("__enzyme_batch")) {
21222236
enableEnzyme = true;
21232237
batch = true;
2238+
} else if (Fn->getName().contains("__enzyme_trace")) {
2239+
enableEnzyme = true;
2240+
probProgMode = ProbProgMode::Trace;
2241+
probProg = true;
2242+
} else if (Fn->getName().contains("__enzyme_condition")) {
2243+
enableEnzyme = true;
2244+
probProgMode = ProbProgMode::Condition;
2245+
probProg = true;
21242246
}
21252247

21262248
if (enableEnzyme) {
@@ -2161,13 +2283,15 @@ class EnzymeBase {
21612283
goto retry;
21622284
}
21632285
if (virtualCall)
2164-
toVirtual[CI] = mode;
2286+
toVirtual[CI] = derivativeMode;
21652287
else if (sizeOnly)
2166-
toSize[CI] = mode;
2288+
toSize[CI] = derivativeMode;
21672289
else if (batch)
21682290
toBatch.push_back(CI);
2169-
else
2170-
toLower[CI] = mode;
2291+
else if (probProg) {
2292+
toProbProg[CI] = probProgMode;
2293+
} else
2294+
toLower[CI] = derivativeMode;
21712295

21722296
if (auto dc = dyn_cast<Function>(fn)) {
21732297
// Force postopt on any inner functions in the nested
@@ -2254,6 +2378,10 @@ class EnzymeBase {
22542378
HandleBatch(call);
22552379
}
22562380

2381+
for (auto &&[call, mode] : toProbProg) {
2382+
HandleProbProg(call, mode);
2383+
}
2384+
22572385
if (Changed && EnzymeAttributor) {
22582386
// TODO consider enabling when attributor does not delete
22592387
// dead internal functions, which invalidates Enzyme's cache
@@ -2453,6 +2581,91 @@ class EnzymeBase {
24532581
changed = true;
24542582
}
24552583

2584+
SmallPtrSet<CallInst *, 4> sample_calls;
2585+
for (auto &&func : M) {
2586+
for (auto &&BB : func) {
2587+
for (auto &&Inst : BB) {
2588+
if (auto CI = dyn_cast<CallInst>(&Inst)) {
2589+
Function *enzyme_sample = CI->getCalledFunction();
2590+
if (enzyme_sample &&
2591+
enzyme_sample->getName().startswith("__enzyme_sample")) {
2592+
if (CI->getNumOperands() < 3) {
2593+
EmitFailure(
2594+
"IllegalNumberOfArguments", CI->getDebugLoc(), CI,
2595+
"Not enough arguments passed to call to __enzyme_sample");
2596+
}
2597+
Function *samplefn = GetFunctionFromValue(CI->getOperand(0));
2598+
unsigned expected =
2599+
samplefn->getFunctionType()->getNumParams() + 3;
2600+
#if LLVM_VERSION_MAJOR >= 14
2601+
unsigned actual = CI->arg_size();
2602+
#else
2603+
unsigned actual = CI->getNumArgOperands();
2604+
#endif
2605+
if (actual - 3 != samplefn->getFunctionType()->getNumParams()) {
2606+
EmitFailure("IllegalNumberOfArguments", CI->getDebugLoc(), CI,
2607+
"Illegal number of arguments passed to call to "
2608+
"__enzyme_sample.",
2609+
" Expected: ", expected, " got: ", actual);
2610+
}
2611+
Function *pdf = GetFunctionFromValue(CI->getArgOperand(1));
2612+
2613+
for (unsigned i = 0;
2614+
i < samplefn->getFunctionType()->getNumParams(); ++i) {
2615+
Value *ci_arg = CI->getArgOperand(i + 3);
2616+
Value *sample_arg = samplefn->arg_begin() + i;
2617+
Value *pdf_arg = pdf->arg_begin() + i;
2618+
2619+
if (ci_arg->getType() != sample_arg->getType()) {
2620+
EmitFailure(
2621+
"IllegalSampleType", CI->getDebugLoc(), CI,
2622+
"Type of: ", *ci_arg, " (", *ci_arg->getType(), ")",
2623+
" does not match the argument type of the sample "
2624+
"function: ",
2625+
*samplefn, " at: ", i, " (", *sample_arg->getType(), ")");
2626+
}
2627+
if (ci_arg->getType() != pdf_arg->getType()) {
2628+
EmitFailure("IllegalSampleType", CI->getDebugLoc(), CI,
2629+
"Type of: ", *ci_arg, " (", *ci_arg->getType(),
2630+
")",
2631+
" does not match the argument type of the "
2632+
"density function: ",
2633+
*pdf, " at: ", i, " (", *pdf_arg->getType(), ")");
2634+
}
2635+
}
2636+
2637+
if ((pdf->arg_end() - 1)->getType() !=
2638+
samplefn->getReturnType()) {
2639+
EmitFailure(
2640+
"IllegalSampleType", CI->getDebugLoc(), CI,
2641+
"Return type of ", *samplefn, " (",
2642+
*samplefn->getReturnType(), ")",
2643+
" does not match the last argument type of the density "
2644+
"function: ",
2645+
*pdf, " (", *(pdf->arg_end() - 1)->getType(), ")");
2646+
}
2647+
sample_calls.insert(CI);
2648+
}
2649+
}
2650+
}
2651+
}
2652+
}
2653+
2654+
// Replace calls to __enzyme_sample with the actual sample calls after
2655+
// running prob prog
2656+
for (auto call : sample_calls) {
2657+
Function *samplefn = GetFunctionFromValue(call->getArgOperand(0));
2658+
2659+
SmallVector<Value *, 2> args;
2660+
for (auto it = call->arg_begin() + 3; it != call->arg_end(); it++) {
2661+
args.push_back(*it);
2662+
}
2663+
CallInst *choice =
2664+
CallInst::Create(samplefn->getFunctionType(), samplefn, args);
2665+
2666+
ReplaceInstWithInst(call, choice);
2667+
}
2668+
24562669
for (const auto &pair : Logic.PPC.cache)
24572670
pair.second->eraseFromParent();
24582671
Logic.clear();

enzyme/Enzyme/EnzymeLogic.cpp

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@
6060
#include "GradientUtils.h"
6161
#include "InstructionBatcher.h"
6262
#include "LibraryFuncs.h"
63+
#include "TraceGenerator.h"
6364
#include "Utils.h"
6465

6566
#if LLVM_VERSION_MAJOR >= 14
@@ -4805,6 +4806,40 @@ llvm::Function *EnzymeLogic::CreateBatch(Function *tobatch, unsigned width,
48054806
return BatchCachedFunctions[tup] = NewF;
48064807
};
48074808

4809+
llvm::Function *
4810+
EnzymeLogic::CreateTrace(llvm::Function *totrace,
4811+
SmallPtrSetImpl<Function *> &GenerativeFunctions,
4812+
ProbProgMode mode, bool dynamic_interface) {
4813+
TraceCacheKey tup = std::make_tuple(totrace, mode, dynamic_interface);
4814+
if (TraceCachedFunctions.find(tup) != TraceCachedFunctions.end()) {
4815+
return TraceCachedFunctions.find(tup)->second;
4816+
}
4817+
4818+
TraceUtils *tutils =
4819+
new TraceUtils(mode, dynamic_interface, totrace, GenerativeFunctions);
4820+
4821+
TraceGenerator *tracer = new TraceGenerator(*this, tutils);
4822+
4823+
for (auto &&BB : *totrace) {
4824+
for (auto &&Inst : BB) {
4825+
tracer->visit(Inst);
4826+
}
4827+
}
4828+
4829+
if (llvm::verifyFunction(*tutils->newFunc, &llvm::errs())) {
4830+
llvm::errs() << *totrace << "\n";
4831+
llvm::errs() << *tutils->newFunc << "\n";
4832+
report_fatal_error("function failed verification (4)");
4833+
}
4834+
4835+
Function *NewF = tutils->newFunc;
4836+
4837+
delete tracer;
4838+
delete tutils;
4839+
4840+
return TraceCachedFunctions[tup] = NewF;
4841+
}
4842+
48084843
llvm::Value *EnzymeLogic::CreateNoFree(llvm::Value *todiff) {
48094844
if (auto F = dyn_cast<Function>(todiff))
48104845
return CreateNoFree(F);

0 commit comments

Comments
 (0)