Skip to content

Commit a1ad99a

Browse files
committed
Swift handling
1 parent 69a224d commit a1ad99a

File tree

4 files changed

+100
-46
lines changed

4 files changed

+100
-46
lines changed

enzyme/Enzyme/ActivityAnalysis.cpp

Lines changed: 44 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -79,47 +79,49 @@ cl::opt<bool>
7979
#include <unordered_map>
8080

8181
const char *KnownInactiveFunctionsStartingWith[] = {
82-
"_ZN4core3fmt", "_ZN3std2io5stdio6_print", "f90io"};
83-
84-
std::set<std::string> KnownInactiveFunctions = {"__assert_fail",
85-
"__cxa_guard_acquire",
86-
"__cxa_guard_release",
87-
"__cxa_guard_abort",
88-
"printf",
89-
"vprintf",
90-
"puts",
91-
"__enzyme_float",
92-
"__enzyme_double",
93-
"__enzyme_integer",
94-
"__enzyme_pointer",
95-
"__kmpc_for_static_init_4",
96-
"__kmpc_for_static_init_4u",
97-
"__kmpc_for_static_init_8",
98-
"__kmpc_for_static_init_8u",
99-
"__kmpc_for_static_fini",
100-
"__kmpc_dispatch_init_4",
101-
"__kmpc_dispatch_init_4u",
102-
"__kmpc_dispatch_init_8",
103-
"__kmpc_dispatch_init_8u",
104-
"__kmpc_dispatch_next_4",
105-
"__kmpc_dispatch_next_4u",
106-
"__kmpc_dispatch_next_8",
107-
"__kmpc_dispatch_next_8u",
108-
"__kmpc_dispatch_fini_4",
109-
"__kmpc_dispatch_fini_4u",
110-
"__kmpc_dispatch_fini_8",
111-
"__kmpc_dispatch_fini_8u",
112-
"malloc_usable_size",
113-
"malloc_size",
114-
"MPI_Init",
115-
"MPI_Comm_size",
116-
"MPI_Comm_rank",
117-
"MPI_Get_processor_name",
118-
"MPI_Finalize",
119-
"_msize",
120-
"ftnio_fmt_write64",
121-
"f90_strcmp_klen",
122-
"vprintf"};
82+
"_ZN4core3fmt", "_ZN3std2io5stdio6_print", "f90io", "$ss5print"};
83+
84+
std::set<std::string> KnownInactiveFunctions = {
85+
"__assert_fail",
86+
"__cxa_guard_acquire",
87+
"__cxa_guard_release",
88+
"__cxa_guard_abort",
89+
"printf",
90+
"vprintf",
91+
"puts",
92+
"__enzyme_float",
93+
"__enzyme_double",
94+
"__enzyme_integer",
95+
"__enzyme_pointer",
96+
"__kmpc_for_static_init_4",
97+
"__kmpc_for_static_init_4u",
98+
"__kmpc_for_static_init_8",
99+
"__kmpc_for_static_init_8u",
100+
"__kmpc_for_static_fini",
101+
"__kmpc_dispatch_init_4",
102+
"__kmpc_dispatch_init_4u",
103+
"__kmpc_dispatch_init_8",
104+
"__kmpc_dispatch_init_8u",
105+
"__kmpc_dispatch_next_4",
106+
"__kmpc_dispatch_next_4u",
107+
"__kmpc_dispatch_next_8",
108+
"__kmpc_dispatch_next_8u",
109+
"__kmpc_dispatch_fini_4",
110+
"__kmpc_dispatch_fini_4u",
111+
"__kmpc_dispatch_fini_8",
112+
"__kmpc_dispatch_fini_8u",
113+
"malloc_usable_size",
114+
"malloc_size",
115+
"MPI_Init",
116+
"MPI_Comm_size",
117+
"MPI_Comm_rank",
118+
"MPI_Get_processor_name",
119+
"MPI_Finalize",
120+
"_msize",
121+
"ftnio_fmt_write64",
122+
"f90_strcmp_klen",
123+
"vprintf",
124+
"__swift_instantiateConcreteTypeFromMangledName"};
123125

124126
/// Is the use of value val as an argument of call CI known to be inactive
125127
/// This tool can only be used when in DOWN mode
@@ -2007,4 +2009,4 @@ void ActivityAnalyzer::InsertConstantValue(TypeResults &TR, llvm::Value *V) {
20072009
isConstantInstruction(TR, toeval);
20082010
}
20092011
}
2010-
}
2012+
}

enzyme/Enzyme/GradientUtils.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -624,9 +624,18 @@ class GradientUtils : public CacheUtility {
624624
EmitFailure("SplitGCAllocation", orig->getDebugLoc(), orig,
625625
"Not handling Julia shadow GC allocation in split mode ",
626626
*orig);
627+
return anti;
627628
}
628629
}
629630

631+
if (orig->getCalledFunction()->getName() == "swift_allocObject") {
632+
EmitFailure(
633+
"SwiftShadowAllocation", orig->getDebugLoc(), orig,
634+
"Haven't implemented shadow allocator for `swift_allocObject`",
635+
*orig);
636+
return anti;
637+
}
638+
630639
Value *dst_arg = anti;
631640

