Skip to content

Commit e0a2944

Browse files
authored
Speed up invert original map (rust-lang#250)
1 parent d6a5b1a commit e0a2944

File tree

4 files changed

+151
-72
lines changed

4 files changed

+151
-72
lines changed

enzyme/Enzyme/AdjointGenerator.h

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5178,11 +5178,12 @@ class AdjointGenerator
51785178
: BuilderZ.CreateExtractValue(
51795179
augmentcall, {(unsigned)returnIdx.getValue()});
51805180
gutils->originalToNewFn[orig] = dcall;
5181+
gutils->newToOriginalFn.erase(newCall);
5182+
gutils->newToOriginalFn[dcall] = orig;
51815183
assert(dcall->getType() == orig->getType());
51825184
assert(dcall);
51835185

51845186
if (!gutils->isConstantValue(orig)) {
5185-
gutils->originalToNewFn[orig] = dcall;
51865187
if (!orig->getType()->isFPOrFPVectorTy() &&
51875188
TR.query(orig).Inner0().isPossiblePointer()) {
51885189
} else if (Mode != DerivativeMode::ReverseModePrimal) {
@@ -5211,6 +5212,7 @@ class AdjointGenerator
52115212
BuilderZ.SetInsertPoint(BuilderZ.GetInsertPoint()->getNextNode());
52125213
eraseIfUnused(*orig, /*erase*/ true, /*check*/ false);
52135214
gutils->originalToNewFn[orig] = augmentcall;
5215+
gutils->newToOriginalFn[augmentcall] = orig;
52145216
}
52155217

52165218
} else {
@@ -5566,6 +5568,8 @@ class AdjointGenerator
55665568
}
55675569

55685570
gutils->originalToNewFn[orig] = retval ? retval : diffes;
5571+
gutils->newToOriginalFn.erase(newCall);
5572+
gutils->newToOriginalFn[retval ? retval : diffes] = orig;
55695573

55705574
// llvm::errs() << "newFunc postrep: " << *gutils->newFunc << "\n";
55715575

@@ -5585,6 +5589,8 @@ class AdjointGenerator
55855589

55865590
if (!gutils->isConstantValue(orig)) {
55875591
gutils->originalToNewFn[orig] = dcall;
5592+
gutils->newToOriginalFn.erase(newCall);
5593+
gutils->newToOriginalFn[dcall] = orig;
55885594
if (!orig->getType()->isFPOrFPVectorTy() &&
55895595
TR.query(orig).Inner0().isPossiblePointer()) {
55905596
} else {
@@ -5603,6 +5609,8 @@ class AdjointGenerator
56035609
eraseIfUnused(*orig, /*erase*/ true, /*check*/ false);
56045610
if (augmentcall) {
56055611
gutils->originalToNewFn[orig] = augmentcall;
5612+
gutils->newToOriginalFn.erase(newCall);
5613+
gutils->newToOriginalFn[augmentcall] = orig;
56065614
}
56075615
}
56085616
}

enzyme/Enzyme/EnzymeLogic.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1634,6 +1634,8 @@ const AugmentedReturn &EnzymeLogic::CreateAugmentedPrimal(
16341634

16351635
auto newri = ib.CreateRet(rt);
16361636
gutils->originalToNewFn[orig_ri] = newri;
1637+
gutils->newToOriginalFn.erase(ri);
1638+
gutils->newToOriginalFn[newri] = orig_ri;
16371639
gutils->erase(ri);
16381640
}
16391641
}
@@ -2049,7 +2051,7 @@ const AugmentedReturn &EnzymeLogic::CreateAugmentedPrimal(
20492051
ib.CreateRetVoid();
20502052
else
20512053
ib.CreateRet(ib.CreateLoad(ret));
2052-
gutils->erase(cast<Instruction>(VMap[ri]));
2054+
cast<Instruction>(VMap[ri])->eraseFromParent();
20532055
}
20542056
}
20552057

enzyme/Enzyme/GradientUtils.cpp

