Skip to content

Commit 56c72c7

Browse files
committed
[ORC] Add a public unsafe-operations helper for SymbolStringPtr.
SymbolStringPoolEntryUnsafe provides unsafe access to SymbolStringPtr objects, allowing clients to manually retain and release pool entries, or consume or create SymbolStringPtr instances without affecting an entry's ref-count. This can be useful when writing C APIs that need to handle SymbolStringPtrs. As part of this patch the LLVM-C API implementation is updated to use the new utility, rather than the old, private OrcV2CAPIHelper utility.
1 parent 7138fab commit 56c72c7

File tree

3 files changed

+112
-64
lines changed

3 files changed

+112
-64
lines changed

llvm/include/llvm/ExecutionEngine/Orc/SymbolStringPool.h

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ class NonOwningSymbolStringPtr;
3232
class SymbolStringPool {
3333
friend class SymbolStringPoolTest;
3434
friend class SymbolStringPtrBase;
35+
friend class SymbolStringPoolEntryUnsafe;
3536

3637
// Implemented in DebugUtils.h.
3738
friend raw_ostream &operator<<(raw_ostream &OS, const SymbolStringPool &SSP);
@@ -134,8 +135,8 @@ class SymbolStringPtrBase {
134135

135136
/// Pointer to a pooled string representing a symbol name.
136137
class SymbolStringPtr : public SymbolStringPtrBase {
137-
friend class OrcV2CAPIHelper;
138138
friend class SymbolStringPool;
139+
friend class SymbolStringPoolEntryUnsafe;
139140
friend struct DenseMapInfo<SymbolStringPtr>;
140141

141142
public:
@@ -189,6 +190,47 @@ class SymbolStringPtr : public SymbolStringPtrBase {
189190
}
190191
};
191192

193+
/// Provides unsafe access to ownership operations on SymbolStringPtr.
194+
/// This class can be used to manage SymbolStringPtr instances from C.
195+
class SymbolStringPoolEntryUnsafe {
196+
public:
197+
using PoolEntry = SymbolStringPool::PoolMapEntry;
198+
199+
SymbolStringPoolEntryUnsafe(PoolEntry *E) : E(E) {}
200+
201+
/// Create an unsafe pool entry ref without changing the ref-count.
202+
static SymbolStringPoolEntryUnsafe from(const SymbolStringPtr &S) {
203+
return S.S;
204+
}
205+
206+
/// Consumes the given SymbolStringPtr without releasing the pool entry.
207+
static SymbolStringPoolEntryUnsafe take(SymbolStringPtr &&S) {
208+
PoolEntry *E = nullptr;
209+
std::swap(E, S.S);
210+
return E;
211+
}
212+
213+
PoolEntry *rawPtr() { return E; }
214+
215+
/// Creates a SymbolStringPtr for this entry, with the SymbolStringPtr
216+
/// retaining the entry as usual.
217+
SymbolStringPtr copyToSymbolStringPtr() { return SymbolStringPtr(E); }
218+
219+
/// Creates a SymbolStringPtr for this entry *without* performing a retain
220+
/// operation during construction.
221+
SymbolStringPtr moveToSymbolStringPtr() {
222+
SymbolStringPtr S;
223+
std::swap(S.S, E);
224+
return S;
225+
}
226+
227+
void retain() { ++E->getValue(); }
228+
void release() { --E->getValue(); }
229+
230+
private:
231+
PoolEntry *E = nullptr;
232+
};
233+
192234
/// Non-owning SymbolStringPool entry pointer. Instances are comparable with
193235
/// SymbolStringPtr instances and guaranteed to have the same hash, but do not
194236
/// affect the ref-count of the pooled string (and are therefore cheaper to

llvm/lib/ExecutionEngine/Orc/OrcV2CBindings.cpp

Lines changed: 31 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -27,42 +27,6 @@ class InProgressLookupState;
2727

2828
class OrcV2CAPIHelper {
2929
public:
30-
using PoolEntry = SymbolStringPtr::PoolEntry;
31-
using PoolEntryPtr = SymbolStringPtr::PoolEntryPtr;
32-
33-
// Move from SymbolStringPtr to PoolEntryPtr (no change in ref count).
34-
static PoolEntryPtr moveFromSymbolStringPtr(SymbolStringPtr S) {
35-
PoolEntryPtr Result = nullptr;
36-
std::swap(Result, S.S);
37-
return Result;
38-
}
39-
40-
// Move from a PoolEntryPtr to a SymbolStringPtr (no change in ref count).
41-
static SymbolStringPtr moveToSymbolStringPtr(PoolEntryPtr P) {
42-
SymbolStringPtr S;
43-
S.S = P;
44-
return S;
45-
}
46-
47-
// Copy a pool entry to a SymbolStringPtr (increments ref count).
48-
static SymbolStringPtr copyToSymbolStringPtr(PoolEntryPtr P) {
49-
return SymbolStringPtr(P);
50-
}
51-
52-
static PoolEntryPtr getRawPoolEntryPtr(const SymbolStringPtr &S) {
53-
return S.S;
54-
}
55-
56-
static void retainPoolEntry(PoolEntryPtr P) {
57-
SymbolStringPtr S(P);
58-
S.S = nullptr;
59-
}
60-
61-
static void releasePoolEntry(PoolEntryPtr P) {
62-
SymbolStringPtr S;
63-
S.S = P;
64-
}
65-
6630
static InProgressLookupState *extractLookupState(LookupState &LS) {
6731
return LS.IPLS.release();
6832
}
@@ -75,10 +39,16 @@ class OrcV2CAPIHelper {
7539
} // namespace orc
7640
} // namespace llvm
7741

42+
inline LLVMOrcSymbolStringPoolEntryRef wrap(SymbolStringPoolEntryUnsafe E) {
43+
return reinterpret_cast<LLVMOrcSymbolStringPoolEntryRef>(E.rawPtr());
44+
}
45+
46+
inline SymbolStringPoolEntryUnsafe unwrap(LLVMOrcSymbolStringPoolEntryRef E) {
47+
return reinterpret_cast<SymbolStringPoolEntryUnsafe::PoolEntry *>(E);
48+
}
49+
7850
DEFINE_SIMPLE_CONVERSION_FUNCTIONS(ExecutionSession, LLVMOrcExecutionSessionRef)
7951
DEFINE_SIMPLE_CONVERSION_FUNCTIONS(SymbolStringPool, LLVMOrcSymbolStringPoolRef)
80-
DEFINE_SIMPLE_CONVERSION_FUNCTIONS(OrcV2CAPIHelper::PoolEntry,
81-
LLVMOrcSymbolStringPoolEntryRef)
8252
DEFINE_SIMPLE_CONVERSION_FUNCTIONS(MaterializationUnit,
8353
LLVMOrcMaterializationUnitRef)
8454
DEFINE_SIMPLE_CONVERSION_FUNCTIONS(MaterializationResponsibility,
@@ -136,7 +106,7 @@ class OrcCAPIMaterializationUnit : public llvm::orc::MaterializationUnit {
136106

137107
private:
138108
void discard(const JITDylib &JD, const SymbolStringPtr &Name) override {
139-
Discard(Ctx, wrap(&JD), wrap(OrcV2CAPIHelper::getRawPoolEntryPtr(Name)));
109+
Discard(Ctx, wrap(&JD), wrap(SymbolStringPoolEntryUnsafe::from(Name)));
140110
}
141111

142112
std::string Name;
@@ -184,7 +154,7 @@ static SymbolMap toSymbolMap(LLVMOrcCSymbolMapPairs Syms, size_t NumPairs) {
184154
SymbolMap SM;
185155
for (size_t I = 0; I != NumPairs; ++I) {
186156
JITSymbolFlags Flags = toJITSymbolFlags(Syms[I].Sym.Flags);
187-
SM[OrcV2CAPIHelper::moveToSymbolStringPtr(unwrap(Syms[I].Name))] = {
157+
SM[unwrap(Syms[I].Name).moveToSymbolStringPtr()] = {
188158
ExecutorAddr(Syms[I].Sym.Address), Flags};
189159
}
190160
return SM;
@@ -199,7 +169,7 @@ toSymbolDependenceMap(LLVMOrcCDependenceMapPairs Pairs, size_t NumPairs) {
199169

200170
for (size_t J = 0; J != Pairs[I].Names.Length; ++J) {
201171
auto Sym = Pairs[I].Names.Symbols[J];
202-
Names.insert(OrcV2CAPIHelper::moveToSymbolStringPtr(unwrap(Sym)));
172+
Names.insert(unwrap(Sym).moveToSymbolStringPtr());
203173
}
204174
SDM[JD] = Names;
205175
}
@@ -309,7 +279,7 @@ class CAPIDefinitionGenerator final : public DefinitionGenerator {
309279
CLookupSet.reserve(LookupSet.size());
310280
for (auto &KV : LookupSet) {
311281
LLVMOrcSymbolStringPoolEntryRef Name =
312-
::wrap(OrcV2CAPIHelper::getRawPoolEntryPtr(KV.first));
282+
::wrap(SymbolStringPoolEntryUnsafe::from(KV.first));
313283
LLVMOrcSymbolLookupFlags SLF = fromSymbolLookupFlags(KV.second);
314284
CLookupSet.push_back({Name, SLF});
315285
}
@@ -353,8 +323,7 @@ void LLVMOrcSymbolStringPoolClearDeadEntries(LLVMOrcSymbolStringPoolRef SSP) {
353323

354324
LLVMOrcSymbolStringPoolEntryRef
355325
LLVMOrcExecutionSessionIntern(LLVMOrcExecutionSessionRef ES, const char *Name) {
356-
return wrap(
357-
OrcV2CAPIHelper::moveFromSymbolStringPtr(unwrap(ES)->intern(Name)));
326+
return wrap(SymbolStringPoolEntryUnsafe::take(unwrap(ES)->intern(Name)));
358327
}
359328

360329
void LLVMOrcExecutionSessionLookup(
@@ -374,7 +343,7 @@ void LLVMOrcExecutionSessionLookup(
374343

375344
SymbolLookupSet SLS;
376345
for (size_t I = 0; I != SymbolsSize; ++I)
377-
SLS.add(OrcV2CAPIHelper::moveToSymbolStringPtr(unwrap(Symbols[I].Name)),
346+
SLS.add(unwrap(Symbols[I].Name).moveToSymbolStringPtr(),
378347
toSymbolLookupFlags(Symbols[I].LookupFlags));
379348

380349
unwrap(ES)->lookup(
@@ -384,7 +353,7 @@ void LLVMOrcExecutionSessionLookup(
384353
SmallVector<LLVMOrcCSymbolMapPair> CResult;
385354
for (auto &KV : *Result)
386355
CResult.push_back(LLVMOrcCSymbolMapPair{
387-
wrap(OrcV2CAPIHelper::getRawPoolEntryPtr(KV.first)),
356+
wrap(SymbolStringPoolEntryUnsafe::from(KV.first)),
388357
fromExecutorSymbolDef(KV.second)});
389358
HandleResult(LLVMErrorSuccess, CResult.data(), CResult.size(), Ctx);
390359
} else
@@ -394,15 +363,15 @@ void LLVMOrcExecutionSessionLookup(
394363
}
395364

396365
void LLVMOrcRetainSymbolStringPoolEntry(LLVMOrcSymbolStringPoolEntryRef S) {
397-
OrcV2CAPIHelper::retainPoolEntry(unwrap(S));
366+
unwrap(S).retain();
398367
}
399368

400369
void LLVMOrcReleaseSymbolStringPoolEntry(LLVMOrcSymbolStringPoolEntryRef S) {
401-
OrcV2CAPIHelper::releasePoolEntry(unwrap(S));
370+
unwrap(S).release();
402371
}
403372

404373
const char *LLVMOrcSymbolStringPoolEntryStr(LLVMOrcSymbolStringPoolEntryRef S) {
405-
return unwrap(S)->getKey().data();
374+
return unwrap(S).rawPtr()->getKey().data();
406375
}
407376

408377
LLVMOrcResourceTrackerRef
@@ -452,10 +421,10 @@ LLVMOrcMaterializationUnitRef LLVMOrcCreateCustomMaterializationUnit(
452421
LLVMOrcMaterializationUnitDestroyFunction Destroy) {
453422
SymbolFlagsMap SFM;
454423
for (size_t I = 0; I != NumSyms; ++I)
455-
SFM[OrcV2CAPIHelper::moveToSymbolStringPtr(unwrap(Syms[I].Name))] =
424+
SFM[unwrap(Syms[I].Name).moveToSymbolStringPtr()] =
456425
toJITSymbolFlags(Syms[I].Flags);
457426

458-
auto IS = OrcV2CAPIHelper::moveToSymbolStringPtr(unwrap(InitSym));
427+
auto IS = unwrap(InitSym).moveToSymbolStringPtr();
459428

460429
return wrap(new OrcCAPIMaterializationUnit(
461430
Name, std::move(SFM), std::move(IS), Ctx, Materialize, Discard, Destroy));
@@ -476,9 +445,8 @@ LLVMOrcMaterializationUnitRef LLVMOrcLazyReexports(
476445
for (size_t I = 0; I != NumPairs; ++I) {
477446
auto pair = CallableAliases[I];
478447
JITSymbolFlags Flags = toJITSymbolFlags(pair.Entry.Flags);
479-
SymbolStringPtr Name =
480-
OrcV2CAPIHelper::moveToSymbolStringPtr(unwrap(pair.Entry.Name));
481-
SAM[OrcV2CAPIHelper::moveToSymbolStringPtr(unwrap(pair.Name))] =
448+
SymbolStringPtr Name = unwrap(pair.Entry.Name).moveToSymbolStringPtr();
449+
SAM[unwrap(pair.Name).moveToSymbolStringPtr()] =
482450
SymbolAliasMapEntry(Name, Flags);
483451
}
484452

@@ -511,7 +479,7 @@ LLVMOrcCSymbolFlagsMapPairs LLVMOrcMaterializationResponsibilityGetSymbols(
511479
safe_malloc(Symbols.size() * sizeof(LLVMOrcCSymbolFlagsMapPair)));
512480
size_t I = 0;
513481
for (auto const &pair : Symbols) {
514-
auto Name = wrap(OrcV2CAPIHelper::getRawPoolEntryPtr(pair.first));
482+
auto Name = wrap(SymbolStringPoolEntryUnsafe::from(pair.first));
515483
auto Flags = pair.second;
516484
Result[I] = {Name, fromJITSymbolFlags(Flags)};
517485
I++;
@@ -528,7 +496,7 @@ LLVMOrcSymbolStringPoolEntryRef
528496
LLVMOrcMaterializationResponsibilityGetInitializerSymbol(
529497
LLVMOrcMaterializationResponsibilityRef MR) {
530498
auto Sym = unwrap(MR)->getInitializerSymbol();
531-
return wrap(OrcV2CAPIHelper::getRawPoolEntryPtr(Sym));
499+
return wrap(SymbolStringPoolEntryUnsafe::from(Sym));
532500
}
533501

534502
LLVMOrcSymbolStringPoolEntryRef *
@@ -541,7 +509,7 @@ LLVMOrcMaterializationResponsibilityGetRequestedSymbols(
541509
Symbols.size() * sizeof(LLVMOrcSymbolStringPoolEntryRef)));
542510
size_t I = 0;
543511
for (auto &Name : Symbols) {
544-
Result[I] = wrap(OrcV2CAPIHelper::getRawPoolEntryPtr(Name));
512+
Result[I] = wrap(SymbolStringPoolEntryUnsafe::from(Name));
545513
I++;
546514
}
547515
*NumSymbols = Symbols.size();
@@ -569,7 +537,7 @@ LLVMErrorRef LLVMOrcMaterializationResponsibilityDefineMaterializing(
569537
LLVMOrcCSymbolFlagsMapPairs Syms, size_t NumSyms) {
570538
SymbolFlagsMap SFM;
571539
for (size_t I = 0; I != NumSyms; ++I)
572-
SFM[OrcV2CAPIHelper::moveToSymbolStringPtr(unwrap(Syms[I].Name))] =
540+
SFM[unwrap(Syms[I].Name).moveToSymbolStringPtr()] =
573541
toJITSymbolFlags(Syms[I].Flags);
574542

575543
return wrap(unwrap(MR)->defineMaterializing(std::move(SFM)));
@@ -588,7 +556,7 @@ LLVMErrorRef LLVMOrcMaterializationResponsibilityDelegate(
588556
LLVMOrcMaterializationResponsibilityRef *Result) {
589557
SymbolNameSet Syms;
590558
for (size_t I = 0; I != NumSymbols; I++) {
591-
Syms.insert(OrcV2CAPIHelper::moveToSymbolStringPtr(unwrap(Symbols[I])));
559+
Syms.insert(unwrap(Symbols[I]).moveToSymbolStringPtr());
592560
}
593561
auto OtherMR = unwrap(MR)->delegate(Syms);
594562

@@ -605,7 +573,7 @@ void LLVMOrcMaterializationResponsibilityAddDependencies(
605573
LLVMOrcCDependenceMapPairs Dependencies, size_t NumPairs) {
606574

607575
SymbolDependenceMap SDM = toSymbolDependenceMap(Dependencies, NumPairs);
608-
auto Sym = OrcV2CAPIHelper::moveToSymbolStringPtr(unwrap(Name));
576+
auto Sym = unwrap(Name).moveToSymbolStringPtr();
609577
unwrap(MR)->addDependencies(Sym, SDM);
610578
}
611579

@@ -698,7 +666,7 @@ LLVMErrorRef LLVMOrcCreateDynamicLibrarySearchGeneratorForProcess(
698666
DynamicLibrarySearchGenerator::SymbolPredicate Pred;
699667
if (Filter)
700668
Pred = [=](const SymbolStringPtr &Name) -> bool {
701-
return Filter(FilterCtx, wrap(OrcV2CAPIHelper::getRawPoolEntryPtr(Name)));
669+
return Filter(FilterCtx, wrap(SymbolStringPoolEntryUnsafe::from(Name)));
702670
};
703671

704672
auto ProcessSymsGenerator =
@@ -724,7 +692,7 @@ LLVMErrorRef LLVMOrcCreateDynamicLibrarySearchGeneratorForPath(
724692
DynamicLibrarySearchGenerator::SymbolPredicate Pred;
725693
if (Filter)
726694
Pred = [=](const SymbolStringPtr &Name) -> bool {
727-
return Filter(FilterCtx, wrap(OrcV2CAPIHelper::getRawPoolEntryPtr(Name)));
695+
return Filter(FilterCtx, wrap(SymbolStringPoolEntryUnsafe::from(Name)));
728696
};
729697

730698
auto LibrarySymsGenerator =
@@ -992,7 +960,7 @@ char LLVMOrcLLJITGetGlobalPrefix(LLVMOrcLLJITRef J) {
992960

993961
LLVMOrcSymbolStringPoolEntryRef
994962
LLVMOrcLLJITMangleAndIntern(LLVMOrcLLJITRef J, const char *UnmangledName) {
995-
return wrap(OrcV2CAPIHelper::moveFromSymbolStringPtr(
963+
return wrap(SymbolStringPoolEntryUnsafe::take(
996964
unwrap(J)->mangleAndIntern(UnmangledName)));
997965
}
998966

llvm/unittests/ExecutionEngine/Orc/SymbolStringPoolTest.cpp

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,4 +142,42 @@ TEST_F(SymbolStringPoolTest, NonOwningPointerRefCounts) {
142142
<< "Copy-assignment of NonOwningSymbolStringPtr changed ref-count";
143143
}
144144
}
145+
146+
TEST_F(SymbolStringPoolTest, SymbolStringPoolEntryUnsafe) {
147+
148+
auto A = SP.intern("a");
149+
EXPECT_EQ(getRefCount(A), 1U);
150+
151+
{
152+
// Try creating an unsafe pool entry ref from the given SymbolStringPtr.
153+
// This should not affect the ref-count.
154+
auto AUnsafe = SymbolStringPoolEntryUnsafe::from(A);
155+
EXPECT_EQ(getRefCount(A), 1U);
156+
157+
// Create a new SymbolStringPtr from the unsafe ref. This should increment
158+
// the ref-count.
159+
auto ACopy = AUnsafe.copyToSymbolStringPtr();
160+
EXPECT_EQ(getRefCount(A), 2U);
161+
}
162+
163+
{
164+
// Create a copy of the original string. Move it into an unsafe ref, and
165+
// then move it back. None of these operations should affect the ref-count.
166+
auto ACopy = A;
167+
EXPECT_EQ(getRefCount(A), 2U);
168+
auto AUnsafe = SymbolStringPoolEntryUnsafe::take(std::move(ACopy));
169+
EXPECT_EQ(getRefCount(A), 2U);
170+
ACopy = AUnsafe.moveToSymbolStringPtr();
171+
EXPECT_EQ(getRefCount(A), 2U);
172+
}
173+
174+
// Test manual retain / release.
175+
auto AUnsafe = SymbolStringPoolEntryUnsafe::from(A);
176+
EXPECT_EQ(getRefCount(A), 1U);
177+
AUnsafe.retain();
178+
EXPECT_EQ(getRefCount(A), 2U);
179+
AUnsafe.release();
180+
EXPECT_EQ(getRefCount(A), 1U);
181+
}
182+
145183
} // namespace

0 commit comments

Comments
 (0)