632641
dst_arg = bb.CreateBitCast(

enzyme/Enzyme/LibraryFuncs.h

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,8 @@ static inline bool isAllocationFunction(const llvm::Function &F,
4444
const llvm::TargetLibraryInfo &TLI) {
4545
if (F.getName() == "calloc")
4646
return true;
47+
if (F.getName() == "swift_allocObject")
48+
return true;
4749
if (F.getName() == "__rust_alloc" || F.getName() == "__rust_alloc_zeroed")
4850
return true;
4951
if (F.getName() == "julia.gc_alloc_obj")
@@ -125,6 +127,8 @@ static inline bool isDeallocationFunction(const llvm::Function &F,
125127
return true;
126128
if (F.getName() == "__rust_dealloc")
127129
return true;
130+
if (F.getName() == "swift_release")
131+
return true;
128132
return false;
129133
}
130134

@@ -209,6 +213,41 @@ freeKnownAllocation(llvm::IRBuilder<> &builder, llvm::Value *tofree,
209213
if (allocationfn.getName() == "julia.gc_alloc_obj")
210214
return nullptr;
211215

216+
if (allocationfn.getName() == "swift_allocObject") {
217+
Type *VoidTy = Type::getVoidTy(tofree->getContext());
218+
Type *IntPtrTy = Type::getInt8PtrTy(tofree->getContext());
219+
220+
auto FT = FunctionType::get(VoidTy, {IntPtrTy}, false);
221+
#if LLVM_VERSION_MAJOR >= 9
222+
Value *freevalue = allocationfn.getParent()
223+
->getOrInsertFunction("swift_release", FT)
224+
.getCallee();
225+
#else
226+
Value *freevalue =
227+
allocationfn.getParent()->getOrInsertFunction("swift_release", FT);
228+
#endif
229+
CallInst *freecall = cast<CallInst>(
230+
#if LLVM_VERSION_MAJOR >= 8
231+
CallInst::Create(FT, freevalue,
232+
{builder.CreatePointerCast(tofree, IntPtrTy)},
233+
#else
234+
CallInst::Create(freevalue,
235+
{builder.CreatePointerCast(tofree, IntPtrTy)},
236+
#endif
237+
"", builder.GetInsertBlock()));
238+
freecall->setTailCall();
239+
if (isa<CallInst>(tofree) &&
240+
cast<CallInst>(tofree)->getAttributes().hasAttribute(
241+
AttributeList::ReturnIndex, Attribute::NonNull)) {
242+
freecall->addAttribute(AttributeList::FirstArgIndex, Attribute::NonNull);
243+
}
244+
if (Function *F = dyn_cast<Function>(freevalue))
245+
freecall->setCallingConv(F->getCallingConv());
246+
if (freecall->getParent() == nullptr)
247+
builder.Insert(freecall);
248+
return freecall;
249+
}
250+
212251
if (shadowErasers.find(allocationfn.getName().str()) != shadowErasers.end()) {
213252
return shadowErasers[allocationfn.getName().str()](builder, tofree,
214253
&allocationfn);

enzyme/Enzyme/Utils.h

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -424,7 +424,8 @@ static inline bool isCertainMallocOrFree(llvm::Function *called) {
424424
if (called->getName() == "printf" || called->getName() == "puts" ||
425425
called->getName() == "malloc" || called->getName() == "_Znwm" ||
426426
called->getName() == "_ZdlPv" || called->getName() == "_ZdlPvm" ||
427-
called->getName() == "free" ||
427+
called->getName() == "free" || called->getName() == "swift_allocObject" ||
428+
called->getName() == "swift_release" ||
428429
shadowHandlers.find(called->getName().str()) != shadowHandlers.end())
429430
return true;
430431
switch (called->getIntrinsicID()) {
@@ -455,7 +456,8 @@ static inline bool isCertainPrintOrFree(llvm::Function *called) {
455456
called->getName().startswith("_ZN3std2io5stdio6_print") ||
456457
called->getName().startswith("_ZN4core3fmt") ||
457458
called->getName() == "vprintf" || called->getName() == "_ZdlPv" ||
458-
called->getName() == "_ZdlPvm" || called->getName() == "free")
459+
called->getName() == "_ZdlPvm" || called->getName() == "free" ||
460+
called->getName() == "swift_release")
459461
return true;
460462
switch (called->getIntrinsicID()) {
461463
case llvm::Intrinsic::dbg_declare:
@@ -484,8 +486,10 @@ static inline bool isCertainPrintMallocOrFree(llvm::Function *called) {
484486
called->getName().startswith("_ZN3std2io5stdio6_print") ||
485487
called->getName().startswith("_ZN4core3fmt") ||
486488
called->getName() == "vprintf" || called->getName() == "malloc" ||
487-
called->getName() == "_Znwm" || called->getName() == "_ZdlPv" ||
488-
called->getName() == "_ZdlPvm" || called->getName() == "free" ||
489+
called->getName() == "swift_allocObject" ||
490+
called->getName() == "swift_release" || called->getName() == "_Znwm" ||
491+
called->getName() == "_ZdlPv" || called->getName() == "_ZdlPvm" ||
492+
called->getName() == "free" ||
489493
shadowHandlers.find(called->getName().str()) != shadowHandlers.end())
490494
return true;
491495
switch (called->getIntrinsicID()) {

0 commit comments

Comments
 (0)