Skip to content

Commit 0cc5297

Browse files
authored
Merge pull request #70635 from DougGregor/async-sequence-typed-throws
Adopt typed throws in AsyncIteratorProtocol and AsyncSequence
2 parents 8f4adb3 + 4c990dc commit 0cc5297

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

52 files changed

+1312
-87
lines changed

include/swift/AST/ASTContext.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -650,6 +650,9 @@ class ASTContext final {
650650
/// Get AsyncIteratorProtocol.next().
651651
FuncDecl *getAsyncIteratorNext() const;
652652

653+
/// Get AsyncIteratorProtocol.next(actor).
654+
FuncDecl *getAsyncIteratorNextIsolated() const;
655+
653656
/// Check whether the standard library provides all the correct
654657
/// intrinsic support for Optional<T>.
655658
///

include/swift/AST/Effects.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,11 @@ enum class PolymorphicEffectKind : uint8_t {
8787
/// This is the conformance-based 'rethrows' /'reasync' case.
8888
ByConformance,
8989

90+
/// The function is only permitted to be `rethrows` because it depends
91+
/// on a conformance to `AsyncSequence` or `AsyncIteratorProtocol`,
92+
/// which historically were "@rethrows" protocols.
93+
AsyncSequenceRethrows,
94+
9095
/// The function has this effect unconditionally.
9196
///
9297
/// This is a plain old 'throws' / 'async' function.

include/swift/AST/KnownIdentifiers.def

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,7 @@ IDENTIFIER(enqueue)
9090
IDENTIFIER(erasing)
9191
IDENTIFIER(error)
9292
IDENTIFIER(errorDomain)
93+
IDENTIFIER(Failure)
9394
IDENTIFIER(first)
9495
IDENTIFIER(forKeyedSubscript)
9596
IDENTIFIER(Foundation)

include/swift/AST/Stmt.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -971,6 +971,7 @@ class ForEachStmt : public LabeledStmt {
971971

972972
// Set by Sema:
973973
ProtocolConformanceRef sequenceConformance = ProtocolConformanceRef();
974+
Type sequenceType;
974975
PatternBindingDecl *iteratorVar = nullptr;
975976
Expr *nextCall = nullptr;
976977
OpaqueValueExpr *elementExpr = nullptr;
@@ -1001,9 +1002,12 @@ class ForEachStmt : public LabeledStmt {
10011002
void setConvertElementExpr(Expr *expr) { convertElementExpr = expr; }
10021003
Expr *getConvertElementExpr() const { return convertElementExpr; }
10031004

1004-
void setSequenceConformance(ProtocolConformanceRef conformance) {
1005+
void setSequenceConformance(Type type,
1006+
ProtocolConformanceRef conformance) {
1007+
sequenceType = type;
10051008
sequenceConformance = conformance;
10061009
}
1010+
Type getSequenceType() const { return sequenceType; }
10071011
ProtocolConformanceRef getSequenceConformance() const {
10081012
return sequenceConformance;
10091013
}

include/swift/SILOptimizer/Analysis/RegionAnalysis.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -299,7 +299,11 @@ class RegionAnalysisValueMap {
299299
/// This only includes function arguments.
300300
std::vector<TrackableValueID> neverTransferredValueIDs;
301301

302+
SILFunction *fn;
303+
302304
public:
305+
RegionAnalysisValueMap(SILFunction *fn) : fn(fn) { }
306+
303307
/// Returns the value for this instruction if it isn't a fake "represenative
304308
/// value" to inject actor isolatedness. Asserts in such a case.
305309
SILValue getRepresentative(Element trackableValueID) const;

lib/AST/ASTContext.cpp

Lines changed: 38 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -287,6 +287,10 @@ struct ASTContext::Implementation {
287287
/// The declaration of 'AsyncIteratorProtocol.next()'.
288288
FuncDecl *AsyncIteratorNext = nullptr;
289289

290+
/// The declaration of 'AsyncIteratorProtocol.next(_:)' that takes
291+
/// an actor isolation.
292+
FuncDecl *AsyncIteratorNextIsolated = nullptr;
293+
290294
/// The declaration of Swift.Optional<T>.Some.
291295
EnumElementDecl *OptionalSomeDecl = nullptr;
292296

@@ -948,21 +952,48 @@ FuncDecl *ASTContext::getIteratorNext() const {
948952
return nullptr;
949953
}
950954

955+
static std::pair<FuncDecl *, FuncDecl *>
956+
getAsyncIteratorNextRequirements(const ASTContext &ctx) {
957+
auto proto = ctx.getProtocol(KnownProtocolKind::AsyncIteratorProtocol);
958+
if (!proto)
959+
return { nullptr, nullptr };
960+
961+
FuncDecl *next = nullptr;
962+
FuncDecl *nextThrowing = nullptr;
963+
for (auto result : proto->lookupDirect(ctx.Id_next)) {
964+
if (result->getDeclContext() != proto)
965+
continue;
966+
967+
if (auto func = dyn_cast<FuncDecl>(result)) {
968+
switch (func->getParameters()->size()) {
969+
case 0: next = func; break;
970+
case 1: nextThrowing = func; break;
971+
default: break;
972+
}
973+
}
974+
}
975+
976+
return { next, nextThrowing };
977+
}
978+
951979
FuncDecl *ASTContext::getAsyncIteratorNext() const {
952980
if (getImpl().AsyncIteratorNext) {
953981
return getImpl().AsyncIteratorNext;
954982
}
955983

956-
auto proto = getProtocol(KnownProtocolKind::AsyncIteratorProtocol);
957-
if (!proto)
958-
return nullptr;
984+
auto next = getAsyncIteratorNextRequirements(*this).first;
985+
getImpl().AsyncIteratorNext = next;
986+
return next;
987+
}
959988

960-
if (auto *func = lookupRequirement(proto, Id_next)) {
961-
getImpl().AsyncIteratorNext = func;
962-
return func;
989+
FuncDecl *ASTContext::getAsyncIteratorNextIsolated() const {
990+
if (getImpl().AsyncIteratorNextIsolated) {
991+
return getImpl().AsyncIteratorNextIsolated;
963992
}
964993

965-
return nullptr;
994+
auto nextThrowing = getAsyncIteratorNextRequirements(*this).second;
995+
getImpl().AsyncIteratorNextIsolated = nextThrowing;
996+
return nextThrowing;
966997
}
967998

968999
namespace {

lib/AST/Availability.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -819,6 +819,7 @@ ASTContext::getSwift5PlusAvailability(llvm::VersionTuple swiftVersion) {
819819
case 7: return getSwift57Availability();
820820
case 8: return getSwift58Availability();
821821
case 9: return getSwift59Availability();
822+
case 11: return getSwift511Availability();
822823
default: break;
823824
}
824825
}

lib/AST/Effects.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,9 @@ void swift::simple_display(llvm::raw_ostream &out,
114114
case PolymorphicEffectKind::ByConformance:
115115
out << "by conformance";
116116
break;
117+
case PolymorphicEffectKind::AsyncSequenceRethrows:
118+
out << "by async sequence implicit @rethrows";
119+
break;
117120
case PolymorphicEffectKind::Always:
118121
out << "always";
119122
break;

lib/IRGen/MetadataRequest.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1641,7 +1641,7 @@ static MetadataResponse emitFunctionTypeMetadataRef(IRGenFunction &IGF,
16411641

16421642
default:
16431643
assert((!params.empty() || type->isDifferentiable() ||
1644-
type->getGlobalActor()) &&
1644+
type->getGlobalActor() || type->getThrownError()) &&
16451645
"0 parameter case should be specialized unless it is a "
16461646
"differentiable function or has a global actor");
16471647

lib/SILOptimizer/Analysis/RegionAnalysis.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2802,8 +2802,9 @@ static bool canComputeRegionsForFunction(SILFunction *fn) {
28022802

28032803
RegionAnalysisFunctionInfo::RegionAnalysisFunctionInfo(
28042804
SILFunction *fn, PostOrderFunctionInfo *pofi)
2805-
: allocator(), fn(fn), translator(), ptrSetFactory(allocator),
2806-
blockStates(), pofi(pofi), solved(false), supportedFunction(true) {
2805+
: allocator(), fn(fn), valueMap(fn), translator(),
2806+
ptrSetFactory(allocator), blockStates(), pofi(pofi), solved(false),
2807+
supportedFunction(true) {
28072808
// Before we do anything, make sure that we support processing this function.
28082809
//
28092810
// NOTE: See documentation on supportedFunction for criteria.
@@ -3005,7 +3006,7 @@ TrackableValue RegionAnalysisValueMap::getTrackableValue(
30053006
}
30063007

30073008
// Otherwise refer to the oracle.
3008-
if (!isNonSendableType(value->getType(), value->getFunction()))
3009+
if (!isNonSendableType(value->getType(), fn))
30093010
iter.first->getSecond().addFlag(TrackableValueFlag::isSendable);
30103011

30113012
// Check if our base is a ref_element_addr from an actor. In such a case,

lib/Sema/AssociatedTypeInference.cpp

Lines changed: 110 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -978,6 +978,14 @@ class AssociatedTypeInference {
978978
llvm::Optional<AbstractTypeWitness>
979979
computeDefaultTypeWitness(AssociatedTypeDecl *assocType) const;
980980

981+
/// Compute type witnesses for the Failure type from the
982+
/// AsyncSequence or AsyncIteratorProtocol
983+
llvm::Optional<AbstractTypeWitness>
984+
computeFailureTypeWitness(
985+
AssociatedTypeDecl *assocType,
986+
ArrayRef<std::pair<ValueDecl *, ValueDecl *>> valueWitnesses
987+
) const;
988+
981989
/// Compute the "derived" type witness for an associated type that is
982990
/// known to the compiler.
983991
std::pair<Type, TypeDecl *>
@@ -1577,6 +1585,25 @@ next_witness:;
15771585
return result;
15781586
}
15791587

1588+
/// Determine whether this is AsyncIteratorProtocol.Failure associated type.
1589+
static bool isAsyncIteratorProtocolFailure(AssociatedTypeDecl *assocType) {
1590+
auto proto = assocType->getProtocol();
1591+
if (!proto->isSpecificProtocol(KnownProtocolKind::AsyncIteratorProtocol))
1592+
return false;
1593+
1594+
return assocType->getName().str().equals("Failure");
1595+
}
1596+
1597+
/// Determine whether this is AsyncIteratorProtocol.next() function.
1598+
static bool isAsyncIteratorProtocolNext(ValueDecl *req) {
1599+
auto proto = dyn_cast<ProtocolDecl>(req->getDeclContext());
1600+
if (!proto ||
1601+
!proto->isSpecificProtocol(KnownProtocolKind::AsyncIteratorProtocol))
1602+
return false;
1603+
1604+
return req->getName().getBaseName() == req->getASTContext().Id_next;
1605+
}
1606+
15801607
InferredAssociatedTypes
15811608
AssociatedTypeInference::inferTypeWitnessesViaValueWitnesses(
15821609
const llvm::SetVector<AssociatedTypeDecl *> &assocTypes) {
@@ -1622,7 +1649,8 @@ AssociatedTypeInference::inferTypeWitnessesViaValueWitnesses(
16221649
TinyPtrVector<AssociatedTypeDecl *>());
16231650
if (llvm::find_if(referenced, [&](AssociatedTypeDecl *const assocType) {
16241651
return assocTypes.count(assocType);
1625-
}) == referenced.end())
1652+
}) == referenced.end() &&
1653+
!isAsyncIteratorProtocolNext(req))
16261654
continue;
16271655
}
16281656

@@ -2057,9 +2085,71 @@ Type AssociatedTypeInference::computeFixedTypeWitness(
20572085
return resultType;
20582086
}
20592087

2088+
llvm::Optional<AbstractTypeWitness>
2089+
AssociatedTypeInference::computeFailureTypeWitness(
2090+
AssociatedTypeDecl *assocType,
2091+
ArrayRef<std::pair<ValueDecl *, ValueDecl *>> valueWitnesses) const {
2092+
// Inference only applies to AsyncIteratorProtocol.Failure.
2093+
if (!isAsyncIteratorProtocolFailure(assocType))
2094+
return llvm::None;
2095+
2096+
// If there is a generic parameter named Failure, don't try to use next()
2097+
// to infer Failure.
2098+
if (auto genericSig = dc->getGenericSignatureOfContext()) {
2099+
for (auto gp : genericSig.getGenericParams()) {
2100+
// Packs cannot witness associated type requirements.
2101+
if (gp->isParameterPack())
2102+
continue;
2103+
2104+
if (gp->getName() == assocType->getName())
2105+
return llvm::None;
2106+
}
2107+
}
2108+
2109+
// Look for AsyncIteratorProtocol.next() and infer the Failure type from
2110+
// it.
2111+
for (const auto &witness : valueWitnesses) {
2112+
if (isAsyncIteratorProtocolNext(witness.first)) {
2113+
if (auto witnessFunc = dyn_cast<AbstractFunctionDecl>(witness.second)) {
2114+
// If it doesn't throw, Failure == Never.
2115+
if (!witnessFunc->hasThrows())
2116+
return AbstractTypeWitness(assocType, ctx.getNeverType());
2117+
2118+
// If it isn't 'rethrows', Failure == any Error.
2119+
if (!witnessFunc->getAttrs().hasAttribute<RethrowsAttr>())
2120+
return AbstractTypeWitness(assocType, ctx.getErrorExistentialType());
2121+
2122+
// Otherwise, we need to derive the Failure type from a type parameter
2123+
// that conforms to AsyncIteratorProtocol or AsyncSequence.
2124+
for (auto req : witnessFunc->getGenericSignature().getRequirements()) {
2125+
if (req.getKind() == RequirementKind::Conformance) {
2126+
auto proto = req.getProtocolDecl();
2127+
if (proto->isSpecificProtocol(KnownProtocolKind::AsyncIteratorProtocol) ||
2128+
proto->isSpecificProtocol(KnownProtocolKind::AsyncSequence)) {
2129+
auto failureAssocType = proto->getAssociatedType(ctx.Id_Failure);
2130+
auto failureType = DependentMemberType::get(req.getFirstType(), failureAssocType);
2131+
return AbstractTypeWitness(assocType, dc->mapTypeIntoContext(failureType));
2132+
}
2133+
}
2134+
}
2135+
2136+
return AbstractTypeWitness(assocType, ctx.getErrorExistentialType());
2137+
}
2138+
2139+
break;
2140+
}
2141+
}
2142+
2143+
return llvm::None;
2144+
}
2145+
20602146
llvm::Optional<AbstractTypeWitness>
20612147
AssociatedTypeInference::computeDefaultTypeWitness(
20622148
AssociatedTypeDecl *assocType) const {
2149+
// Ignore the default for AsyncIteratorProtocol.Failure
2150+
if (isAsyncIteratorProtocolFailure(assocType))
2151+
return llvm::None;
2152+
20632153
// Go find a default definition.
20642154
auto *const defaultedAssocType = findDefaultedAssociatedType(
20652155
dc, dc->getSelfNominalTypeDecl(), assocType);
@@ -2154,7 +2244,11 @@ AssociatedTypeInference::computeAbstractTypeWitness(
21542244

21552245
// If there is a generic parameter of the named type, use that.
21562246
if (auto genericSig = dc->getGenericSignatureOfContext()) {
2157-
for (auto gp : genericSig.getInnermostGenericParams()) {
2247+
bool wantAllGenericParams = isAsyncIteratorProtocolFailure(assocType);
2248+
auto genericParams = wantAllGenericParams
2249+
? genericSig.getGenericParams()
2250+
: genericSig.getInnermostGenericParams();
2251+
for (auto gp : genericParams) {
21582252
// Packs cannot witness associated type requirements.
21592253
if (gp->isParameterPack())
21602254
continue;
@@ -2784,7 +2878,20 @@ void AssociatedTypeInference::findSolutionsRec(
27842878
// Filter out the associated types that remain unresolved.
27852879
SmallVector<AssociatedTypeDecl *, 4> stillUnresolved;
27862880
for (auto *const assocType : unresolvedAssocTypes) {
2787-
const auto typeWitness = typeWitnesses.begin(assocType);
2881+
auto typeWitness = typeWitnesses.begin(assocType);
2882+
2883+
// If we do not have a witness for AsyncIteratorProtocol.Failure,
2884+
// look for the witness to AsyncIteratorProtocol.next(). If it throws,
2885+
// use 'any Error'. Otherwise, use 'Never'.
2886+
if (typeWitness == typeWitnesses.end()) {
2887+
if (auto failureTypeWitness =
2888+
computeFailureTypeWitness(assocType, valueWitnesses)) {
2889+
typeWitnesses.insert(assocType,
2890+
{failureTypeWitness->getType(), reqDepth});
2891+
typeWitness = typeWitnesses.begin(assocType);
2892+
}
2893+
}
2894+
27882895
if (typeWitness == typeWitnesses.end()) {
27892896
stillUnresolved.push_back(assocType);
27902897
} else {

lib/Sema/CSApply.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8234,6 +8234,10 @@ bool ExprRewriter::isDistributedThunk(ConcreteDeclRef ref, Expr *context) {
82348234
return existential.ExistentialValue;
82358235
}
82368236
return nullptr;
8237+
},
8238+
[]() -> VarDecl * {
8239+
// FIXME: Need to communicate this.
8240+
return nullptr;
82378241
});
82388242

82398243
if (!actor)
@@ -9194,7 +9198,7 @@ static llvm::Optional<SequenceIterationInfo> applySolutionToForEachStmt(
91949198
type, sequenceProto);
91959199
assert(!sequenceConformance.isInvalid() &&
91969200
"Couldn't find sequence conformance");
9197-
stmt->setSequenceConformance(sequenceConformance);
9201+
stmt->setSequenceConformance(type, sequenceConformance);
91989202

91999203
// Apply the solution to the filtering condition, if there is one.
92009204
if (auto *whereExpr = stmt->getWhere()) {

lib/Sema/CSGen.cpp

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4604,14 +4604,26 @@ generateForEachStmtConstraints(ConstraintSystem &cs, DeclContext *dc,
46044604
// Now, result type of `.makeIterator()` is used to form a call to
46054605
// `.next()`. `next()` is called on each iteration of the loop.
46064606
{
4607+
FuncDecl *nextFn =
4608+
TypeChecker::getForEachIteratorNextFunction(dc, stmt->getForLoc(), isAsync);
4609+
Identifier nextId = nextFn ? nextFn->getName().getBaseIdentifier()
4610+
: ctx.Id_next;
46074611
auto *nextRef = UnresolvedDotExpr::createImplicit(
46084612
ctx,
46094613
new (ctx) DeclRefExpr(makeIteratorVar, DeclNameLoc(stmt->getForLoc()),
46104614
/*Implicit=*/true),
4611-
ctx.Id_next, /*labels=*/ArrayRef<Identifier>());
4615+
nextId, /*labels=*/ArrayRef<Identifier>());
46124616
nextRef->setFunctionRefKind(FunctionRefKind::SingleApply);
46134617

4614-
Expr *nextCall = CallExpr::createImplicitEmpty(ctx, nextRef);
4618+
ArgumentList *nextArgs;
4619+
if (nextFn && nextFn->getParameters()->size() == 1) {
4620+
auto isolationArg =
4621+
new (ctx) CurrentContextIsolationExpr(stmt->getForLoc(), Type());
4622+
nextArgs = ArgumentList::forImplicitUnlabeled(ctx, { isolationArg });
4623+
} else {
4624+
nextArgs = ArgumentList::createImplicit(ctx, {});
4625+
}
4626+
Expr *nextCall = CallExpr::createImplicit(ctx, nextRef, nextArgs);
46154627

46164628
// `next` is always async but witness might not be throwing
46174629
if (isAsync) {

0 commit comments

Comments
 (0)