Skip to content

Commit d03aa0f

Browse files
committed
Fix formatting
1 parent b6f9c9b commit d03aa0f

File tree

5 files changed

+94
-65
lines changed

5 files changed

+94
-65
lines changed

enzyme/Enzyme/AdjointGenerator.h

Lines changed: 31 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -2964,17 +2964,17 @@ class AdjointGenerator
29642964
funcName = called->getName();
29652965
} else {
29662966
#if LLVM_VERSION_MAJOR >= 11
2967-
if (auto castinst = dyn_cast<ConstantExpr>(orig->getCalledOperand())) {
2967+
if (auto castinst = dyn_cast<ConstantExpr>(orig->getCalledOperand())) {
29682968
#else
2969-
if (auto castinst = dyn_cast<ConstantExpr>(orig->getCalledValue())) {
2969+
if (auto castinst = dyn_cast<ConstantExpr>(orig->getCalledValue())) {
29702970
#endif
2971-
if (castinst->isCast())
2972-
if (auto fn = dyn_cast<Function>(castinst->getOperand(0))) {
2973-
if (fn->hasFnAttribute("enzyme_math"))
2974-
funcName = fn->getFnAttribute("enzyme_math").getValueAsString();
2975-
else
2976-
funcName = fn->getName();
2977-
}
2971+
if (castinst->isCast())
2972+
if (auto fn = dyn_cast<Function>(castinst->getOperand(0))) {
2973+
if (fn->hasFnAttribute("enzyme_math"))
2974+
funcName = fn->getFnAttribute("enzyme_math").getValueAsString();
2975+
else
2976+
funcName = fn->getName();
2977+
}
29782978
}
29792979
}
29802980

@@ -2993,9 +2993,9 @@ class AdjointGenerator
29932993
IRBuilder<> Builder2(call.getParent());
29942994
if (Mode == DerivativeMode::ReverseModeGradient ||
29952995
Mode == DerivativeMode::ReverseModeCombined)
2996-
getReverseBuilder(Builder2);
2996+
getReverseBuilder(Builder2);
29972997

2998-
Value* invertedReturn = nullptr;
2998+
Value *invertedReturn = nullptr;
29992999
bool hasNonReturnUse = false;
30003000
if (gutils->invertedPointers.count(orig)) {
30013001
//! We only need the shadow pointer for non-forward Mode if it is used
@@ -3012,27 +3012,33 @@ class AdjointGenerator
30123012
invertedReturn = cast<PHINode>(gutils->invertedPointers[orig]);
30133013
}
30143014

3015-
Value* normalReturn = subretused ? newF : nullptr;
3015+
Value *normalReturn = subretused ? newF : nullptr;
30163016

3017-
Value* tape = nullptr;
3017+
Value *tape = nullptr;
30183018

30193019
if (Mode == DerivativeMode::ReverseModePrimal ||
30203020
Mode == DerivativeMode::ReverseModeCombined) {
3021-
found->second.first(BuilderZ, orig, *gutils, normalReturn, invertedReturn, tape);
3021+
found->second.first(BuilderZ, orig, *gutils, normalReturn,
3022+
invertedReturn, tape);
30223023
if (tape)
3023-
gutils->cacheForReverse(BuilderZ, tape, getIndex(orig, CacheType::Tape));
3024+
gutils->cacheForReverse(BuilderZ, tape,
3025+
getIndex(orig, CacheType::Tape));
30243026
}
30253027

30263028
if (Mode == DerivativeMode::ReverseModeGradient ||
30273029
Mode == DerivativeMode::ReverseModeCombined) {
30283030
if (Mode == DerivativeMode::ReverseModeGradient &&
3029-
augmentedReturn->tapeIndices.find(std::make_pair(orig, CacheType::Tape)) != augmentedReturn->tapeIndices.end()) {
3031+
augmentedReturn->tapeIndices.find(std::make_pair(
3032+
orig, CacheType::Tape)) != augmentedReturn->tapeIndices.end()) {
30303033
tape = Builder2.CreatePHI(Type::getInt32Ty(orig->getContext()), 0);
3031-
tape = gutils->cacheForReverse(Builder2, (Value*)0x01, getIndex(orig, CacheType::Tape), /*ignoreType*/true);
3034+
tape = gutils->cacheForReverse(Builder2, (Value *)0x01,
3035+
getIndex(orig, CacheType::Tape),
3036+
/*ignoreType*/ true);
30323037
}
30333038
if (tape)
30343039
tape = gutils->lookupM(tape, Builder2);
3035-
found->second.second(Builder2, orig, *(DiffeGradientUtils*)gutils, tape);
3040+
found->second.second(Builder2, orig, *(DiffeGradientUtils *)gutils,
3041+
tape);
30363042
}
30373043

30383044
if (gutils->invertedPointers.count(orig)) {
@@ -3055,10 +3061,11 @@ class AdjointGenerator
30553061
assert(invertedReturn->getType() == orig->getType());
30563062
placeholder->replaceAllUsesWith(invertedReturn);
30573063
gutils->erase(placeholder);
3058-
} else invertedReturn = placeholder;
3064+
} else
3065+
invertedReturn = placeholder;
30593066

3060-
invertedReturn = gutils->cacheForReverse(BuilderZ, invertedReturn,
3061-
getIndex(orig, CacheType::Shadow));
3067+
invertedReturn = gutils->cacheForReverse(
3068+
BuilderZ, invertedReturn, getIndex(orig, CacheType::Shadow));
30623069

30633070
gutils->invertedPointers[orig] = invertedReturn;
30643071
}
@@ -3071,9 +3078,10 @@ class AdjointGenerator
30713078
BuilderZ.SetInsertPoint(newF->getNextNode());
30723079
gutils->erase(newF);
30733080
}
3074-
normalReturn = gutils->cacheForReverse(BuilderZ, normalReturn, getIndex(orig, CacheType::Self));
3081+
normalReturn = gutils->cacheForReverse(BuilderZ, normalReturn,
3082+
getIndex(orig, CacheType::Self));
30753083
} else {
3076-
eraseIfUnused(*orig, /*erase*/true, /*check*/false);
3084+
eraseIfUnused(*orig, /*erase*/ true, /*check*/ false);
30773085
}
30783086
return;
30793087
}