Lines changed: 78 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -161,15 +161,31 @@ Value *GradientUtils::unwrapM(Value *const val, IRBuilder<> &BuilderM,
161161

162162
std::pair<Value *, BasicBlock *> idx = std::make_pair(val, scope);
163163
// assert(!val->getName().startswith("$tapeload"));
164-
if (permitCache && unwrap_cache[BuilderM.GetInsertBlock()].find(idx) !=
165-
unwrap_cache[BuilderM.GetInsertBlock()].end()) {
166-
auto cachedValue = unwrap_cache[BuilderM.GetInsertBlock()][idx];
167-
if (cachedValue->getType() != val->getType()) {
168-
llvm::errs() << "val: " << *val << "\n";
169-
llvm::errs() << "unwrap_cache[cidx]: " << *cachedValue << "\n";
164+
if (permitCache) {
165+
auto found0 = unwrap_cache.find(BuilderM.GetInsertBlock());
166+
if (found0 != unwrap_cache.end()) {
167+
auto found1 = found0->second.find(idx.first);
168+
if (found1 != found0->second.end()) {
169+
auto found2 = found1->second.find(idx.second);
170+
if (found2 != found1->second.end()) {
171+
172+
auto cachedValue = found2->second;
173+
if (cachedValue == nullptr) {
174+
found1->second.erase(idx.second);
175+
if (found1->second.size() == 0) {
176+
found0->second.erase(idx.first);
177+
}
178+
} else {
179+
if (cachedValue->getType() != val->getType()) {
180+
llvm::errs() << "val: " << *val << "\n";
181+
llvm::errs() << "unwrap_cache[cidx]: " << *cachedValue << "\n";
182+
}
183+
assert(cachedValue->getType() == val->getType());
184+
return cachedValue;
185+
}
186+
}
187+
}
170188
}
171-
assert(cachedValue->getType() == val->getType());
172-
return cachedValue;
173189
}
174190

175191
if (this->mode == DerivativeMode::ReverseModeGradient)
@@ -192,7 +208,8 @@ Value *GradientUtils::unwrapM(Value *const val, IRBuilder<> &BuilderM,
192208
val->getType(), 0, val->getName() + "_krcLFUreplacement");
193209
assert(permitCache);
194210
unwrappedLoads[placeholder] = inst;
195-
return unwrap_cache[BuilderM.GetInsertBlock()][idx] = placeholder;
211+
return unwrap_cache[BuilderM.GetInsertBlock()][idx.first]
212+
[idx.second] = placeholder;
196213
}
197214
}
198215
} else if (mode == UnwrapMode::AttemptFullUnwrapWithLookup) {
@@ -222,7 +239,8 @@ Value *GradientUtils::unwrapM(Value *const val, IRBuilder<> &BuilderM,
222239
val->getType(), 0, val->getName() + "_krcAFUWLreplacement");
223240
assert(permitCache);
224241
unwrappedLoads[placeholder] = inst;
225-
return unwrap_cache[BuilderM.GetInsertBlock()][idx] = placeholder;
242+
return unwrap_cache[BuilderM.GetInsertBlock()][idx.first]
243+
[idx.second] = placeholder;
226244
}
227245
}
228246
}
@@ -327,7 +345,7 @@ Value *GradientUtils::unwrapM(Value *const val, IRBuilder<> &BuilderM,
327345
unwrappedLoads[newi] = val;
328346
}
329347
if (permitCache)
330-
unwrap_cache[BuilderM.GetInsertBlock()][idx] = toreturn;
348+
unwrap_cache[BuilderM.GetInsertBlock()][idx.first][idx.second] = toreturn;
331349
assert(val->getType() == toreturn->getType());
332350
return toreturn;
333351
#endif
@@ -344,7 +362,7 @@ Value *GradientUtils::unwrapM(Value *const val, IRBuilder<> &BuilderM,
344362
newi->setDebugLoc(nullptr);
345363
}
346364
if (permitCache)
347-
unwrap_cache[BuilderM.GetInsertBlock()][idx] = toreturn;
365+
unwrap_cache[BuilderM.GetInsertBlock()][idx.first][idx.second] = toreturn;
348366
assert(val->getType() == toreturn->getType());
349367
return toreturn;
350368
} else if (auto op = dyn_cast<ExtractValueInst>(val)) {
@@ -354,7 +372,7 @@ Value *GradientUtils::unwrapM(Value *const val, IRBuilder<> &BuilderM,
354372
auto toreturn = BuilderM.CreateExtractValue(op0, op->getIndices(),
355373
op->getName() + "_unwrap");
356374
if (permitCache)
357-
unwrap_cache[BuilderM.GetInsertBlock()][idx] = toreturn;
375+
unwrap_cache[BuilderM.GetInsertBlock()][idx.first][idx.second] = toreturn;
358376
if (auto newi = dyn_cast<Instruction>(toreturn)) {
359377
newi->copyIRFlags(op);
360378
unwrappedLoads[newi] = val;
@@ -373,7 +391,7 @@ Value *GradientUtils::unwrapM(Value *const val, IRBuilder<> &BuilderM,
373391
auto toreturn = BuilderM.CreateInsertValue(op0, op1, op->getIndices(),
374392
op->getName() + "_unwrap");
375393
if (permitCache)
376-
unwrap_cache[BuilderM.GetInsertBlock()][idx] = toreturn;
394+
unwrap_cache[BuilderM.GetInsertBlock()][idx.first][idx.second] = toreturn;
377395
if (auto newi = dyn_cast<Instruction>(toreturn)) {
378396
newi->copyIRFlags(op);
379397
unwrappedLoads[newi] = val;
@@ -392,7 +410,7 @@ Value *GradientUtils::unwrapM(Value *const val, IRBuilder<> &BuilderM,
392410
auto toreturn =
393411
BuilderM.CreateExtractElement(op0, op1, op->getName() + "_unwrap");
394412
if (permitCache)
395-
unwrap_cache[BuilderM.GetInsertBlock()][idx] = toreturn;
413+
unwrap_cache[BuilderM.GetInsertBlock()][idx.first][idx.second] = toreturn;
396414
if (auto newi = dyn_cast<Instruction>(toreturn)) {
397415
newi->copyIRFlags(op);
398416
unwrappedLoads[newi] = val;
@@ -414,7 +432,7 @@ Value *GradientUtils::unwrapM(Value *const val, IRBuilder<> &BuilderM,
414432
auto toreturn =
415433
BuilderM.CreateInsertElement(op0, op1, op2, op->getName() + "_unwrap");
416434
if (permitCache)
417-
unwrap_cache[BuilderM.GetInsertBlock()][idx] = toreturn;
435+
unwrap_cache[BuilderM.GetInsertBlock()][idx.first][idx.second] = toreturn;
418436
if (auto newi = dyn_cast<Instruction>(toreturn)) {
419437
newi->copyIRFlags(op);
420438
unwrappedLoads[newi] = val;
@@ -438,7 +456,7 @@ Value *GradientUtils::unwrapM(Value *const val, IRBuilder<> &BuilderM,
438456
op->getName() + "'_unwrap");
439457
#endif
440458
if (permitCache)
441-
unwrap_cache[BuilderM.GetInsertBlock()][idx] = toreturn;
459+
unwrap_cache[BuilderM.GetInsertBlock()][idx.first][idx.second] = toreturn;
442460
if (auto newi = dyn_cast<Instruction>(toreturn)) {
443461
newi->copyIRFlags(op);
444462
unwrappedLoads[newi] = val;
@@ -469,7 +487,7 @@ Value *GradientUtils::unwrapM(Value *const val, IRBuilder<> &BuilderM,
469487
newi->setDebugLoc(nullptr);
470488
}
471489
if (permitCache)
472-
unwrap_cache[BuilderM.GetInsertBlock()][idx] = toreturn;
490+
unwrap_cache[BuilderM.GetInsertBlock()][idx.first][idx.second] = toreturn;
473491
assert(val->getType() == toreturn->getType());
474492
return toreturn;
475493
} else if (auto op = dyn_cast<ICmpInst>(val)) {
@@ -488,7 +506,7 @@ Value *GradientUtils::unwrapM(Value *const val, IRBuilder<> &BuilderM,
488506
newi->setDebugLoc(nullptr);
489507
}
490508
if (permitCache)
491-
unwrap_cache[BuilderM.GetInsertBlock()][idx] = toreturn;
509+
unwrap_cache[BuilderM.GetInsertBlock()][idx.first][idx.second] = toreturn;
492510
assert(val->getType() == toreturn->getType());
493511
return toreturn;
494512
} else if (auto op = dyn_cast<FCmpInst>(val)) {
@@ -507,7 +525,7 @@ Value *GradientUtils::unwrapM(Value *const val, IRBuilder<> &BuilderM,
507525
newi->setDebugLoc(nullptr);
508526
}
509527
if (permitCache)
510-
unwrap_cache[BuilderM.GetInsertBlock()][idx] = toreturn;
528+
unwrap_cache[BuilderM.GetInsertBlock()][idx.first][idx.second] = toreturn;
511529
assert(val->getType() == toreturn->getType());
512530
return toreturn;
513531
#if LLVM_VERSION_MAJOR >= 9
@@ -526,7 +544,7 @@ Value *GradientUtils::unwrapM(Value *const val, IRBuilder<> &BuilderM,
526544
newi->setDebugLoc(nullptr);
527545
}
528546
if (permitCache)
529-
unwrap_cache[BuilderM.GetInsertBlock()][idx] = toreturn;
547+
unwrap_cache[BuilderM.GetInsertBlock()][idx.first][idx.second] = toreturn;
530548
assert(val->getType() == toreturn->getType());
531549
return toreturn;
532550
#endif
@@ -549,7 +567,7 @@ Value *GradientUtils::unwrapM(Value *const val, IRBuilder<> &BuilderM,
549567
newi->setDebugLoc(nullptr);
550568
}
551569
if (permitCache)
552-
unwrap_cache[BuilderM.GetInsertBlock()][idx] = toreturn;
570+
unwrap_cache[BuilderM.GetInsertBlock()][idx.first][idx.second] = toreturn;
553571
assert(val->getType() == toreturn->getType());
554572
return toreturn;
555573
} else if (auto inst = dyn_cast<GetElementPtrInst>(val)) {
@@ -576,7 +594,7 @@ Value *GradientUtils::unwrapM(Value *const val, IRBuilder<> &BuilderM,
576594
newi->setDebugLoc(nullptr);
577595
}
578596
if (permitCache)
579-
unwrap_cache[BuilderM.GetInsertBlock()][idx] = toreturn;
597+
unwrap_cache[BuilderM.GetInsertBlock()][idx.first][idx.second] = toreturn;
580598
assert(val->getType() == toreturn->getType());
581599
return toreturn;
582600
} else if (auto load = dyn_cast<LoadInst>(val)) {
@@ -642,7 +660,7 @@ Value *GradientUtils::unwrapM(Value *const val, IRBuilder<> &BuilderM,
642660
load->getMetadata(LLVMContext::MD_invariant_group));
643661
// TODO adding to cache only legal if no alias of any future writes
644662
if (permitCache)
645-
unwrap_cache[BuilderM.GetInsertBlock()][idx] = toreturn;
663+
unwrap_cache[BuilderM.GetInsertBlock()][idx.first][idx.second] = toreturn;
646664
assert(val->getType() == toreturn->getType());
647665
return toreturn;
648666
} else if (auto op = dyn_cast<CallInst>(val)) {
@@ -681,7 +699,7 @@ Value *GradientUtils::unwrapM(Value *const val, IRBuilder<> &BuilderM,
681699
else
682700
toreturn->setDebugLoc(nullptr);
683701
if (permitCache)
684-
unwrap_cache[BuilderM.GetInsertBlock()][idx] = toreturn;
702+
unwrap_cache[BuilderM.GetInsertBlock()][idx.first][idx.second] = toreturn;
685703
unwrappedLoads[toreturn] = val;
686704
return toreturn;
687705
} else if (auto phi = dyn_cast<PHINode>(val)) {
@@ -741,7 +759,8 @@ Value *GradientUtils::unwrapM(Value *const val, IRBuilder<> &BuilderM,
741759
dli->getMetadata(LLVMContext::MD_invariant_group));
742760
// TODO adding to cache only legal if no alias of any future writes
743761
if (permitCache)
744-
unwrap_cache[BuilderM.GetInsertBlock()][idx] = toreturn;
762+
unwrap_cache[BuilderM.GetInsertBlock()][idx.first][idx.second] =
763+
toreturn;
745764
assert(val->getType() == toreturn->getType());
746765
return toreturn;
747766
}
@@ -790,7 +809,8 @@ Value *GradientUtils::unwrapM(Value *const val, IRBuilder<> &BuilderM,
790809
if (!lc.dynamic) {
791810
Value *lim = getOp(lc.trueLimit);
792811
if (lim) {
793-
unwrap_cache[BuilderM.GetInsertBlock()][idx] = lim;
812+
unwrap_cache[BuilderM.GetInsertBlock()][idx.first][idx.second] =
813+
lim;
794814
return lim;
795815
}
796816
} else if (mode == UnwrapMode::AttemptFullUnwrapWithLookup &&
@@ -804,7 +824,7 @@ Value *GradientUtils::unwrapM(Value *const val, IRBuilder<> &BuilderM,
804824
/*forwardPass*/ false, BuilderM, lctx,
805825
getDynamicLoopLimit(LI.getLoopFor(lc.header)),
806826
/*isi1*/ false);
807-
unwrap_cache[BuilderM.GetInsertBlock()][idx] = lim;
827+
unwrap_cache[BuilderM.GetInsertBlock()][idx.first][idx.second] = lim;
808828
return lim;
809829
}
810830
}
@@ -1076,7 +1096,7 @@ Value *GradientUtils::unwrapM(Value *const val, IRBuilder<> &BuilderM,
10761096
toret->addIncoming(vals[i], endingBlocks[i]);
10771097
assert(val->getType() == toret->getType());
10781098
if (permitCache) {
1079-
unwrap_cache[bret][idx] = toret;
1099+
unwrap_cache[bret][idx.first][idx.second] = toret;
10801100
}
10811101
unwrappedLoads[toret] = val;
10821102
unwrap_cache[bret] = unwrap_cache[oldB];
@@ -1249,7 +1269,8 @@ Value *GradientUtils::unwrapM(Value *const val, IRBuilder<> &BuilderM,
12491269
Value *toret = BuilderM.CreateSelect(cond, vals[0], vals[1],
12501270
phi->getName() + "_unwrap");
12511271
if (permitCache) {
1252-
unwrap_cache[BuilderM.GetInsertBlock()][idx] = toret;
1272+
unwrap_cache[BuilderM.GetInsertBlock()][idx.first][idx.second] =
1273+
toret;
12531274
}
12541275
if (auto instRet = dyn_cast<Instruction>(toret))
12551276
unwrappedLoads[instRet] = val;
@@ -1276,7 +1297,7 @@ Value *GradientUtils::unwrapM(Value *const val, IRBuilder<> &BuilderM,
12761297
toret->addIncoming(vals[i], endingBlocks[i]);
12771298
assert(val->getType() == toret->getType());
12781299
if (permitCache) {
1279-
unwrap_cache[bret][idx] = toret;
1300+
unwrap_cache[bret][idx.first][idx.second] = toret;
12801301
}
12811302
unwrap_cache[bret] = unwrap_cache[oldB];
12821303
lookup_cache[bret] = lookup_cache[oldB];
@@ -1510,9 +1531,15 @@ Value *GradientUtils::cacheForReverse(IRBuilder<> &BuilderQ, Value *malloc,
15101531
assert(malloc->getType() == ret->getType());
15111532
}
15121533

1513-
if (replace)
1514-
if (auto orig = isOriginal(malloc))
1534+
if (replace) {
1535+
auto found = newToOriginalFn.find(malloc);
1536+
if (found != newToOriginalFn.end()) {
1537+
Value *orig = found->second;
15151538
originalToNewFn[orig] = ret;
1539+
newToOriginalFn.erase(malloc);
1540+
newToOriginalFn[ret] = orig;
1541+
}
1542+
}
15161543

15171544
if (auto found = findInMap(scopeMap, malloc)) {
15181545
// There already exists an alloaction for this, we should fully remove
@@ -3129,11 +3156,15 @@ Value *GradientUtils::lookupM(Value *val, IRBuilder<> &BuilderM,
31293156
if (lookup_cache[BuilderM.GetInsertBlock()].find(val) !=
31303157
lookup_cache[BuilderM.GetInsertBlock()].end()) {
31313158
auto result = lookup_cache[BuilderM.GetInsertBlock()][val];
3132-
assert(result);
3133-
assert(result->getType());
3134-
result = BuilderM.CreateBitCast(result, val->getType());
3135-
assert(result->getType() == inst->getType());
3136-
return result;
3159+
if (result == nullptr) {
3160+
lookup_cache[BuilderM.GetInsertBlock()].erase(val);
3161+
} else {
3162+
assert(result);
3163+
assert(result->getType());
3164+
result = BuilderM.CreateBitCast(result, val->getType());
3165+
assert(result->getType() == inst->getType());
3166+
return result;
3167+
}
31373168
}
31383169

31393170
ValueToValueMapTy available;
@@ -3244,11 +3275,15 @@ Value *GradientUtils::lookupM(Value *val, IRBuilder<> &BuilderM,
32443275
if (lookup_cache[BuilderM.GetInsertBlock()].find(val) !=
32453276
lookup_cache[BuilderM.GetInsertBlock()].end()) {
32463277
auto result = lookup_cache[BuilderM.GetInsertBlock()][val];
3247-
assert(result);
3248-
assert(result->getType());
3249-
result = BuilderM.CreateBitCast(result, val->getType());
3250-
assert(result->getType() == inst->getType());
3251-
return result;
3278+
if (result == nullptr) {
3279+
lookup_cache[BuilderM.GetInsertBlock()].erase(val);
3280+
} else {
3281+
assert(result);
3282+
assert(result->getType());
3283+
result = BuilderM.CreateBitCast(result, val->getType());
3284+
assert(result->getType() == inst->getType());
3285+
return result;
3286+
}
32523287
}
32533288

32543289
// TODO consider call as part of

0 commit comments

Comments
 (0)