Skip to content

Commit 72e4224

Browse files
authored
Merge pull request #13655 from CodaFi/ace-attorney
[SE-0194] Implement deriving collections of enum cases
2 parents 1e4f55d + dac0689 commit 72e4224

17 files changed

+327
-13
lines changed

include/swift/AST/Decl.h

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -543,12 +543,16 @@ class alignas(1 << DeclAlignInBits) Decl {
543543
HasUnreferenceableStorage : 1
544544
);
545545

546-
SWIFT_INLINE_BITFIELD(EnumDecl, NominalTypeDecl, 2+2,
546+
SWIFT_INLINE_BITFIELD(EnumDecl, NominalTypeDecl, 2+2+1,
547547
/// The stage of the raw type circularity check for this class.
548548
Circularity : 2,
549549

550550
/// True if the enum has cases and at least one case has associated values.
551-
HasAssociatedValues : 2
551+
HasAssociatedValues : 2,
552+
/// True if the enum has at least one case that has some availability
553+
/// attribute. A single bit because it's lazily computed along with the
554+
/// HasAssociatedValues bit.
555+
HasAnyUnavailableValues : 1
552556
);
553557

554558
SWIFT_INLINE_BITFIELD(PrecedenceGroupDecl, Decl, 1+2,
@@ -3220,6 +3224,11 @@ class EnumDecl final : public NominalTypeDecl {
32203224
/// Note that this is true for enums with absolutely no cases.
32213225
bool hasOnlyCasesWithoutAssociatedValues() const;
32223226

3227+
/// True if any of the enum cases have availability annotations.
3228+
///
3229+
/// Note that this is false for enums with absolutely no cases.
3230+
bool hasPotentiallyUnavailableCaseValue() const;
3231+
32233232
/// True if the enum has cases.
32243233
bool hasCases() const {
32253234
return !getAllElements().empty();

include/swift/AST/DiagnosticsSema.def

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2096,10 +2096,11 @@ NOTE(construct_raw_representable_from_unwrapped_value,none,
20962096
"construct %0 from unwrapped %1 value", (Type, Type))
20972097

20982098
// Derived conformances
2099-
21002099
ERROR(cannot_synthesize_in_extension,none,
21012100
"implementation of %0 cannot be automatically synthesized in an extension", (Type))
21022101

2102+
ERROR(broken_case_iterable_requirement,none,
2103+
"CaseIterable protocol is broken: unexpected requirement", ())
21032104
ERROR(broken_raw_representable_requirement,none,
21042105
"RawRepresentable protocol is broken: unexpected requirement", ())
21052106
ERROR(broken_equatable_requirement,none,

include/swift/AST/KnownIdentifiers.def

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@
2222
#define IDENTIFIER(name) IDENTIFIER_WITH_NAME(name, #name)
2323
#define IDENTIFIER_(name) IDENTIFIER_WITH_NAME(name, "_" #name)
2424

25+
IDENTIFIER(AllCases)
26+
IDENTIFIER(allCases)
2527
IDENTIFIER(alloc)
2628
IDENTIFIER(allocWithZone)
2729
IDENTIFIER(allZeros)

include/swift/AST/KnownProtocols.def

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ PROTOCOL(Comparable)
5757
PROTOCOL(Error)
5858
PROTOCOL_(ErrorCodeProtocol)
5959
PROTOCOL(OptionSet)
60+
PROTOCOL(CaseIterable)
6061

6162
PROTOCOL_(BridgedNSError)
6263
PROTOCOL_(BridgedStoredNSError)

lib/AST/Decl.cpp

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2689,6 +2689,8 @@ EnumDecl::EnumDecl(SourceLoc EnumLoc,
26892689
= static_cast<unsigned>(CircularityCheck::Unchecked);
26902690
Bits.EnumDecl.HasAssociatedValues
26912691
= static_cast<unsigned>(AssociatedValueCheck::Unchecked);
2692+
Bits.EnumDecl.HasAnyUnavailableValues
2693+
= false;
26922694
}
26932695

26942696
StructDecl::StructDecl(SourceLoc StructLoc, Identifier Name, SourceLoc NameLoc,
@@ -3041,6 +3043,17 @@ EnumElementDecl *EnumDecl::getElement(Identifier Name) const {
30413043
return nullptr;
30423044
}
30433045

3046+
bool EnumDecl::hasPotentiallyUnavailableCaseValue() const {
3047+
switch (static_cast<AssociatedValueCheck>(Bits.EnumDecl.HasAssociatedValues)) {
3048+
case AssociatedValueCheck::Unchecked:
3049+
// Compute below
3050+
this->hasOnlyCasesWithoutAssociatedValues();
3051+
LLVM_FALLTHROUGH;
3052+
default:
3053+
return static_cast<bool>(Bits.EnumDecl.HasAnyUnavailableValues);
3054+
}
3055+
}
3056+
30443057
bool EnumDecl::hasOnlyCasesWithoutAssociatedValues() const {
30453058
// Check whether we already have a cached answer.
30463059
switch (static_cast<AssociatedValueCheck>(
@@ -3056,6 +3069,15 @@ bool EnumDecl::hasOnlyCasesWithoutAssociatedValues() const {
30563069
return false;
30573070
}
30583071
for (auto elt : getAllElements()) {
3072+
for (auto Attr : elt->getAttrs()) {
3073+
if (auto AvAttr = dyn_cast<AvailableAttr>(Attr)) {
3074+
if (!AvAttr->isInvalid()) {
3075+
const_cast<EnumDecl*>(this)->Bits.EnumDecl.HasAnyUnavailableValues
3076+
= true;
3077+
}
3078+
}
3079+
}
3080+
30593081
if (elt->hasAssociatedValues()) {
30603082
const_cast<EnumDecl*>(this)->Bits.EnumDecl.HasAssociatedValues
30613083
= static_cast<unsigned>(AssociatedValueCheck::HasAssociatedValues);

lib/IRGen/GenMeta.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5856,6 +5856,7 @@ SpecialProtocol irgen::getSpecialProtocolID(ProtocolDecl *P) {
58565856
case KnownProtocolKind::RawRepresentable:
58575857
case KnownProtocolKind::Equatable:
58585858
case KnownProtocolKind::Hashable:
5859+
case KnownProtocolKind::CaseIterable:
58595860
case KnownProtocolKind::Comparable:
58605861
case KnownProtocolKind::ObjectiveCBridgeable:
58615862
case KnownProtocolKind::DestructorSafeContainer:

lib/Sema/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ add_swift_library(swiftSema STATIC
1717
ConstraintGraph.cpp
1818
ConstraintLocator.cpp
1919
ConstraintSystem.cpp
20+
DerivedConformanceCaseIterable.cpp
2021
DerivedConformanceCodable.cpp
2122
DerivedConformanceCodingKey.cpp
2223
DerivedConformanceEquatableHashable.cpp
Lines changed: 164 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,164 @@
1+
//
2+
// This source file is part of the Swift.org open source project
3+
//
4+
// Copyright (c) 2016 Apple Inc. and the Swift project authors
5+
// Licensed under Apache License v2.0 with Runtime Library Exception
6+
//
7+
// See http://swift.org/LICENSE.txt for license information
8+
// See http://swift.org/CONTRIBUTORS.txt for the list of Swift project authors
9+
//
10+
//===----------------------------------------------------------------------===//
11+
//
12+
// This file implements implicit derivation of the CaseIterable protocol.
13+
//
14+
//===----------------------------------------------------------------------===//
15+
16+
#include "TypeChecker.h"
17+
#include "swift/AST/Decl.h"
18+
#include "swift/AST/Stmt.h"
19+
#include "swift/AST/Expr.h"
20+
#include "swift/AST/Types.h"
21+
#include "llvm/Support/raw_ostream.h"
22+
#include "DerivedConformances.h"
23+
24+
using namespace swift;
25+
using namespace DerivedConformance;
26+
27+
/// Common preconditions for CaseIterable.
28+
static bool canDeriveConformance(NominalTypeDecl *type) {
29+
// The type must be an enum.
30+
auto enumDecl = dyn_cast<EnumDecl>(type);
31+
if (!enumDecl)
32+
return false;
33+
34+
// "Simple" enums without availability attributes can derive
35+
// a CaseIterable conformance.
36+
//
37+
// FIXME: Lift the availability restriction.
38+
return !enumDecl->hasPotentiallyUnavailableCaseValue()
39+
&& enumDecl->hasOnlyCasesWithoutAssociatedValues();
40+
}
41+
42+
/// Derive the implementation of allCases for a "simple" no-payload enum.
43+
void deriveCaseIterable_enum_getter(AbstractFunctionDecl *funcDecl) {
44+
auto *parentDC = funcDecl->getDeclContext();
45+
auto *parentEnum = parentDC->getAsEnumOrEnumExtensionContext();
46+
auto enumTy = parentEnum->getDeclaredTypeInContext();
47+
auto &C = parentDC->getASTContext();
48+
49+
SmallVector<Expr *, 8> elExprs;
50+
for (EnumElementDecl *elt : parentEnum->getAllElements()) {
51+
auto *ref = new (C) DeclRefExpr(elt, DeclNameLoc(), /*implicit*/true);
52+
auto *base = TypeExpr::createImplicit(enumTy, C);
53+
auto *apply = new (C) DotSyntaxCallExpr(ref, SourceLoc(), base);
54+
elExprs.push_back(apply);
55+
}
56+
auto *arrayExpr = ArrayExpr::create(C, SourceLoc(), elExprs, {}, SourceLoc());
57+
58+
auto *returnStmt = new (C) ReturnStmt(SourceLoc(), arrayExpr);
59+
auto *body = BraceStmt::create(C, SourceLoc(), ASTNode(returnStmt),
60+
SourceLoc());
61+
funcDecl->setBody(body);
62+
}
63+
64+
static ArraySliceType *computeAllCasesType(NominalTypeDecl *enumType) {
65+
auto metaTy = enumType->getDeclaredInterfaceType();
66+
if (!metaTy || metaTy->hasError())
67+
return nullptr;
68+
69+
return ArraySliceType::get(metaTy->getRValueInstanceType());
70+
}
71+
72+
static Type deriveCaseIterable_AllCases(TypeChecker &tc, Decl *parentDecl,
73+
EnumDecl *enumDecl) {
74+
// enum SomeEnum : CaseIterable {
75+
// @derived
76+
// typealias AllCases = [SomeEnum]
77+
// }
78+
auto *rawInterfaceType = computeAllCasesType(enumDecl);
79+
return cast<DeclContext>(parentDecl)->mapTypeIntoContext(rawInterfaceType);
80+
}
81+
82+
ValueDecl *DerivedConformance::deriveCaseIterable(TypeChecker &tc,
83+
Decl *parentDecl,
84+
NominalTypeDecl *targetDecl,
85+
ValueDecl *requirement) {
86+
// Conformance can't be synthesized in an extension.
87+
auto caseIterableProto
88+
= tc.Context.getProtocol(KnownProtocolKind::CaseIterable);
89+
auto caseIterableType = caseIterableProto->getDeclaredType();
90+
if (targetDecl != parentDecl) {
91+
tc.diagnose(parentDecl->getLoc(), diag::cannot_synthesize_in_extension,
92+
caseIterableType);
93+
return nullptr;
94+
}
95+
96+
// Check that we can actually derive CaseIterable for this type.
97+
if (!canDeriveConformance(targetDecl))
98+
return nullptr;
99+
100+
// Build the necessary decl.
101+
if (requirement->getBaseName() != tc.Context.Id_allCases) {
102+
tc.diagnose(requirement->getLoc(),
103+
diag::broken_case_iterable_requirement);
104+
return nullptr;
105+
}
106+
107+
auto enumDecl = cast<EnumDecl>(targetDecl);
108+
ASTContext &C = tc.Context;
109+
110+
111+
// Define the property.
112+
auto *returnTy = computeAllCasesType(targetDecl);
113+
114+
VarDecl *propDecl;
115+
PatternBindingDecl *pbDecl;
116+
std::tie(propDecl, pbDecl)
117+
= declareDerivedProperty(tc, parentDecl, enumDecl, C.Id_allCases,
118+
returnTy, returnTy,
119+
/*isStatic=*/true, /*isFinal=*/true);
120+
121+
// Define the getter.
122+
auto *getterDecl = addGetterToReadOnlyDerivedProperty(tc, propDecl, returnTy);
123+
124+
getterDecl->setBodySynthesizer(&deriveCaseIterable_enum_getter);
125+
126+
auto dc = cast<IterableDeclContext>(parentDecl);
127+
dc->addMember(getterDecl);
128+
dc->addMember(propDecl);
129+
dc->addMember(pbDecl);
130+
131+
return propDecl;
132+
}
133+
134+
Type DerivedConformance::deriveCaseIterable(TypeChecker &tc, Decl *parentDecl,
135+
NominalTypeDecl *targetDecl,
136+
AssociatedTypeDecl *assocType) {
137+
// Conformance can't be synthesized in an extension.
138+
auto caseIterableProto
139+
= tc.Context.getProtocol(KnownProtocolKind::CaseIterable);
140+
auto caseIterableType = caseIterableProto->getDeclaredType();
141+
if (targetDecl != parentDecl) {
142+
tc.diagnose(parentDecl->getLoc(), diag::cannot_synthesize_in_extension,
143+
caseIterableType);
144+
return nullptr;
145+
}
146+
147+
// We can only synthesize CaseIterable for enums.
148+
auto enumDecl = dyn_cast<EnumDecl>(targetDecl);
149+
if (!enumDecl)
150+
return nullptr;
151+
152+
// Check that we can actually derive CaseIterable for this type.
153+
if (!canDeriveConformance(targetDecl))
154+
return nullptr;
155+
156+
if (assocType->getName() == tc.Context.Id_AllCases) {
157+
return deriveCaseIterable_AllCases(tc, parentDecl, enumDecl);
158+
}
159+
160+
tc.diagnose(assocType->getLoc(),
161+
diag::broken_case_iterable_requirement);
162+
return nullptr;
163+
}
164+

lib/Sema/DerivedConformances.cpp

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,11 +39,18 @@ bool DerivedConformance::derivesProtocolConformance(TypeChecker &tc,
3939
return enumDecl->hasRawType();
4040

4141
// Enums without associated values can implicitly derive Equatable and
42-
// Hashable conformance.
42+
// Hashable conformances.
4343
case KnownProtocolKind::Equatable:
4444
return canDeriveEquatable(tc, enumDecl, protocol);
4545
case KnownProtocolKind::Hashable:
4646
return canDeriveHashable(tc, enumDecl, protocol);
47+
// "Simple" enums without availability attributes can explicitly derive
48+
// a CaseIterable conformance.
49+
//
50+
// FIXME: Lift the availability restriction.
51+
case KnownProtocolKind::CaseIterable:
52+
return !enumDecl->hasPotentiallyUnavailableCaseValue()
53+
&& enumDecl->hasOnlyCasesWithoutAssociatedValues();
4754

4855
// @objc enums can explicitly derive their _BridgedNSError conformance.
4956
case KnownProtocolKind::BridgedNSError:
@@ -135,6 +142,10 @@ ValueDecl *DerivedConformance::getDerivableRequirement(TypeChecker &tc,
135142
if (name.isSimpleName(ctx.Id_hashValue))
136143
return getRequirement(KnownProtocolKind::Hashable);
137144

145+
// CaseIterable.allValues
146+
if (name.isSimpleName(ctx.Id_allCases))
147+
return getRequirement(KnownProtocolKind::CaseIterable);
148+
138149
// _BridgedNSError._nsErrorDomain
139150
if (name.isSimpleName(ctx.Id_nsErrorDomain))
140151
return getRequirement(KnownProtocolKind::BridgedNSError);
@@ -192,6 +203,10 @@ ValueDecl *DerivedConformance::getDerivableRequirement(TypeChecker &tc,
192203
if (name.isSimpleName(ctx.Id_RawValue))
193204
return getRequirement(KnownProtocolKind::RawRepresentable);
194205

206+
// CaseIterable.AllCases
207+
if (name.isSimpleName(ctx.Id_AllCases))
208+
return getRequirement(KnownProtocolKind::CaseIterable);
209+
195210
return nullptr;
196211
}
197212

lib/Sema/DerivedConformances.h

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,25 @@ ValueDecl *getDerivableRequirement(TypeChecker &tc,
7171
NominalTypeDecl *nominal,
7272
ValueDecl *requirement);
7373

74+
75+
/// Derive a CaseIterable requirement for an enum if it has no associated
76+
/// values for any of its cases.
77+
///
78+
/// \returns the derived member, which will also be added to the type.
79+
ValueDecl *deriveCaseIterable(TypeChecker &tc,
80+
Decl *parentDecl,
81+
NominalTypeDecl *type,
82+
ValueDecl *requirement);
83+
84+
/// Derive a CaseIterable type witness for an enum if it has no associated
85+
/// values for any of its cases.
86+
///
87+
/// \returns the derived member, which will also be added to the type.
88+
Type deriveCaseIterable(TypeChecker &tc,
89+
Decl *parentDecl,
90+
NominalTypeDecl *type,
91+
AssociatedTypeDecl *assocType);
92+
7493
/// Derive a RawRepresentable requirement for an enum, if it has a valid
7594
/// raw type and raw values for all of its cases.
7695
///

lib/Sema/TypeCheckDecl.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9012,6 +9012,13 @@ void TypeChecker::synthesizeMemberForLookup(NominalTypeDecl *target,
90129012
auto *encodableProto = Context.getProtocol(KnownProtocolKind::Encodable);
90139013
if (!evaluateTargetConformanceTo(decodableProto))
90149014
(void)evaluateTargetConformanceTo(encodableProto);
9015+
} else if (baseName.getIdentifier() == Context.Id_allCases ||
9016+
baseName.getIdentifier() == Context.Id_AllCases) {
9017+
// If the target should conform to the CaseIterable protocol, check the
9018+
// conformance here to attempt synthesis.
9019+
auto *caseIterableProto
9020+
= Context.getProtocol(KnownProtocolKind::CaseIterable);
9021+
(void)evaluateTargetConformanceTo(caseIterableProto);
90159022
}
90169023
} else {
90179024
auto argumentNames = member.getArgumentNames();

lib/Sema/TypeCheckProtocol.cpp

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4712,11 +4712,17 @@ ValueDecl *TypeChecker::deriveProtocolRequirement(DeclContext *DC,
47124712
return DerivedConformance::deriveRawRepresentable(*this, Decl,
47134713
TypeDecl, Requirement);
47144714

4715+
case KnownProtocolKind::CaseIterable:
4716+
return DerivedConformance::deriveCaseIterable(*this, Decl,
4717+
TypeDecl, Requirement);
4718+
47154719
case KnownProtocolKind::Equatable:
4716-
return DerivedConformance::deriveEquatable(*this, Decl, TypeDecl, Requirement);
4720+
return DerivedConformance::deriveEquatable(*this, Decl, TypeDecl,
4721+
Requirement);
47174722

47184723
case KnownProtocolKind::Hashable:
4719-
return DerivedConformance::deriveHashable(*this, Decl, TypeDecl, Requirement);
4724+
return DerivedConformance::deriveHashable(*this, Decl, TypeDecl,
4725+
Requirement);
47204726

47214727
case KnownProtocolKind::BridgedNSError:
47224728
return DerivedConformance::deriveBridgedNSError(*this, Decl, TypeDecl,
@@ -4752,7 +4758,9 @@ Type TypeChecker::deriveTypeWitness(DeclContext *DC,
47524758
case KnownProtocolKind::RawRepresentable:
47534759
return DerivedConformance::deriveRawRepresentable(*this, Decl,
47544760
TypeDecl, AssocType);
4755-
4761+
case KnownProtocolKind::CaseIterable:
4762+
return DerivedConformance::deriveCaseIterable(*this, Decl,
4763+
TypeDecl, AssocType);
47564764
default:
47574765
return nullptr;
47584766
}

0 commit comments

Comments
 (0)