Skip to content

Commit 817e1c1

Browse files
committed
Extend type checker hack for @_unsafeInheritExecutor functions to methods
With the re-introduction of `@_unsafeInheritExecutor` for `TaskLocal.withValue`, we need to extend the type checker trick with `_unsafeInheritExecutor_`-prefixed functions to work with methods. Do so to make `TaskLocal.withValue` actually work this way.
1 parent 8964436 commit 817e1c1

File tree

4 files changed

+81
-3
lines changed

4 files changed

+81
-3
lines changed

lib/Sema/ConstraintSystem.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -293,6 +293,13 @@ LookupResult &ConstraintSystem::lookupMember(Type base, DeclNameRef name,
293293
result = TypeChecker::lookupMember(DC, base, name, loc,
294294
defaultMemberLookupOptions);
295295

296+
// If we are in an @_unsafeInheritExecutor context, swap out
297+
// declarations for their _unsafeInheritExecutor_ counterparts if they
298+
// exist.
299+
if (enclosingUnsafeInheritsExecutor(DC)) {
300+
introduceUnsafeInheritExecutorReplacements(DC, base, loc, *result);
301+
}
302+
296303
// If we aren't performing dynamic lookup, we're done.
297304
if (!*result || !base->isAnyObject())
298305
return *result;

lib/Sema/TypeCheckConcurrency.cpp

Lines changed: 57 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2111,17 +2111,18 @@ void swift::introduceUnsafeInheritExecutorReplacements(
21112111
// Make sure at least some of the entries are functions in the _Concurrency
21122112
// module.
21132113
ModuleDecl *concurrencyModule = nullptr;
2114+
DeclBaseName baseName;
21142115
for (auto decl: decls) {
21152116
if (isReplaceable(decl)) {
21162117
concurrencyModule = decl->getDeclContext()->getParentModule();
2118+
baseName = decl->getName().getBaseName();
21172119
break;
21182120
}
21192121
}
21202122
if (!concurrencyModule)
21212123
return;
21222124

2123-
// Dig out the name.
2124-
auto baseName = decls.front()->getName().getBaseName();
2125+
// Ignore anything with a special name.
21252126
if (baseName.isSpecial())
21262127
return;
21272128

@@ -2149,6 +2150,60 @@ void swift::introduceUnsafeInheritExecutorReplacements(
21492150
}
21502151
}
21512152

2153+
void swift::introduceUnsafeInheritExecutorReplacements(
2154+
const DeclContext *dc, Type base, SourceLoc loc, LookupResult &lookup) {
2155+
if (lookup.empty())
2156+
return;
2157+
2158+
auto baseNominal = base->getAnyNominal();
2159+
if (!baseNominal || !inConcurrencyModule(baseNominal))
2160+
return;
2161+
2162+
auto isReplaceable = [&](ValueDecl *decl) {
2163+
return isa<FuncDecl>(decl) && inConcurrencyModule(decl->getDeclContext());
2164+
};
2165+
2166+
// Make sure at least some of the entries are functions in the _Concurrency
2167+
// module.
2168+
ModuleDecl *concurrencyModule = nullptr;
2169+
DeclBaseName baseName;
2170+
for (auto &result: lookup) {
2171+
auto decl = result.getValueDecl();
2172+
if (isReplaceable(decl)) {
2173+
concurrencyModule = decl->getDeclContext()->getParentModule();
2174+
baseName = decl->getBaseName();
2175+
break;
2176+
}
2177+
}
2178+
if (!concurrencyModule)
2179+
return;
2180+
2181+
// Ignore anything with a special name.
2182+
if (baseName.isSpecial())
2183+
return;
2184+
2185+
// Look for entities with the _unsafeInheritExecutor_ prefix on the name.
2186+
ASTContext &ctx = base->getASTContext();
2187+
Identifier newIdentifier = ctx.getIdentifier(
2188+
("_unsafeInheritExecutor_" + baseName.getIdentifier().str()).str());
2189+
2190+
LookupResult replacementLookup = TypeChecker::lookupMember(
2191+
const_cast<DeclContext *>(dc), base, DeclNameRef(newIdentifier), loc,
2192+
defaultMemberLookupOptions);
2193+
if (replacementLookup.innerResults().empty())
2194+
return;
2195+
2196+
// Drop all of the _Concurrency entries in favor of the ones found by this
2197+
// lookup.
2198+
lookup.filter([&](const LookupResultEntry &entry, bool) {
2199+
return !isReplaceable(entry.getValueDecl());
2200+
});
2201+
2202+
for (const auto &entry: replacementLookup.innerResults()) {
2203+
lookup.add(entry, /*isOuter=*/false);
2204+
}
2205+
}
2206+
21522207
/// Check if it is safe for the \c globalActor qualifier to be removed from
21532208
/// \c ty, when the function value of that type is isolated to that actor.
21542209
///

lib/Sema/TypeCheckConcurrency.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ class EnumElementDecl;
4444
class Expr;
4545
class FuncDecl;
4646
class Initializer;
47+
class LookupResult;
4748
class PatternBindingDecl;
4849
class ProtocolConformance;
4950
class TopLevelCodeDecl;
@@ -668,6 +669,18 @@ void replaceUnsafeInheritExecutorWithDefaultedIsolationParam(
668669
void introduceUnsafeInheritExecutorReplacements(
669670
const DeclContext *dc, SourceLoc loc, SmallVectorImpl<ValueDecl *> &decls);
670671

672+
/// Replace any functions in this list that were found in the _Concurrency
673+
/// module as a member on "base" and have _unsafeInheritExecutor_-prefixed
674+
/// versions with those _unsafeInheritExecutor_-prefixed versions.
675+
///
676+
/// This function is an egregious hack that allows us to introduce the
677+
/// #isolation-based versions of functions into the concurrency library
678+
/// without breaking clients that use @_unsafeInheritExecutor. Since those
679+
/// clients can't use #isolation (it doesn't work with @_unsafeInheritExecutor),
680+
/// we route them to the @_unsafeInheritExecutor versions implicitly.
681+
void introduceUnsafeInheritExecutorReplacements(
682+
const DeclContext *dc, Type base, SourceLoc loc, LookupResult &result);
683+
671684
} // end namespace swift
672685

673686
namespace llvm {

test/Concurrency/unsafe_inherit_executor.swift

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,10 @@ func unsafeCallerAvoidsNewLoop() async throws {
113113
} onCancel: {
114114
}
115115

116-
TL.$string.withValue("hello") {
116+
await TL.$string.withValue("hello") {
117117
print(TL.string)
118118
}
119+
120+
func operation() async throws -> Int { 7 }
121+
try await TL.$string.withValue("hello", operation: operation)
119122
}

0 commit comments

Comments
 (0)