Skip to content

Commit d7f4967

Browse files
committed
unwrap of cached
1 parent 49bf6b5 commit d7f4967

File tree

2 files changed

+27
-1
lines changed

2 files changed

+27
-1
lines changed

enzyme/Enzyme/EnzymeLogic.cpp

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2965,6 +2965,10 @@ Function *EnzymeLogic::CreatePrimalAndGradient(
29652965
}
29662966

29672967
if (!topLevel) {
2968+
std::map<Value *, std::vector<Value *>> unwrapToOrig;
2969+
for (auto pair : gutils->unwrappedLoads)
2970+
unwrapToOrig[pair.second].push_back(const_cast<Value *>(pair.first));
2971+
std::map<Value *, Value *> newIToNextI;
29682972
for (const auto &m : mapping) {
29692973
if (m.first.second == CacheType::Self && !isa<LoadInst>(m.first.first) &&
29702974
!isa<CallInst>(m.first.first)) {
@@ -2977,7 +2981,14 @@ Function *EnzymeLogic::CreatePrimalAndGradient(
29772981
BuilderZ.SetInsertPoint(
29782982
cast<Instruction>(newi)->getParent()->getFirstNonPHI());
29792983
}
2980-
gutils->cacheForReverse(BuilderZ, newi, m.second);
2984+
Value *nexti = gutils->cacheForReverse(BuilderZ, newi, m.second);
2985+
for (auto V : unwrapToOrig[newi]) {
2986+
ValueToValueMapTy empty;
2987+
IRBuilder<> lb(cast<Instruction>(V));
2988+
V->replaceAllUsesWith(
2989+
gutils->unwrapM(nexti, lb, empty, UnwrapMode::LegalFullUnwrap));
2990+
cast<Instruction>(V)->eraseFromParent();
2991+
}
29812992
}
29822993
}
29832994

enzyme/Enzyme/GradientUtils.cpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -245,6 +245,7 @@ Value *GradientUtils::unwrapM(Value *const val, IRBuilder<> &BuilderM,
245245
newi->copyIRFlags(op);
246246
if (permitCache)
247247
unwrap_cache[BuilderM.GetInsertBlock()][idx] = toreturn;
248+
unwrappedLoads[toreturn] = val;
248249
assert(val->getType() == toreturn->getType());
249250
return toreturn;
250251
} else if (auto op = dyn_cast<ExtractValueInst>(val)) {
@@ -255,6 +256,7 @@ Value *GradientUtils::unwrapM(Value *const val, IRBuilder<> &BuilderM,
255256
op->getName() + "_unwrap");
256257
if (permitCache)
257258
unwrap_cache[BuilderM.GetInsertBlock()][idx] = toreturn;
259+
unwrappedLoads[toreturn] = val;
258260
if (auto newi = dyn_cast<Instruction>(toreturn))
259261
newi->copyIRFlags(op);
260262
assert(val->getType() == toreturn->getType());
@@ -270,6 +272,7 @@ Value *GradientUtils::unwrapM(Value *const val, IRBuilder<> &BuilderM,
270272
op->getName() + "_unwrap");
271273
if (permitCache)
272274
unwrap_cache[BuilderM.GetInsertBlock()][idx] = toreturn;
275+
unwrappedLoads[toreturn] = val;
273276
if (auto newi = dyn_cast<Instruction>(toreturn))
274277
newi->copyIRFlags(op);
275278
assert(val->getType() == toreturn->getType());
@@ -285,6 +288,7 @@ Value *GradientUtils::unwrapM(Value *const val, IRBuilder<> &BuilderM,
285288
BuilderM.CreateExtractElement(op0, op1, op->getName() + "_unwrap");
286289
if (permitCache)
287290
unwrap_cache[BuilderM.GetInsertBlock()][idx] = toreturn;
291+
unwrappedLoads[toreturn] = val;
288292
if (auto newi = dyn_cast<Instruction>(toreturn))
289293
newi->copyIRFlags(op);
290294
assert(val->getType() == toreturn->getType());
@@ -303,6 +307,7 @@ Value *GradientUtils::unwrapM(Value *const val, IRBuilder<> &BuilderM,
303307
BuilderM.CreateInsertElement(op0, op1, op2, op->getName() + "_unwrap");
304308
if (permitCache)
305309
unwrap_cache[BuilderM.GetInsertBlock()][idx] = toreturn;
310+
unwrappedLoads[toreturn] = val;
306311
if (auto newi = dyn_cast<Instruction>(toreturn))
307312
newi->copyIRFlags(op);
308313
assert(val->getType() == toreturn->getType());
@@ -323,6 +328,7 @@ Value *GradientUtils::unwrapM(Value *const val, IRBuilder<> &BuilderM,
323328
#endif
324329
if (permitCache)
325330
unwrap_cache[BuilderM.GetInsertBlock()][idx] = toreturn;
331+
unwrappedLoads[toreturn] = val;
326332
if (auto newi = dyn_cast<Instruction>(toreturn))
327333
newi->copyIRFlags(op);
328334
assert(val->getType() == toreturn->getType());
@@ -342,6 +348,7 @@ Value *GradientUtils::unwrapM(Value *const val, IRBuilder<> &BuilderM,
342348
assert(op0->getType() == op1->getType());
343349
auto toreturn = BuilderM.CreateBinOp(op->getOpcode(), op0, op1,
344350
op->getName() + "_unwrap");
351+
unwrappedLoads[toreturn] = val;
345352
if (auto newi = dyn_cast<Instruction>(toreturn))
346353
newi->copyIRFlags(op);
347354
if (permitCache)
@@ -361,6 +368,7 @@ Value *GradientUtils::unwrapM(Value *const val, IRBuilder<> &BuilderM,
361368
newi->copyIRFlags(op);
362369
if (permitCache)
363370
unwrap_cache[BuilderM.GetInsertBlock()][idx] = toreturn;
371+
unwrappedLoads[toreturn] = val;
364372
assert(val->getType() == toreturn->getType());
365373
return toreturn;
366374
} else if (auto op = dyn_cast<FCmpInst>(val)) {
@@ -376,6 +384,7 @@ Value *GradientUtils::unwrapM(Value *const val, IRBuilder<> &BuilderM,
376384
newi->copyIRFlags(op);
377385
if (permitCache)
378386
unwrap_cache[BuilderM.GetInsertBlock()][idx] = toreturn;
387+
unwrappedLoads[toreturn] = val;
379388
assert(val->getType() == toreturn->getType());
380389
return toreturn;
381390
#if LLVM_VERSION_MAJOR >= 9
@@ -390,6 +399,7 @@ Value *GradientUtils::unwrapM(Value *const val, IRBuilder<> &BuilderM,
390399
newi->copyIRFlags(op);
391400
if (permitCache)
392401
unwrap_cache[BuilderM.GetInsertBlock()][idx] = toreturn;
402+
unwrappedLoads[toreturn] = val;
393403
assert(val->getType() == toreturn->getType());
394404
return toreturn;
395405
#endif
@@ -409,6 +419,7 @@ Value *GradientUtils::unwrapM(Value *const val, IRBuilder<> &BuilderM,
409419
newi->copyIRFlags(op);
410420
if (permitCache)
411421
unwrap_cache[BuilderM.GetInsertBlock()][idx] = toreturn;
422+
unwrappedLoads[toreturn] = val;
412423
assert(val->getType() == toreturn->getType());
413424
return toreturn;
414425
} else if (auto inst = dyn_cast<GetElementPtrInst>(val)) {
@@ -437,6 +448,7 @@ Value *GradientUtils::unwrapM(Value *const val, IRBuilder<> &BuilderM,
437448
newi->copyIRFlags(inst);
438449
if (permitCache)
439450
unwrap_cache[BuilderM.GetInsertBlock()][idx] = toreturn;
451+
unwrappedLoads[toreturn] = val;
440452
assert(val->getType() == toreturn->getType());
441453
return toreturn;
442454
} else if (auto load = dyn_cast<LoadInst>(val)) {
@@ -532,6 +544,7 @@ Value *GradientUtils::unwrapM(Value *const val, IRBuilder<> &BuilderM,
532544
toreturn->setDebugLoc(getNewFromOriginal(op->getDebugLoc()));
533545
if (permitCache)
534546
unwrap_cache[BuilderM.GetInsertBlock()][idx] = toreturn;
547+
unwrappedLoads[toreturn] = val;
535548
return toreturn;
536549
} else if (auto phi = dyn_cast<PHINode>(val)) {
537550
if (phi->getNumIncomingValues() == 0) {
@@ -913,6 +926,7 @@ Value *GradientUtils::unwrapM(Value *const val, IRBuilder<> &BuilderM,
913926
if (permitCache) {
914927
unwrap_cache[bret][idx] = toret;
915928
}
929+
unwrappedLoads[toret] = val;
916930
unwrap_cache[bret] = unwrap_cache[oldB];
917931
lookup_cache[bret] = lookup_cache[oldB];
918932
return toret;
@@ -1077,6 +1091,7 @@ Value *GradientUtils::unwrapM(Value *const val, IRBuilder<> &BuilderM,
10771091
}
10781092
unwrap_cache[bret] = unwrap_cache[oldB];
10791093
lookup_cache[bret] = lookup_cache[oldB];
1094+
unwrappedLoads[toret] = val;
10801095
return toret;
10811096
}
10821097
goto endCheck;

0 commit comments

Comments
 (0)