@@ -781,10 +781,18 @@ void calculateUnusedValuesInFunction(
781
781
return UseReq::Recur;
782
782
}
783
783
}
784
- if (auto ai = dyn_cast<AllocaInst>(at)) {
784
+ bool newMemory = false ;
785
+ if (isa<AllocaInst>(at))
786
+ newMemory = true ;
787
+ else if (auto CI = dyn_cast<CallInst>(at))
788
+ if (Function *F = getFunctionFromCall (CI))
789
+ if (isAllocationFunction (*F, TLI))
790
+ newMemory = true ;
791
+ if (newMemory) {
785
792
bool foundStore = false ;
786
793
allInstructionsBetween (
787
- gutils->OrigLI , ai, const_cast <MemTransferInst *>(mti),
794
+ gutils->OrigLI , cast<Instruction>(at),
795
+ const_cast <MemTransferInst *>(mti),
788
796
[&](Instruction *I) -> bool {
789
797
if (!I->mayWriteToMemory ())
790
798
return /* earlyBreak*/ false ;
@@ -880,7 +888,7 @@ void calculateUnusedStoresInFunction(
880
888
Function &func,
881
889
llvm::SmallPtrSetImpl<const Instruction *> &unnecessaryStores,
882
890
const llvm::SmallPtrSetImpl<const Instruction *> &unnecessaryInstructions,
883
- GradientUtils *gutils) {
891
+ GradientUtils *gutils, TargetLibraryInfo &TLI ) {
884
892
calculateUnusedStores (func, unnecessaryStores, [&](const Instruction *inst) {
885
893
if (auto si = dyn_cast<StoreInst>(inst)) {
886
894
if (isa<UndefValue>(si->getValueOperand ()))
@@ -891,15 +899,22 @@ void calculateUnusedStoresInFunction(
891
899
#if LLVM_VERSION_MAJOR >= 12
892
900
auto at = getUnderlyingObject (mti->getArgOperand (1 ), 100 );
893
901
#else
894
- auto at = GetUnderlyingObject (
895
- mti->getArgOperand (1 ),
896
- func.getParent ()->getDataLayout (), 100 );
902
+ auto at = GetUnderlyingObject (
903
+ mti->getArgOperand (1 ),
904
+ func.getParent ()->getDataLayout (), 100 );
897
905
#endif
898
- if (auto ai = dyn_cast<AllocaInst>(at)) {
906
+ bool newMemory = false ;
907
+ if (isa<AllocaInst>(at))
908
+ newMemory = true ;
909
+ else if (auto CI = dyn_cast<CallInst>(at))
910
+ if (Function *F = getFunctionFromCall (CI))
911
+ if (isAllocationFunction (*F, TLI))
912
+ newMemory = true ;
913
+ if (newMemory) {
899
914
bool foundStore = false ;
900
915
allInstructionsBetween (
901
- gutils->OrigLI , ai, const_cast <MemTransferInst *>(mti ),
902
- [&](Instruction *I) -> bool {
916
+ gutils->OrigLI , cast<Instruction>(at ),
917
+ const_cast <MemTransferInst *>(mti), [&](Instruction *I) -> bool {
903
918
if (!I->mayWriteToMemory ())
904
919
return /* earlyBreak*/ false ;
905
920
if (unnecessaryInstructions.count (I))
@@ -1923,7 +1938,7 @@ const AugmentedReturn &EnzymeLogic::CreateAugmentedPrimal(
1923
1938
1924
1939
SmallPtrSet<const Instruction *, 4 > unnecessaryStores;
1925
1940
calculateUnusedStoresInFunction (*gutils->oldFunc , unnecessaryStores,
1926
- unnecessaryInstructions, gutils);
1941
+ unnecessaryInstructions, gutils, TLI );
1927
1942
1928
1943
insert_or_assign (AugmentedCachedFunctions, tup,
1929
1944
AugmentedReturn (gutils->newFunc , nullptr , {}, returnMapping,
@@ -3463,7 +3478,7 @@ Function *EnzymeLogic::CreatePrimalAndGradient(
3463
3478
3464
3479
SmallPtrSet<const Instruction *, 4 > unnecessaryStores;
3465
3480
calculateUnusedStoresInFunction (*gutils->oldFunc , unnecessaryStores,
3466
- unnecessaryInstructions, gutils);
3481
+ unnecessaryInstructions, gutils, TLI );
3467
3482
3468
3483
Value *additionalValue = nullptr ;
3469
3484
if (key.additionalType ) {
@@ -4057,7 +4072,7 @@ Function *EnzymeLogic::CreateForwardDiff(
4057
4072
4058
4073
SmallPtrSet<const Instruction *, 4 > unnecessaryStores;
4059
4074
calculateUnusedStoresInFunction (*gutils->oldFunc , unnecessaryStores,
4060
- unnecessaryInstructions, gutils);
4075
+ unnecessaryInstructions, gutils, TLI );
4061
4076
4062
4077
// set derivative of function arguments
4063
4078
auto newArgs = gutils->newFunc ->arg_begin ();
0 commit comments