Skip to content

Commit 788cc8e

Browse files
authored
EnzymeLogic cleanup (rust-lang#320)
* rename CreateDual to CreateForwardDiff * add clearFunctionAttributes
1 parent bea9d39 commit 788cc8e

File tree

6 files changed

+374
-82
lines changed

6 files changed

+374
-82
lines changed

enzyme/Enzyme/AdjointGenerator.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7581,11 +7581,11 @@ class AdjointGenerator
75817581
}
75827582
}
75837583

7584-
auto newcalled = gutils->Logic.CreatePrimalAndGradient(
7584+
auto newcalled = gutils->Logic.CreateForwardDiff(
75857585
cast<Function>(called), subretType, argsInverted, gutils->TLI,
75867586
TR.analyzer.interprocedural, /*returnValue*/ retUsed,
75877587
/*subdretptr*/ false, DerivativeMode::ForwardMode, nullptr,
7588-
nextTypeInfo, uncacheable_args, nullptr,
7588+
nextTypeInfo, uncacheable_args,
75897589
/*AtomicAdd*/ gutils->AtomicAdd);
75907590

75917591
assert(newcalled);

enzyme/Enzyme/CApi.cpp

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -323,6 +323,29 @@ LLVMBasicBlockRef EnzymeGradientUtilsAllocationBlock(GradientUtils *gutils) {
323323
return wrap(gutils->inversionAllocs);
324324
}
325325

326+
LLVMValueRef EnzymeCreateForwardDiff(
327+
EnzymeLogicRef Logic, LLVMValueRef todiff, CDIFFE_TYPE retType,
328+
CDIFFE_TYPE *constant_args, size_t constant_args_size,
329+
EnzymeTypeAnalysisRef TA, uint8_t returnValue, uint8_t dretUsed,
330+
CDerivativeMode mode, LLVMTypeRef additionalArg, CFnTypeInfo typeInfo,
331+
uint8_t *_uncacheable_args, size_t uncacheable_args_size, uint8_t AtomicAdd,
332+
uint8_t PostOpt) {
333+
std::vector<DIFFE_TYPE> nconstant_args((DIFFE_TYPE *)constant_args,
334+
(DIFFE_TYPE *)constant_args +
335+
constant_args_size);
336+
std::map<llvm::Argument *, bool> uncacheable_args;
337+
size_t argnum = 0;
338+
for (auto &arg : cast<Function>(unwrap(todiff))->args()) {
339+
assert(argnum < uncacheable_args_size);
340+
uncacheable_args[&arg] = _uncacheable_args[argnum];
341+
argnum++;
342+
}
343+
return wrap(eunwrap(Logic).CreateForwardDiff(
344+
cast<Function>(unwrap(todiff)), (DIFFE_TYPE)retType, nconstant_args,
345+
eunwrap(TA).TLI, eunwrap(TA), returnValue, dretUsed, (DerivativeMode)mode,
346+
unwrap(additionalArg), eunwrap(typeInfo, cast<Function>(unwrap(todiff))),
347+
uncacheable_args, AtomicAdd, PostOpt));
348+
}
326349
LLVMValueRef EnzymeCreatePrimalAndGradient(
327350
EnzymeLogicRef Logic, LLVMValueRef todiff, CDIFFE_TYPE retType,
328351
CDIFFE_TYPE *constant_args, size_t constant_args_size,

enzyme/Enzyme/CApi.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,14 @@ typedef enum {
116116
DEM_ReverseModeCombined = 3,
117117
} CDerivativeMode;
118118

119+
LLVMValueRef EnzymeCreateForwardDiff(
120+
EnzymeLogicRef, LLVMValueRef todiff, CDIFFE_TYPE retType,
121+
CDIFFE_TYPE *constant_args, size_t constant_args_size,
122+
EnzymeTypeAnalysisRef TA, uint8_t returnValue, uint8_t dretUsed,
123+
CDerivativeMode mode, LLVMTypeRef additionalArg,
124+
struct CFnTypeInfo typeInfo, uint8_t *_uncacheable_args,
125+
size_t uncacheable_args_size, uint8_t AtomicAdd, uint8_t PostOpt);
126+
119127
LLVMValueRef EnzymeCreatePrimalAndGradient(
120128
EnzymeLogicRef, LLVMValueRef todiff, CDIFFE_TYPE retType,
121129
CDIFFE_TYPE *constant_args, size_t constant_args_size,

enzyme/Enzyme/Enzyme.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -482,6 +482,11 @@ class Enzyme : public ModulePass {
482482
Type *tapeType = nullptr;
483483
switch (mode) {
484484
case DerivativeMode::ForwardMode:
485+
newFunc = Logic.CreateForwardDiff(
486+
cast<Function>(fn), retType, constants, TLI, TA,
487+
/*should return*/ false, /*dretPtr*/ false, mode,
488+
/*addedType*/ nullptr, type_args, volatile_args, AtomicAdd, PostOpt);
489+
break;
485490
case DerivativeMode::ReverseModeCombined:
486491
newFunc = Logic.CreatePrimalAndGradient(
487492
cast<Function>(fn), retType, constants, TLI, TA,

0 commit comments

Comments
 (0)