enzyme/Enzyme/CApi.cpp

Lines changed: 37 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,13 @@
2222
//
2323
//===----------------------------------------------------------------------===//
2424
#include "CApi.h"
25-
#include "SCEV/TargetLibraryInfo.h"
26-
#include "GradientUtils.h"
25+
#include "SCEV/ScalarEvolution.h"
26+
#include "SCEV/ScalarEvolutionExpander.h"
27+
2728
#include "EnzymeLogic.h"
29+
#include "GradientUtils.h"
2830
#include "LibraryFuncs.h"
31+
#include "SCEV/TargetLibraryInfo.h"
2932

3033
#include "llvm/ADT/Triple.h"
3134
#include "llvm/Analysis/CallGraph.h"
@@ -239,7 +242,7 @@ void EnzymeRegisterAllocationHandler(char *Name, CustomShadowAlloc AHandle,
239242
}
240243

241244
void EnzymeRegisterFunctionHandler(char *Name, CustomShadowAlloc AHandle,
242-
CustomShadowFree FHandle) {
245+
CustomShadowFree FHandle) {
243246
shadowHandlers[std::string(Name)] =
244247
[=](IRBuilder<> &B, CallInst *CI,
245248
ArrayRef<Value *> Args) -> llvm::Value * {
@@ -256,60 +259,70 @@ void EnzymeRegisterFunctionHandler(char *Name, CustomShadowAlloc AHandle,
256259
}
257260

258261
void EnzymeRegisterCallHandler(char *Name, CustomFunctionForward FwdHandle,
259-
CustomFunctionReverse RevHandle) {
262+
CustomFunctionReverse RevHandle) {
260263
auto &pair = customCallHandlers[std::string(Name)];
261-
pair.first = [=](IRBuilder<> &B, CallInst *CI, GradientUtils& gutils,
262-
Value*& normalReturn, Value*& shadowReturn, Value*& tape) {
263-
LLVMValueRef normalR = wrap(normalReturn);
264-
LLVMValueRef shadowR = wrap(shadowReturn);
265-
LLVMValueRef tapeR = wrap(tape);
266-
FwdHandle(wrap(&B), wrap(CI), &gutils, &normalR, &shadowR, &tapeR);
267-
normalReturn = unwrap(normalR);
268-
shadowReturn = unwrap(shadowR);
269-
tape = unwrap(tapeR);
264+
pair.first = [=](IRBuilder<> &B, CallInst *CI, GradientUtils &gutils,
265+
Value *&normalReturn, Value *&shadowReturn, Value *&tape) {
266+
LLVMValueRef normalR = wrap(normalReturn);
267+
LLVMValueRef shadowR = wrap(shadowReturn);
268+
LLVMValueRef tapeR = wrap(tape);
269+
FwdHandle(wrap(&B), wrap(CI), &gutils, &normalR, &shadowR, &tapeR);
270+
normalReturn = unwrap(normalR);
271+
shadowReturn = unwrap(shadowR);
272+
tape = unwrap(tapeR);
270273
};
271-
pair.second = [=](IRBuilder<> &B, CallInst *CI, DiffeGradientUtils& gutils, Value* tape) {
274+
pair.second = [=](IRBuilder<> &B, CallInst *CI, DiffeGradientUtils &gutils,
275+
Value *tape) {
272276
RevHandle(wrap(&B), wrap(CI), &gutils, wrap(tape));
273277
};
274278
}
275279

276-
LLVMValueRef EnzymeGradientUtilsNewFromOriginal(GradientUtils* gutils, LLVMValueRef val) {
280+
LLVMValueRef EnzymeGradientUtilsNewFromOriginal(GradientUtils *gutils,
281+
LLVMValueRef val) {
277282
return wrap(gutils->getNewFromOriginal(unwrap(val)));
278283
}
279284

280-
LLVMValueRef EnzymeGradientUtilsLookup(GradientUtils* gutils, LLVMValueRef val, LLVMBuilderRef B) {
285+
LLVMValueRef EnzymeGradientUtilsLookup(GradientUtils *gutils, LLVMValueRef val,
286+
LLVMBuilderRef B) {
281287
return wrap(gutils->lookupM(unwrap(val), *unwrap(B)));
282288
}
283289

284-
LLVMValueRef EnzymeGradientUtilsInvertPointer(GradientUtils* gutils, LLVMValueRef val, LLVMBuilderRef B) {
290+
LLVMValueRef EnzymeGradientUtilsInvertPointer(GradientUtils *gutils,
291+
LLVMValueRef val,
292+
LLVMBuilderRef B) {
285293
return wrap(gutils->invertPointerM(unwrap(val), *unwrap(B)));
286294
}
287295

288-
LLVMValueRef EnzymeGradientUtilsDiffe(DiffeGradientUtils* gutils, LLVMValueRef val, LLVMBuilderRef B) {
296+
LLVMValueRef EnzymeGradientUtilsDiffe(DiffeGradientUtils *gutils,
297+
LLVMValueRef val, LLVMBuilderRef B) {
289298
return wrap(gutils->diffe(unwrap(val), *unwrap(B)));
290299
}
291300

292-
void EnzymeGradientUtilsAddToDiffe(DiffeGradientUtils* gutils, LLVMValueRef val, LLVMValueRef diffe, LLVMBuilderRef B, LLVMTypeRef T) {
301+
void EnzymeGradientUtilsAddToDiffe(DiffeGradientUtils *gutils, LLVMValueRef val,
302+
LLVMValueRef diffe, LLVMBuilderRef B,
303+
LLVMTypeRef T) {
293304
gutils->addToDiffe(unwrap(val), unwrap(diffe), *unwrap(B), unwrap(T));
294305
}
295306

296-
void EnzymeGradientUtilsSetDiffe(DiffeGradientUtils* gutils, LLVMValueRef val, LLVMValueRef diffe, LLVMBuilderRef B) {
307+
void EnzymeGradientUtilsSetDiffe(DiffeGradientUtils *gutils, LLVMValueRef val,
308+
LLVMValueRef diffe, LLVMBuilderRef B) {
297309
gutils->setDiffe(unwrap(val), unwrap(diffe), *unwrap(B));
298310
}
299311

300-
uint8_t EnzymeGradientUtilsIsConstantValue(GradientUtils* gutils, LLVMValueRef val) {
312+
uint8_t EnzymeGradientUtilsIsConstantValue(GradientUtils *gutils,
313+
LLVMValueRef val) {
301314
return gutils->isConstantValue(unwrap(val));
302315
}
303316

304-
uint8_t EnzymeGradientUtilsIsConstantInstruction(GradientUtils* gutils, LLVMValueRef val) {
317+
uint8_t EnzymeGradientUtilsIsConstantInstruction(GradientUtils *gutils,
318+
LLVMValueRef val) {
305319
return gutils->isConstantInstruction(cast<Instruction>(unwrap(val)));
306320
}
307321

308-
LLVMBasicBlockRef EnzymeGradientUtilsAllocationBlock(GradientUtils* gutils) {
322+
LLVMBasicBlockRef EnzymeGradientUtilsAllocationBlock(GradientUtils *gutils) {
309323
return wrap(gutils->inversionAllocs);
310324
}
311325

312-
313326
LLVMValueRef EnzymeCreatePrimalAndGradient(
314327
EnzymeLogicRef Logic, LLVMValueRef todiff, CDIFFE_TYPE retType,
315328
CDIFFE_TYPE *constant_args, size_t constant_args_size,

enzyme/Enzyme/CApi.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -165,10 +165,11 @@ class GradientUtils;
165165
class DiffeGradientUtils;
166166

167167
typedef void (*CustomFunctionForward)(LLVMBuilderRef, LLVMValueRef,
168-
GradientUtils*, LLVMValueRef*, LLVMValueRef*, LLVMValueRef*);
168+
GradientUtils *, LLVMValueRef *,
169+
LLVMValueRef *, LLVMValueRef *);
169170

170171
typedef void (*CustomFunctionReverse)(LLVMBuilderRef, LLVMValueRef,
171-
DiffeGradientUtils*, LLVMValueRef);
172+
DiffeGradientUtils *, LLVMValueRef);
172173

173174
#ifdef __cplusplus
174175
}

enzyme/Enzyme/GradientUtils.cpp

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -51,12 +51,13 @@ std::map<std::string,
5151
std::function<llvm::CallInst *(IRBuilder<> &, Value *, Function *)>>
5252
shadowErasers;
5353

54-
std::map<std::string,
55-
std::pair<
56-
std::function<void(IRBuilder<> &, CallInst *, GradientUtils&, Value*&, Value*&, Value*&)>,
57-
std::function<void(IRBuilder<> &, CallInst *, DiffeGradientUtils&, Value*)>
58-
>
59-
> customCallHandlers;
54+
std::map<
55+
std::string,
56+
std::pair<std::function<void(IRBuilder<> &, CallInst *, GradientUtils &,
57+
Value *&, Value *&, Value *&)>,
58+
std::function<void(IRBuilder<> &, CallInst *,
59+
DiffeGradientUtils &, Value *)>>>
60+
customCallHandlers;
6061

6162
extern "C" {
6263
llvm::cl::opt<bool>

enzyme/Enzyme/GradientUtils.h

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -79,12 +79,14 @@ extern std::map<std::string, std::function<llvm::Value *(
7979

8080
class GradientUtils;
8181
class DiffeGradientUtils;
82-
extern std::map<std::string,
83-
std::pair<
84-
std::function<void(llvm::IRBuilder<> &, llvm::CallInst *, GradientUtils&, llvm::Value*&, llvm::Value*&, llvm::Value*&)>,
85-
std::function<void(llvm::IRBuilder<> &, llvm::CallInst *, DiffeGradientUtils&, llvm::Value*)>
86-
>
87-
> customCallHandlers;
82+
extern std::map<
83+
std::string,
84+
std::pair<std::function<void(llvm::IRBuilder<> &, llvm::CallInst *,
85+
GradientUtils &, llvm::Value *&,
86+
llvm::Value *&, llvm::Value *&)>,
87+
std::function<void(llvm::IRBuilder<> &, llvm::CallInst *,
88+
DiffeGradientUtils &, llvm::Value *)>>>
89+
customCallHandlers;
8890

8991
extern "C" {
9092
extern llvm::cl::opt<bool> EnzymeInactiveDynamic;
@@ -627,9 +629,12 @@ class GradientUtils : public CacheUtility {
627629

628630
if (tape == nullptr) {
629631
if (orig->getCalledFunction()->getName() == "julia.gc_alloc_obj") {
630-
Type *tys[] = { PointerType::get(StructType::get(orig->getContext()), 10) };
631-
FunctionType* FT = FunctionType::get(Type::getVoidTy(orig->getContext()), tys, true);
632-
bb.CreateCall(oldFunc->getParent()->getOrInsertFunction("julia.write_barrier", FT),
632+
Type *tys[] = {
633+
PointerType::get(StructType::get(orig->getContext()), 10)};
634+
FunctionType *FT =
635+
FunctionType::get(Type::getVoidTy(orig->getContext()), tys, true);
636+
bb.CreateCall(oldFunc->getParent()->getOrInsertFunction(
637+
"julia.write_barrier", FT),
633638
anti);
634639
if (mode != DerivativeMode::ReverseModeCombined) {
635640
EmitFailure("SplitGCAllocation", orig->getDebugLoc(), orig,
@@ -723,7 +728,8 @@ class GradientUtils : public CacheUtility {
723728
}
724729
}
725730

726-
Value *cacheForReverse(IRBuilder<> &BuilderQ, Value *malloc, int idx, bool ignoreType=false);
731+
Value *cacheForReverse(IRBuilder<> &BuilderQ, Value *malloc, int idx,
732+
bool ignoreType = false);
727733

728734
const SmallVectorImpl<Value *> &getTapeValues() const {
729735
return addedTapeVals;

0 commit comments

Comments
 (0)