Skip to content

Commit d9cfcd6

Browse files
authored
mark unnecessary mallocs (rust-lang#662)
* mark unnecessary mallocs * Add test
1 parent 034a18f commit d9cfcd6

File tree

2 files changed

+65
-12
lines changed

2 files changed

+65
-12
lines changed

enzyme/Enzyme/EnzymeLogic.cpp

Lines changed: 27 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -781,10 +781,18 @@ void calculateUnusedValuesInFunction(
781781
return UseReq::Recur;
782782
}
783783
}
784-
if (auto ai = dyn_cast<AllocaInst>(at)) {
784+
bool newMemory = false;
785+
if (isa<AllocaInst>(at))
786+
newMemory = true;
787+
else if (auto CI = dyn_cast<CallInst>(at))
788+
if (Function *F = getFunctionFromCall(CI))
789+
if (isAllocationFunction(*F, TLI))
790+
newMemory = true;
791+
if (newMemory) {
785792
bool foundStore = false;
786793
allInstructionsBetween(
787-
gutils->OrigLI, ai, const_cast<MemTransferInst *>(mti),
794+
gutils->OrigLI, cast<Instruction>(at),
795+
const_cast<MemTransferInst *>(mti),
788796
[&](Instruction *I) -> bool {
789797
if (!I->mayWriteToMemory())
790798
return /*earlyBreak*/ false;
@@ -880,7 +888,7 @@ void calculateUnusedStoresInFunction(
880888
Function &func,
881889
llvm::SmallPtrSetImpl<const Instruction *> &unnecessaryStores,
882890
const llvm::SmallPtrSetImpl<const Instruction *> &unnecessaryInstructions,
883-
GradientUtils *gutils) {
891+
GradientUtils *gutils, TargetLibraryInfo &TLI) {
884892
calculateUnusedStores(func, unnecessaryStores, [&](const Instruction *inst) {
885893
if (auto si = dyn_cast<StoreInst>(inst)) {
886894
if (isa<UndefValue>(si->getValueOperand()))
@@ -891,15 +899,22 @@ void calculateUnusedStoresInFunction(
891899
#if LLVM_VERSION_MAJOR >= 12
892900
auto at = getUnderlyingObject(mti->getArgOperand(1), 100);
893901
#else
894-
auto at = GetUnderlyingObject(
895-
mti->getArgOperand(1),
896-
func.getParent()->getDataLayout(), 100);
902+
auto at = GetUnderlyingObject(
903+
mti->getArgOperand(1),
904+
func.getParent()->getDataLayout(), 100);
897905
#endif
898-
if (auto ai = dyn_cast<AllocaInst>(at)) {
906+
bool newMemory = false;
907+
if (isa<AllocaInst>(at))
908+
newMemory = true;
909+
else if (auto CI = dyn_cast<CallInst>(at))
910+
if (Function *F = getFunctionFromCall(CI))
911+
if (isAllocationFunction(*F, TLI))
912+
newMemory = true;
913+
if (newMemory) {
899914
bool foundStore = false;
900915
allInstructionsBetween(
901-
gutils->OrigLI, ai, const_cast<MemTransferInst *>(mti),
902-
[&](Instruction *I) -> bool {
916+
gutils->OrigLI, cast<Instruction>(at),
917+
const_cast<MemTransferInst *>(mti), [&](Instruction *I) -> bool {
903918
if (!I->mayWriteToMemory())
904919
return /*earlyBreak*/ false;
905920
if (unnecessaryInstructions.count(I))
@@ -1923,7 +1938,7 @@ const AugmentedReturn &EnzymeLogic::CreateAugmentedPrimal(
19231938

19241939
SmallPtrSet<const Instruction *, 4> unnecessaryStores;
19251940
calculateUnusedStoresInFunction(*gutils->oldFunc, unnecessaryStores,
1926-
unnecessaryInstructions, gutils);
1941+
unnecessaryInstructions, gutils, TLI);
19271942

19281943
insert_or_assign(AugmentedCachedFunctions, tup,
19291944
AugmentedReturn(gutils->newFunc, nullptr, {}, returnMapping,
@@ -3463,7 +3478,7 @@ Function *EnzymeLogic::CreatePrimalAndGradient(
34633478

34643479
SmallPtrSet<const Instruction *, 4> unnecessaryStores;
34653480
calculateUnusedStoresInFunction(*gutils->oldFunc, unnecessaryStores,
3466-
unnecessaryInstructions, gutils);
3481+
unnecessaryInstructions, gutils, TLI);
34673482

34683483
Value *additionalValue = nullptr;
34693484
if (key.additionalType) {
@@ -4057,7 +4072,7 @@ Function *EnzymeLogic::CreateForwardDiff(
40574072

40584073
SmallPtrSet<const Instruction *, 4> unnecessaryStores;
40594074
calculateUnusedStoresInFunction(*gutils->oldFunc, unnecessaryStores,
4060-
unnecessaryInstructions, gutils);
4075+
unnecessaryInstructions, gutils, TLI);
40614076

40624077
// set derivative of function arguments
40634078
auto newArgs = gutils->newFunc->arg_begin();
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
; RUN: %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -mem2reg -sroa -instsimplify -S | FileCheck %s
2+
3+
source_filename = "<source>"
4+
target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128"
5+
target triple = "x86_64-unknown-linux-gnu"
6+
7+
@enzyme_dup = dso_local global i32 0, align 4
8+
9+
define dso_local void @_Z6squarePi(i8* %i0) {
10+
%i2 = call noalias i8* @malloc(i64 16)
11+
call void @llvm.memcpy.p0i8.p0i8.i64(i8* %i0, i8* %i2, i64 16, i1 false)
12+
ret void
13+
}
14+
15+
; Function Attrs: nounwind
16+
declare dso_local noalias i8* @malloc(i64) #1
17+
18+
declare void @llvm.memcpy.p0i8.p0i8.i64(i8*, i8*, i64, i1)
19+
20+
define dso_local void @_Z7dsquarePdS_(double* %a0, double* %a1) {
21+
%ed = load i32, i32* @enzyme_dup, align 4
22+
call void @_Z17__enzyme_autodiffPviPdS0_(i8* bitcast (void (i8*)* @_Z6squarePi to i8*), i32 %ed, double* %a0, double* %a1)
23+
ret void
24+
}
25+
26+
declare dso_local void @_Z17__enzyme_autodiffPviPdS0_(i8*, i32, double*, double*)
27+
28+
; CHECK: define internal void @diffe_Z6squarePi(i8* %i0, i8* %"i0'")
29+
; CHECK-NEXT: %i2 = call noalias nonnull dereferenceable(16) dereferenceable_or_null(16) i8* @malloc(i64 16)
30+
; CHECK-NEXT: %"i2'mi" = call noalias nonnull dereferenceable(16) dereferenceable_or_null(16) i8* @malloc(i64 16)
31+
; CHECK-NEXT: call void @llvm.memset.p0i8.i64(i8* nonnull dereferenceable(16) dereferenceable_or_null(16) %"i2'mi", i8 0, i64 16, i1 false)
32+
; CHECK-NEXT: br label %invert
33+
34+
; CHECK: invert:
35+
; CHECK-NEXT: tail call void @free(i8* nonnull %"i2'mi")
36+
; CHECK-NEXT: tail call void @free(i8* nonnull %i2)
37+
; CHECK-NEXT: ret void
38+
; CHECK-NEXT: }

0 commit comments

Comments
 (0)