Skip to content

Commit b6f9c9b

Browse files
committed
More powerful C API
1 parent ac6bb8d commit b6f9c9b

File tree

3 files changed

+18
-12
lines changed

3 files changed

+18
-12
lines changed

enzyme/Enzyme/AdjointGenerator.h

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3017,17 +3017,22 @@ class AdjointGenerator
30173017
Value* tape = nullptr;
30183018

30193019
if (Mode == DerivativeMode::ReverseModePrimal ||
3020-
Mode == DerivativeMode::ReverseModeCombined)
3020+
Mode == DerivativeMode::ReverseModeCombined) {
30213021
found->second.first(BuilderZ, orig, *gutils, normalReturn, invertedReturn, tape);
3022+
if (tape)
3023+
gutils->cacheForReverse(BuilderZ, tape, getIndex(orig, CacheType::Tape));
3024+
}
30223025

30233026
if (Mode == DerivativeMode::ReverseModeGradient ||
3024-
Mode == DerivativeMode::ReverseModeCombined)
3027+
Mode == DerivativeMode::ReverseModeCombined) {
3028+
if (Mode == DerivativeMode::ReverseModeGradient &&
3029+
augmentedReturn->tapeIndices.find(std::make_pair(orig, CacheType::Tape)) != augmentedReturn->tapeIndices.end()) {
3030+
tape = Builder2.CreatePHI(Type::getInt32Ty(orig->getContext()), 0);
3031+
tape = gutils->cacheForReverse(Builder2, (Value*)0x01, getIndex(orig, CacheType::Tape), /*ignoreType*/true);
3032+
}
3033+
if (tape)
3034+
tape = gutils->lookupM(tape, Builder2);
30253035
found->second.second(Builder2, orig, *(DiffeGradientUtils*)gutils, tape);
3026-
3027-
assert(!tape && "Tape mechanism not implemented for custom yet");
3028-
3029-
if (Mode == DerivativeMode::ReverseModePrimal && tape) {
3030-
gutils->cacheForReverse(BuilderZ, tape, getIndex(orig, CacheType::Tape));
30313036
}
30323037

30333038
if (gutils->invertedPointers.count(orig)) {

enzyme/Enzyme/GradientUtils.cpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1157,7 +1157,7 @@ Value *GradientUtils::unwrapM(Value *const val, IRBuilder<> &BuilderM,
11571157
}
11581158

11591159
Value *GradientUtils::cacheForReverse(IRBuilder<> &BuilderQ, Value *malloc,
1160-
int idx) {
1160+
int idx, bool ignoreType) {
11611161
assert(malloc);
11621162
assert(BuilderQ.GetInsertBlock()->getParent() == newFunc);
11631163
if (mode == DerivativeMode::ReverseModeCombined) {
@@ -1268,7 +1268,7 @@ Value *GradientUtils::cacheForReverse(IRBuilder<> &BuilderQ, Value *malloc,
12681268
innerType != ret->getType()) {
12691269
assert(innerType == Type::getInt8Ty(malloc->getContext()));
12701270
} else {
1271-
if (innerType != malloc->getType()) {
1271+
if (!ignoreType && innerType != malloc->getType()) {
12721272
llvm::errs() << *oldFunc << "\n";
12731273
llvm::errs() << *newFunc << "\n";
12741274
llvm::errs() << "innerType: " << *innerType << "\n";
@@ -1289,7 +1289,7 @@ Value *GradientUtils::cacheForReverse(IRBuilder<> &BuilderQ, Value *malloc,
12891289

12901290
auto v = lookupValueFromCache(/*forwardPass*/ true, BuilderQ, lctx, cache,
12911291
isi1);
1292-
if (malloc) {
1292+
if (!ignoreType && malloc) {
12931293
assert(v->getType() == malloc->getType());
12941294
}
12951295
insert_or_assign(scopeMap, v, std::make_pair(cache, ctx));
@@ -1462,7 +1462,8 @@ Value *GradientUtils::cacheForReverse(IRBuilder<> &BuilderQ, Value *malloc,
14621462
}
14631463
}
14641464
// llvm::errs() << "replacing " << *malloc << " with " << *ret << "\n";
1465-
cast<Instruction>(malloc)->replaceAllUsesWith(ret);
1465+
if (!ignoreType)
1466+
cast<Instruction>(malloc)->replaceAllUsesWith(ret);
14661467
std::string n = malloc->getName().str();
14671468
erase(cast<Instruction>(malloc));
14681469
ret->setName(n);

enzyme/Enzyme/GradientUtils.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -723,7 +723,7 @@ class GradientUtils : public CacheUtility {
723723
}
724724
}
725725

726-
Value *cacheForReverse(IRBuilder<> &BuilderQ, Value *malloc, int idx);
726+
Value *cacheForReverse(IRBuilder<> &BuilderQ, Value *malloc, int idx, bool ignoreType=false);
727727

728728
const SmallVectorImpl<Value *> &getTapeValues() const {
729729
return addedTapeVals;

0 commit comments

Comments
 (0)