|
51 | 51 | #include "llvm/Transforms/Utils/Cloning.h"
|
52 | 52 | #if LLVM_VERSION_MAJOR >= 11
|
53 | 53 | #include "llvm/Analysis/InlineAdvisor.h"
|
| 54 | +#include "llvm/IR/AbstractCallSite.h" |
54 | 55 | #endif
|
55 | 56 | #include "llvm/Analysis/BasicAliasAnalysis.h"
|
56 | 57 | #include "llvm/Analysis/GlobalsModRef.h"
|
|
61 | 62 | #include "ActivityAnalysis.h"
|
62 | 63 | #include "EnzymeLogic.h"
|
63 | 64 | #include "GradientUtils.h"
|
| 65 | +#include "TraceUtils.h" |
64 | 66 | #include "Utils.h"
|
65 | 67 |
|
66 | 68 | #include "InstructionBatcher.h"
|
@@ -704,15 +706,8 @@ class EnzymeBase {
|
704 | 706 | fn = CI->getArgOperand(1);
|
705 | 707 | }
|
706 | 708 |
|
707 |
| - while (auto ci = dyn_cast<CastInst>(fn)) { |
708 |
| - fn = ci->getOperand(0); |
709 |
| - } |
710 |
| - while (auto ci = dyn_cast<BlockAddress>(fn)) { |
711 |
| - fn = ci->getFunction(); |
712 |
| - } |
713 |
| - while (auto ci = dyn_cast<ConstantExpr>(fn)) { |
714 |
| - fn = ci->getOperand(0); |
715 |
| - } |
| 709 | + fn = GetFunctionFromValue(fn); |
| 710 | + |
716 | 711 | if (!isa<Function>(fn)) {
|
717 | 712 | EmitFailure("NoFunctionToDifferentiate", CI->getDebugLoc(), CI,
|
718 | 713 | "failed to find fn to differentiate", *CI, " - found - ",
|
@@ -1798,6 +1793,120 @@ class EnzymeBase {
|
1798 | 1793 | return true;
|
1799 | 1794 | }
|
1800 | 1795 |
|
| 1796 | + bool HandleProbProg(CallInst *CI, ProbProgMode mode) { |
| 1797 | + IRBuilder<> Builder(CI); |
| 1798 | + Function *F; |
| 1799 | + auto parsedFunction = parseFunctionParameter(CI); |
| 1800 | + if (parsedFunction.hasValue()) { |
| 1801 | + F = parsedFunction.getValue(); |
| 1802 | + } else { |
| 1803 | + return false; |
| 1804 | + } |
| 1805 | + |
| 1806 | + assert(F); |
| 1807 | + |
| 1808 | + bool sret = false; |
| 1809 | + SmallVector<Value *, 4> args; |
| 1810 | + Value *conditioning_trace = nullptr; |
| 1811 | + Value *dynamic_interface = nullptr; |
| 1812 | + |
| 1813 | +#if LLVM_VERSION_MAJOR >= 14 |
| 1814 | + for (unsigned i = 1 + sret; i < CI->arg_size(); ++i) |
| 1815 | +#else |
| 1816 | + for (unsigned i = 1 + sret; i < CI->getNumArgOperands(); ++i) |
| 1817 | +#endif |
| 1818 | + { |
| 1819 | + Value *res = CI->getArgOperand(i); |
| 1820 | + Optional<StringRef> metaString = getMetadataName(res); |
| 1821 | + |
| 1822 | + // handle metadata |
| 1823 | + if (metaString && metaString.getValue().startswith("enzyme_")) { |
| 1824 | + if (*metaString == "enzyme_interface") { |
| 1825 | + ++i; |
| 1826 | + dynamic_interface = CI->getArgOperand(i); |
| 1827 | + continue; |
| 1828 | + } else if (*metaString == "enzyme_condition") { |
| 1829 | + ++i; |
| 1830 | + conditioning_trace = CI->getArgOperand(i); |
| 1831 | + continue; |
| 1832 | + } else { |
| 1833 | + EmitFailure("IllegalDiffeType", CI->getDebugLoc(), CI, |
| 1834 | + "illegal enzyme metadata classification ", *CI, |
| 1835 | + *metaString); |
| 1836 | + return false; |
| 1837 | + } |
| 1838 | + } |
| 1839 | + |
| 1840 | + args.push_back(res); |
| 1841 | + } |
| 1842 | + |
| 1843 | + // Interface |
| 1844 | + |
| 1845 | + Function *sample = nullptr; |
| 1846 | + for (auto &&interface_func : F->getParent()->functions()) { |
| 1847 | + if (interface_func.getName().contains("__enzyme_sample")) { |
| 1848 | + assert(interface_func.getFunctionType()->getNumParams() >= 3); |
| 1849 | + sample = &interface_func; |
| 1850 | + } |
| 1851 | + } |
| 1852 | + |
| 1853 | + assert(sample); |
| 1854 | + |
| 1855 | + if (dynamic_interface) |
| 1856 | + args.push_back(dynamic_interface); |
| 1857 | + |
| 1858 | + if (mode == ProbProgMode::Condition) |
| 1859 | + args.push_back(conditioning_trace); |
| 1860 | + |
| 1861 | + // Determine generative functions |
| 1862 | + SmallPtrSet<Function *, 4> generativeFunctions; |
| 1863 | + SetVector<Function *, std::deque<Function *>> workList; |
| 1864 | + workList.insert(sample); |
| 1865 | + generativeFunctions.insert(sample); |
| 1866 | + |
| 1867 | + while (!workList.empty()) { |
| 1868 | + auto todo = *workList.begin(); |
| 1869 | + workList.erase(workList.begin()); |
| 1870 | + |
| 1871 | +#if LLVM_VERSION_MAJOR > 10 |
| 1872 | + for (auto &&U : todo->uses()) { |
| 1873 | + if (auto ACS = AbstractCallSite(&U)) { |
| 1874 | + auto fun = ACS.getInstruction()->getParent()->getParent(); |
| 1875 | + auto [it, inserted] = generativeFunctions.insert(fun); |
| 1876 | + if (inserted) |
| 1877 | + workList.insert(fun); |
| 1878 | + } |
| 1879 | + } |
| 1880 | +#else |
| 1881 | + for (auto &&U : todo->uses()) { |
| 1882 | + if (auto &&call = dyn_cast<CallInst>(U.getUser())) { |
| 1883 | + auto &&fun = call->getParent()->getParent(); |
| 1884 | + auto &&[it, inserted] = generativeFunctions.insert(fun); |
| 1885 | + if (inserted) |
| 1886 | + workList.insert(fun); |
| 1887 | + } |
| 1888 | + } |
| 1889 | +#endif |
| 1890 | + } |
| 1891 | + |
| 1892 | + auto newFunc = Logic.CreateTrace(F, generativeFunctions, mode, |
| 1893 | + dynamic_interface != nullptr); |
| 1894 | + |
| 1895 | + Value *trace = |
| 1896 | + Builder.CreateCall(newFunc->getFunctionType(), newFunc, args); |
| 1897 | + if (!F->getReturnType()->isVoidTy()) |
| 1898 | + trace = Builder.CreateExtractValue(trace, {1}); |
| 1899 | + |
| 1900 | + // try to cast i8* returned from trace to CI->getRetType.... |
| 1901 | + if (CI->getType() != trace->getType()) |
| 1902 | + trace = Builder.CreatePointerCast(trace, CI->getType()); |
| 1903 | + |
| 1904 | + CI->replaceAllUsesWith(trace); |
| 1905 | + CI->eraseFromParent(); |
| 1906 | + |
| 1907 | + return true; |
| 1908 | + } |
| 1909 | + |
1801 | 1910 | bool lowerEnzymeCalls(Function &F, std::set<Function *> &done) {
|
1802 | 1911 | if (done.count(&F))
|
1803 | 1912 | return false;
|
@@ -1838,7 +1947,9 @@ class EnzymeBase {
|
1838 | 1947 | Fn->getName().contains("__enzyme_augmentfwd") ||
|
1839 | 1948 | Fn->getName().contains("__enzyme_augmentsize") ||
|
1840 | 1949 | Fn->getName().contains("__enzyme_reverse") ||
|
1841 |
| - Fn->getName().contains("__enzyme_batch"))) |
| 1950 | + Fn->getName().contains("__enzyme_batch") || |
| 1951 | + Fn->getName().contains("__enzyme_trace") || |
| 1952 | + Fn->getName().contains("__enzyme_condition"))) |
1842 | 1953 | continue;
|
1843 | 1954 |
|
1844 | 1955 | SmallVector<Value *, 16> CallArgs(II->arg_begin(), II->arg_end());
|
@@ -1875,6 +1986,7 @@ class EnzymeBase {
|
1875 | 1986 | MapVector<CallInst *, DerivativeMode> toVirtual;
|
1876 | 1987 | MapVector<CallInst *, DerivativeMode> toSize;
|
1877 | 1988 | SmallVector<CallInst *, 4> toBatch;
|
| 1989 | + MapVector<CallInst *, ProbProgMode> toProbProg; |
1878 | 1990 | SetVector<CallInst *> InactiveCalls;
|
1879 | 1991 | SetVector<CallInst *> IterCalls;
|
1880 | 1992 | retry:;
|
@@ -2094,33 +2206,43 @@ class EnzymeBase {
|
2094 | 2206 | bool virtualCall = false;
|
2095 | 2207 | bool sizeOnly = false;
|
2096 | 2208 | bool batch = false;
|
2097 |
| - DerivativeMode mode; |
| 2209 | + bool probProg = false; |
| 2210 | + DerivativeMode derivativeMode; |
| 2211 | + ProbProgMode probProgMode; |
2098 | 2212 | if (Fn->getName().contains("__enzyme_autodiff")) {
|
2099 | 2213 | enableEnzyme = true;
|
2100 |
| - mode = DerivativeMode::ReverseModeCombined; |
| 2214 | + derivativeMode = DerivativeMode::ReverseModeCombined; |
2101 | 2215 | } else if (Fn->getName().contains("__enzyme_fwddiff")) {
|
2102 | 2216 | enableEnzyme = true;
|
2103 |
| - mode = DerivativeMode::ForwardMode; |
| 2217 | + derivativeMode = DerivativeMode::ForwardMode; |
2104 | 2218 | } else if (Fn->getName().contains("__enzyme_fwdsplit")) {
|
2105 | 2219 | enableEnzyme = true;
|
2106 |
| - mode = DerivativeMode::ForwardModeSplit; |
| 2220 | + derivativeMode = DerivativeMode::ForwardModeSplit; |
2107 | 2221 | } else if (Fn->getName().contains("__enzyme_augmentfwd")) {
|
2108 | 2222 | enableEnzyme = true;
|
2109 |
| - mode = DerivativeMode::ReverseModePrimal; |
| 2223 | + derivativeMode = DerivativeMode::ReverseModePrimal; |
2110 | 2224 | } else if (Fn->getName().contains("__enzyme_augmentsize")) {
|
2111 | 2225 | enableEnzyme = true;
|
2112 | 2226 | sizeOnly = true;
|
2113 |
| - mode = DerivativeMode::ReverseModePrimal; |
| 2227 | + derivativeMode = DerivativeMode::ReverseModePrimal; |
2114 | 2228 | } else if (Fn->getName().contains("__enzyme_reverse")) {
|
2115 | 2229 | enableEnzyme = true;
|
2116 |
| - mode = DerivativeMode::ReverseModeGradient; |
| 2230 | + derivativeMode = DerivativeMode::ReverseModeGradient; |
2117 | 2231 | } else if (Fn->getName().contains("__enzyme_virtualreverse")) {
|
2118 | 2232 | enableEnzyme = true;
|
2119 | 2233 | virtualCall = true;
|
2120 |
| - mode = DerivativeMode::ReverseModeCombined; |
| 2234 | + derivativeMode = DerivativeMode::ReverseModeCombined; |
2121 | 2235 | } else if (Fn->getName().contains("__enzyme_batch")) {
|
2122 | 2236 | enableEnzyme = true;
|
2123 | 2237 | batch = true;
|
| 2238 | + } else if (Fn->getName().contains("__enzyme_trace")) { |
| 2239 | + enableEnzyme = true; |
| 2240 | + probProgMode = ProbProgMode::Trace; |
| 2241 | + probProg = true; |
| 2242 | + } else if (Fn->getName().contains("__enzyme_condition")) { |
| 2243 | + enableEnzyme = true; |
| 2244 | + probProgMode = ProbProgMode::Condition; |
| 2245 | + probProg = true; |
2124 | 2246 | }
|
2125 | 2247 |
|
2126 | 2248 | if (enableEnzyme) {
|
@@ -2161,13 +2283,15 @@ class EnzymeBase {
|
2161 | 2283 | goto retry;
|
2162 | 2284 | }
|
2163 | 2285 | if (virtualCall)
|
2164 |
| - toVirtual[CI] = mode; |
| 2286 | + toVirtual[CI] = derivativeMode; |
2165 | 2287 | else if (sizeOnly)
|
2166 |
| - toSize[CI] = mode; |
| 2288 | + toSize[CI] = derivativeMode; |
2167 | 2289 | else if (batch)
|
2168 | 2290 | toBatch.push_back(CI);
|
2169 |
| - else |
2170 |
| - toLower[CI] = mode; |
| 2291 | + else if (probProg) { |
| 2292 | + toProbProg[CI] = probProgMode; |
| 2293 | + } else |
| 2294 | + toLower[CI] = derivativeMode; |
2171 | 2295 |
|
2172 | 2296 | if (auto dc = dyn_cast<Function>(fn)) {
|
2173 | 2297 | // Force postopt on any inner functions in the nested
|
@@ -2254,6 +2378,10 @@ class EnzymeBase {
|
2254 | 2378 | HandleBatch(call);
|
2255 | 2379 | }
|
2256 | 2380 |
|
| 2381 | + for (auto &&[call, mode] : toProbProg) { |
| 2382 | + HandleProbProg(call, mode); |
| 2383 | + } |
| 2384 | + |
2257 | 2385 | if (Changed && EnzymeAttributor) {
|
2258 | 2386 | // TODO consider enabling when attributor does not delete
|
2259 | 2387 | // dead internal functions, which invalidates Enzyme's cache
|
@@ -2453,6 +2581,91 @@ class EnzymeBase {
|
2453 | 2581 | changed = true;
|
2454 | 2582 | }
|
2455 | 2583 |
|
| 2584 | + SmallPtrSet<CallInst *, 4> sample_calls; |
| 2585 | + for (auto &&func : M) { |
| 2586 | + for (auto &&BB : func) { |
| 2587 | + for (auto &&Inst : BB) { |
| 2588 | + if (auto CI = dyn_cast<CallInst>(&Inst)) { |
| 2589 | + Function *enzyme_sample = CI->getCalledFunction(); |
| 2590 | + if (enzyme_sample && |
| 2591 | + enzyme_sample->getName().startswith("__enzyme_sample")) { |
| 2592 | + if (CI->getNumOperands() < 3) { |
| 2593 | + EmitFailure( |
| 2594 | + "IllegalNumberOfArguments", CI->getDebugLoc(), CI, |
| 2595 | + "Not enough arguments passed to call to __enzyme_sample"); |
| 2596 | + } |
| 2597 | + Function *samplefn = GetFunctionFromValue(CI->getOperand(0)); |
| 2598 | + unsigned expected = |
| 2599 | + samplefn->getFunctionType()->getNumParams() + 3; |
| 2600 | +#if LLVM_VERSION_MAJOR >= 14 |
| 2601 | + unsigned actual = CI->arg_size(); |
| 2602 | +#else |
| 2603 | + unsigned actual = CI->getNumArgOperands(); |
| 2604 | +#endif |
| 2605 | + if (actual - 3 != samplefn->getFunctionType()->getNumParams()) { |
| 2606 | + EmitFailure("IllegalNumberOfArguments", CI->getDebugLoc(), CI, |
| 2607 | + "Illegal number of arguments passed to call to " |
| 2608 | + "__enzyme_sample.", |
| 2609 | + " Expected: ", expected, " got: ", actual); |
| 2610 | + } |
| 2611 | + Function *pdf = GetFunctionFromValue(CI->getArgOperand(1)); |
| 2612 | + |
| 2613 | + for (unsigned i = 0; |
| 2614 | + i < samplefn->getFunctionType()->getNumParams(); ++i) { |
| 2615 | + Value *ci_arg = CI->getArgOperand(i + 3); |
| 2616 | + Value *sample_arg = samplefn->arg_begin() + i; |
| 2617 | + Value *pdf_arg = pdf->arg_begin() + i; |
| 2618 | + |
| 2619 | + if (ci_arg->getType() != sample_arg->getType()) { |
| 2620 | + EmitFailure( |
| 2621 | + "IllegalSampleType", CI->getDebugLoc(), CI, |
| 2622 | + "Type of: ", *ci_arg, " (", *ci_arg->getType(), ")", |
| 2623 | + " does not match the argument type of the sample " |
| 2624 | + "function: ", |
| 2625 | + *samplefn, " at: ", i, " (", *sample_arg->getType(), ")"); |
| 2626 | + } |
| 2627 | + if (ci_arg->getType() != pdf_arg->getType()) { |
| 2628 | + EmitFailure("IllegalSampleType", CI->getDebugLoc(), CI, |
| 2629 | + "Type of: ", *ci_arg, " (", *ci_arg->getType(), |
| 2630 | + ")", |
| 2631 | + " does not match the argument type of the " |
| 2632 | + "density function: ", |
| 2633 | + *pdf, " at: ", i, " (", *pdf_arg->getType(), ")"); |
| 2634 | + } |
| 2635 | + } |
| 2636 | + |
| 2637 | + if ((pdf->arg_end() - 1)->getType() != |
| 2638 | + samplefn->getReturnType()) { |
| 2639 | + EmitFailure( |
| 2640 | + "IllegalSampleType", CI->getDebugLoc(), CI, |
| 2641 | + "Return type of ", *samplefn, " (", |
| 2642 | + *samplefn->getReturnType(), ")", |
| 2643 | + " does not match the last argument type of the density " |
| 2644 | + "function: ", |
| 2645 | + *pdf, " (", *(pdf->arg_end() - 1)->getType(), ")"); |
| 2646 | + } |
| 2647 | + sample_calls.insert(CI); |
| 2648 | + } |
| 2649 | + } |
| 2650 | + } |
| 2651 | + } |
| 2652 | + } |
| 2653 | + |
| 2654 | + // Replace calls to __enzyme_sample with the actual sample calls after |
| 2655 | + // running prob prog |
| 2656 | + for (auto call : sample_calls) { |
| 2657 | + Function *samplefn = GetFunctionFromValue(call->getArgOperand(0)); |
| 2658 | + |
| 2659 | + SmallVector<Value *, 2> args; |
| 2660 | + for (auto it = call->arg_begin() + 3; it != call->arg_end(); it++) { |
| 2661 | + args.push_back(*it); |
| 2662 | + } |
| 2663 | + CallInst *choice = |
| 2664 | + CallInst::Create(samplefn->getFunctionType(), samplefn, args); |
| 2665 | + |
| 2666 | + ReplaceInstWithInst(call, choice); |
| 2667 | + } |
| 2668 | + |
2456 | 2669 | for (const auto &pair : Logic.PPC.cache)
|
2457 | 2670 | pair.second->eraseFromParent();
|
2458 | 2671 | Logic.clear();
|
|
0 commit comments