@@ -161,15 +161,31 @@ Value *GradientUtils::unwrapM(Value *const val, IRBuilder<> &BuilderM,
161
161
162
162
std::pair<Value *, BasicBlock *> idx = std::make_pair (val, scope);
163
163
// assert(!val->getName().startswith("$tapeload"));
164
- if (permitCache && unwrap_cache[BuilderM.GetInsertBlock ()].find (idx) !=
165
- unwrap_cache[BuilderM.GetInsertBlock ()].end ()) {
166
- auto cachedValue = unwrap_cache[BuilderM.GetInsertBlock ()][idx];
167
- if (cachedValue->getType () != val->getType ()) {
168
- llvm::errs () << " val: " << *val << " \n " ;
169
- llvm::errs () << " unwrap_cache[cidx]: " << *cachedValue << " \n " ;
164
+ if (permitCache) {
165
+ auto found0 = unwrap_cache.find (BuilderM.GetInsertBlock ());
166
+ if (found0 != unwrap_cache.end ()) {
167
+ auto found1 = found0->second .find (idx.first );
168
+ if (found1 != found0->second .end ()) {
169
+ auto found2 = found1->second .find (idx.second );
170
+ if (found2 != found1->second .end ()) {
171
+
172
+ auto cachedValue = found2->second ;
173
+ if (cachedValue == nullptr ) {
174
+ found1->second .erase (idx.second );
175
+ if (found1->second .size () == 0 ) {
176
+ found0->second .erase (idx.first );
177
+ }
178
+ } else {
179
+ if (cachedValue->getType () != val->getType ()) {
180
+ llvm::errs () << " val: " << *val << " \n " ;
181
+ llvm::errs () << " unwrap_cache[cidx]: " << *cachedValue << " \n " ;
182
+ }
183
+ assert (cachedValue->getType () == val->getType ());
184
+ return cachedValue;
185
+ }
186
+ }
187
+ }
170
188
}
171
- assert (cachedValue->getType () == val->getType ());
172
- return cachedValue;
173
189
}
174
190
175
191
if (this ->mode == DerivativeMode::ReverseModeGradient)
@@ -192,7 +208,8 @@ Value *GradientUtils::unwrapM(Value *const val, IRBuilder<> &BuilderM,
192
208
val->getType (), 0 , val->getName () + " _krcLFUreplacement" );
193
209
assert (permitCache);
194
210
unwrappedLoads[placeholder] = inst;
195
- return unwrap_cache[BuilderM.GetInsertBlock ()][idx] = placeholder;
211
+ return unwrap_cache[BuilderM.GetInsertBlock ()][idx.first ]
212
+ [idx.second ] = placeholder;
196
213
}
197
214
}
198
215
} else if (mode == UnwrapMode::AttemptFullUnwrapWithLookup) {
@@ -222,7 +239,8 @@ Value *GradientUtils::unwrapM(Value *const val, IRBuilder<> &BuilderM,
222
239
val->getType (), 0 , val->getName () + " _krcAFUWLreplacement" );
223
240
assert (permitCache);
224
241
unwrappedLoads[placeholder] = inst;
225
- return unwrap_cache[BuilderM.GetInsertBlock ()][idx] = placeholder;
242
+ return unwrap_cache[BuilderM.GetInsertBlock ()][idx.first ]
243
+ [idx.second ] = placeholder;
226
244
}
227
245
}
228
246
}
@@ -327,7 +345,7 @@ Value *GradientUtils::unwrapM(Value *const val, IRBuilder<> &BuilderM,
327
345
unwrappedLoads[newi] = val;
328
346
}
329
347
if (permitCache)
330
- unwrap_cache[BuilderM.GetInsertBlock ()][idx] = toreturn;
348
+ unwrap_cache[BuilderM.GetInsertBlock ()][idx. first ][idx. second ] = toreturn;
331
349
assert (val->getType () == toreturn->getType ());
332
350
return toreturn;
333
351
#endif
@@ -344,7 +362,7 @@ Value *GradientUtils::unwrapM(Value *const val, IRBuilder<> &BuilderM,
344
362
newi->setDebugLoc (nullptr );
345
363
}
346
364
if (permitCache)
347
- unwrap_cache[BuilderM.GetInsertBlock ()][idx] = toreturn;
365
+ unwrap_cache[BuilderM.GetInsertBlock ()][idx. first ][idx. second ] = toreturn;
348
366
assert (val->getType () == toreturn->getType ());
349
367
return toreturn;
350
368
} else if (auto op = dyn_cast<ExtractValueInst>(val)) {
@@ -354,7 +372,7 @@ Value *GradientUtils::unwrapM(Value *const val, IRBuilder<> &BuilderM,
354
372
auto toreturn = BuilderM.CreateExtractValue (op0, op->getIndices (),
355
373
op->getName () + " _unwrap" );
356
374
if (permitCache)
357
- unwrap_cache[BuilderM.GetInsertBlock ()][idx] = toreturn;
375
+ unwrap_cache[BuilderM.GetInsertBlock ()][idx. first ][idx. second ] = toreturn;
358
376
if (auto newi = dyn_cast<Instruction>(toreturn)) {
359
377
newi->copyIRFlags (op);
360
378
unwrappedLoads[newi] = val;
@@ -373,7 +391,7 @@ Value *GradientUtils::unwrapM(Value *const val, IRBuilder<> &BuilderM,
373
391
auto toreturn = BuilderM.CreateInsertValue (op0, op1, op->getIndices (),
374
392
op->getName () + " _unwrap" );
375
393
if (permitCache)
376
- unwrap_cache[BuilderM.GetInsertBlock ()][idx] = toreturn;
394
+ unwrap_cache[BuilderM.GetInsertBlock ()][idx. first ][idx. second ] = toreturn;
377
395
if (auto newi = dyn_cast<Instruction>(toreturn)) {
378
396
newi->copyIRFlags (op);
379
397
unwrappedLoads[newi] = val;
@@ -392,7 +410,7 @@ Value *GradientUtils::unwrapM(Value *const val, IRBuilder<> &BuilderM,
392
410
auto toreturn =
393
411
BuilderM.CreateExtractElement (op0, op1, op->getName () + " _unwrap" );
394
412
if (permitCache)
395
- unwrap_cache[BuilderM.GetInsertBlock ()][idx] = toreturn;
413
+ unwrap_cache[BuilderM.GetInsertBlock ()][idx. first ][idx. second ] = toreturn;
396
414
if (auto newi = dyn_cast<Instruction>(toreturn)) {
397
415
newi->copyIRFlags (op);
398
416
unwrappedLoads[newi] = val;
@@ -414,7 +432,7 @@ Value *GradientUtils::unwrapM(Value *const val, IRBuilder<> &BuilderM,
414
432
auto toreturn =
415
433
BuilderM.CreateInsertElement (op0, op1, op2, op->getName () + " _unwrap" );
416
434
if (permitCache)
417
- unwrap_cache[BuilderM.GetInsertBlock ()][idx] = toreturn;
435
+ unwrap_cache[BuilderM.GetInsertBlock ()][idx. first ][idx. second ] = toreturn;
418
436
if (auto newi = dyn_cast<Instruction>(toreturn)) {
419
437
newi->copyIRFlags (op);
420
438
unwrappedLoads[newi] = val;
@@ -438,7 +456,7 @@ Value *GradientUtils::unwrapM(Value *const val, IRBuilder<> &BuilderM,
438
456
op->getName () + " '_unwrap" );
439
457
#endif
440
458
if (permitCache)
441
- unwrap_cache[BuilderM.GetInsertBlock ()][idx] = toreturn;
459
+ unwrap_cache[BuilderM.GetInsertBlock ()][idx. first ][idx. second ] = toreturn;
442
460
if (auto newi = dyn_cast<Instruction>(toreturn)) {
443
461
newi->copyIRFlags (op);
444
462
unwrappedLoads[newi] = val;
@@ -469,7 +487,7 @@ Value *GradientUtils::unwrapM(Value *const val, IRBuilder<> &BuilderM,
469
487
newi->setDebugLoc (nullptr );
470
488
}
471
489
if (permitCache)
472
- unwrap_cache[BuilderM.GetInsertBlock ()][idx] = toreturn;
490
+ unwrap_cache[BuilderM.GetInsertBlock ()][idx. first ][idx. second ] = toreturn;
473
491
assert (val->getType () == toreturn->getType ());
474
492
return toreturn;
475
493
} else if (auto op = dyn_cast<ICmpInst>(val)) {
@@ -488,7 +506,7 @@ Value *GradientUtils::unwrapM(Value *const val, IRBuilder<> &BuilderM,
488
506
newi->setDebugLoc (nullptr );
489
507
}
490
508
if (permitCache)
491
- unwrap_cache[BuilderM.GetInsertBlock ()][idx] = toreturn;
509
+ unwrap_cache[BuilderM.GetInsertBlock ()][idx. first ][idx. second ] = toreturn;
492
510
assert (val->getType () == toreturn->getType ());
493
511
return toreturn;
494
512
} else if (auto op = dyn_cast<FCmpInst>(val)) {
@@ -507,7 +525,7 @@ Value *GradientUtils::unwrapM(Value *const val, IRBuilder<> &BuilderM,
507
525
newi->setDebugLoc (nullptr );
508
526
}
509
527
if (permitCache)
510
- unwrap_cache[BuilderM.GetInsertBlock ()][idx] = toreturn;
528
+ unwrap_cache[BuilderM.GetInsertBlock ()][idx. first ][idx. second ] = toreturn;
511
529
assert (val->getType () == toreturn->getType ());
512
530
return toreturn;
513
531
#if LLVM_VERSION_MAJOR >= 9
@@ -526,7 +544,7 @@ Value *GradientUtils::unwrapM(Value *const val, IRBuilder<> &BuilderM,
526
544
newi->setDebugLoc (nullptr );
527
545
}
528
546
if (permitCache)
529
- unwrap_cache[BuilderM.GetInsertBlock ()][idx] = toreturn;
547
+ unwrap_cache[BuilderM.GetInsertBlock ()][idx. first ][idx. second ] = toreturn;
530
548
assert (val->getType () == toreturn->getType ());
531
549
return toreturn;
532
550
#endif
@@ -549,7 +567,7 @@ Value *GradientUtils::unwrapM(Value *const val, IRBuilder<> &BuilderM,
549
567
newi->setDebugLoc (nullptr );
550
568
}
551
569
if (permitCache)
552
- unwrap_cache[BuilderM.GetInsertBlock ()][idx] = toreturn;
570
+ unwrap_cache[BuilderM.GetInsertBlock ()][idx. first ][idx. second ] = toreturn;
553
571
assert (val->getType () == toreturn->getType ());
554
572
return toreturn;
555
573
} else if (auto inst = dyn_cast<GetElementPtrInst>(val)) {
@@ -576,7 +594,7 @@ Value *GradientUtils::unwrapM(Value *const val, IRBuilder<> &BuilderM,
576
594
newi->setDebugLoc (nullptr );
577
595
}
578
596
if (permitCache)
579
- unwrap_cache[BuilderM.GetInsertBlock ()][idx] = toreturn;
597
+ unwrap_cache[BuilderM.GetInsertBlock ()][idx. first ][idx. second ] = toreturn;
580
598
assert (val->getType () == toreturn->getType ());
581
599
return toreturn;
582
600
} else if (auto load = dyn_cast<LoadInst>(val)) {
@@ -642,7 +660,7 @@ Value *GradientUtils::unwrapM(Value *const val, IRBuilder<> &BuilderM,
642
660
load->getMetadata (LLVMContext::MD_invariant_group));
643
661
// TODO adding to cache only legal if no alias of any future writes
644
662
if (permitCache)
645
- unwrap_cache[BuilderM.GetInsertBlock ()][idx] = toreturn;
663
+ unwrap_cache[BuilderM.GetInsertBlock ()][idx. first ][idx. second ] = toreturn;
646
664
assert (val->getType () == toreturn->getType ());
647
665
return toreturn;
648
666
} else if (auto op = dyn_cast<CallInst>(val)) {
@@ -681,7 +699,7 @@ Value *GradientUtils::unwrapM(Value *const val, IRBuilder<> &BuilderM,
681
699
else
682
700
toreturn->setDebugLoc (nullptr );
683
701
if (permitCache)
684
- unwrap_cache[BuilderM.GetInsertBlock ()][idx] = toreturn;
702
+ unwrap_cache[BuilderM.GetInsertBlock ()][idx. first ][idx. second ] = toreturn;
685
703
unwrappedLoads[toreturn] = val;
686
704
return toreturn;
687
705
} else if (auto phi = dyn_cast<PHINode>(val)) {
@@ -741,7 +759,8 @@ Value *GradientUtils::unwrapM(Value *const val, IRBuilder<> &BuilderM,
741
759
dli->getMetadata (LLVMContext::MD_invariant_group));
742
760
// TODO adding to cache only legal if no alias of any future writes
743
761
if (permitCache)
744
- unwrap_cache[BuilderM.GetInsertBlock ()][idx] = toreturn;
762
+ unwrap_cache[BuilderM.GetInsertBlock ()][idx.first ][idx.second ] =
763
+ toreturn;
745
764
assert (val->getType () == toreturn->getType ());
746
765
return toreturn;
747
766
}
@@ -790,7 +809,8 @@ Value *GradientUtils::unwrapM(Value *const val, IRBuilder<> &BuilderM,
790
809
if (!lc.dynamic ) {
791
810
Value *lim = getOp (lc.trueLimit );
792
811
if (lim) {
793
- unwrap_cache[BuilderM.GetInsertBlock ()][idx] = lim;
812
+ unwrap_cache[BuilderM.GetInsertBlock ()][idx.first ][idx.second ] =
813
+ lim;
794
814
return lim;
795
815
}
796
816
} else if (mode == UnwrapMode::AttemptFullUnwrapWithLookup &&
@@ -804,7 +824,7 @@ Value *GradientUtils::unwrapM(Value *const val, IRBuilder<> &BuilderM,
804
824
/* forwardPass*/ false , BuilderM, lctx,
805
825
getDynamicLoopLimit (LI.getLoopFor (lc.header )),
806
826
/* isi1*/ false );
807
- unwrap_cache[BuilderM.GetInsertBlock ()][idx] = lim;
827
+ unwrap_cache[BuilderM.GetInsertBlock ()][idx. first ][idx. second ] = lim;
808
828
return lim;
809
829
}
810
830
}
@@ -1076,7 +1096,7 @@ Value *GradientUtils::unwrapM(Value *const val, IRBuilder<> &BuilderM,
1076
1096
toret->addIncoming (vals[i], endingBlocks[i]);
1077
1097
assert (val->getType () == toret->getType ());
1078
1098
if (permitCache) {
1079
- unwrap_cache[bret][idx] = toret;
1099
+ unwrap_cache[bret][idx. first ][idx. second ] = toret;
1080
1100
}
1081
1101
unwrappedLoads[toret] = val;
1082
1102
unwrap_cache[bret] = unwrap_cache[oldB];
@@ -1249,7 +1269,8 @@ Value *GradientUtils::unwrapM(Value *const val, IRBuilder<> &BuilderM,
1249
1269
Value *toret = BuilderM.CreateSelect (cond, vals[0 ], vals[1 ],
1250
1270
phi->getName () + " _unwrap" );
1251
1271
if (permitCache) {
1252
- unwrap_cache[BuilderM.GetInsertBlock ()][idx] = toret;
1272
+ unwrap_cache[BuilderM.GetInsertBlock ()][idx.first ][idx.second ] =
1273
+ toret;
1253
1274
}
1254
1275
if (auto instRet = dyn_cast<Instruction>(toret))
1255
1276
unwrappedLoads[instRet] = val;
@@ -1276,7 +1297,7 @@ Value *GradientUtils::unwrapM(Value *const val, IRBuilder<> &BuilderM,
1276
1297
toret->addIncoming (vals[i], endingBlocks[i]);
1277
1298
assert (val->getType () == toret->getType ());
1278
1299
if (permitCache) {
1279
- unwrap_cache[bret][idx] = toret;
1300
+ unwrap_cache[bret][idx. first ][idx. second ] = toret;
1280
1301
}
1281
1302
unwrap_cache[bret] = unwrap_cache[oldB];
1282
1303
lookup_cache[bret] = lookup_cache[oldB];
@@ -1510,9 +1531,15 @@ Value *GradientUtils::cacheForReverse(IRBuilder<> &BuilderQ, Value *malloc,
1510
1531
assert (malloc->getType () == ret->getType ());
1511
1532
}
1512
1533
1513
- if (replace)
1514
- if (auto orig = isOriginal (malloc))
1534
+ if (replace) {
1535
+ auto found = newToOriginalFn.find (malloc);
1536
+ if (found != newToOriginalFn.end ()) {
1537
+ Value *orig = found->second ;
1515
1538
originalToNewFn[orig] = ret;
1539
+ newToOriginalFn.erase (malloc);
1540
+ newToOriginalFn[ret] = orig;
1541
+ }
1542
+ }
1516
1543
1517
1544
if (auto found = findInMap (scopeMap, malloc)) {
1518
1545
// There already exists an alloaction for this, we should fully remove
@@ -3129,11 +3156,15 @@ Value *GradientUtils::lookupM(Value *val, IRBuilder<> &BuilderM,
3129
3156
if (lookup_cache[BuilderM.GetInsertBlock ()].find (val) !=
3130
3157
lookup_cache[BuilderM.GetInsertBlock ()].end ()) {
3131
3158
auto result = lookup_cache[BuilderM.GetInsertBlock ()][val];
3132
- assert (result);
3133
- assert (result->getType ());
3134
- result = BuilderM.CreateBitCast (result, val->getType ());
3135
- assert (result->getType () == inst->getType ());
3136
- return result;
3159
+ if (result == nullptr ) {
3160
+ lookup_cache[BuilderM.GetInsertBlock ()].erase (val);
3161
+ } else {
3162
+ assert (result);
3163
+ assert (result->getType ());
3164
+ result = BuilderM.CreateBitCast (result, val->getType ());
3165
+ assert (result->getType () == inst->getType ());
3166
+ return result;
3167
+ }
3137
3168
}
3138
3169
3139
3170
ValueToValueMapTy available;
@@ -3244,11 +3275,15 @@ Value *GradientUtils::lookupM(Value *val, IRBuilder<> &BuilderM,
3244
3275
if (lookup_cache[BuilderM.GetInsertBlock ()].find (val) !=
3245
3276
lookup_cache[BuilderM.GetInsertBlock ()].end ()) {
3246
3277
auto result = lookup_cache[BuilderM.GetInsertBlock ()][val];
3247
- assert (result);
3248
- assert (result->getType ());
3249
- result = BuilderM.CreateBitCast (result, val->getType ());
3250
- assert (result->getType () == inst->getType ());
3251
- return result;
3278
+ if (result == nullptr ) {
3279
+ lookup_cache[BuilderM.GetInsertBlock ()].erase (val);
3280
+ } else {
3281
+ assert (result);
3282
+ assert (result->getType ());
3283
+ result = BuilderM.CreateBitCast (result, val->getType ());
3284
+ assert (result->getType () == inst->getType ());
3285
+ return result;
3286
+ }
3252
3287
}
3253
3288
3254
3289
// TODO consider call as part of
0 commit comments