Skip to content

Commit f724d1f

Browse files
authored
[SE-0280] Enum cases as protocol witnesses (#28916)
* [Typechecker] Allow enum cases without payload to witness a static get-only property with Self type protocol requirement * [SIL] Add support for payload cases as well * [SILGen] Clean up comment * [Typechecker] Re-enable some previously disabled witness matching code Also properly handle the matching in some cases * [Test] Update typechecker tests with payload enum test cases * [Test] Update SILGen test * [SIL] Add two FIXME's to address soon * [SIL] Emit the enum case constructor unconditionally when an enum case is used as a witness Also, tweak SILDeclRef::getLinkage to update the 'limit' to 'OnDemand' if we have an enum declaration * [SILGen] Properly handle a enum witness in addMethodImplementation Also remove a FIXME and code added to workaround the original bug * [TBDGen] Handle enum case witness * [Typechecker] Fix conflicts * [Test] Fix tests * [AST] Fix indentation in diagnostics def file
1 parent bbf94fb commit f724d1f

File tree

11 files changed

+322
-32
lines changed

11 files changed

+322
-32
lines changed

include/swift/AST/DiagnosticsSema.def

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1889,8 +1889,9 @@ ERROR(requirement_restricts_self,none,
18891889
"'Self'",
18901890
(DescriptiveDeclKind, DeclName, StringRef, unsigned, StringRef))
18911891
ERROR(witness_argument_name_mismatch,none,
1892-
"%select{method|initializer}0 %1 has different argument labels from those "
1893-
"required by protocol %2 (%3)", (bool, DeclName, Type, DeclName))
1892+
"%0 %1 has different argument labels "
1893+
"from those required by protocol %2 (%3)",
1894+
(DescriptiveDeclKind, DeclName, Type, DeclName))
18941895
ERROR(witness_initializer_not_required,none,
18951896
"initializer requirement %0 can only be satisfied by a 'required' "
18961897
"initializer in%select{| the definition of}1 non-final class %2",
@@ -2116,6 +2117,9 @@ NOTE(protocol_witness_throws_conflict,none,
21162117
"candidate throws, but protocol does not allow it", ())
21172118
NOTE(protocol_witness_not_objc,none,
21182119
"candidate is explicitly '@nonobjc'", ())
2120+
NOTE(protocol_witness_enum_case_payload, none,
2121+
"candidate is an enum case with associated values, "
2122+
"but protocol does not allow it", ())
21192123

21202124
NOTE(protocol_witness_type,none,
21212125
"possibly intended match", ())

include/swift/SILOptimizer/Utils/InstOptUtils.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -499,6 +499,10 @@ bool calleesAreStaticallyKnowable(SILModule &module, SILDeclRef decl);
499499
/// be reached by calling the function represented by Decl?
500500
bool calleesAreStaticallyKnowable(SILModule &module, AbstractFunctionDecl *afd);
501501

502+
/// Do we have enough information to determine all callees that could
503+
/// be reached by calling the function represented by Decl?
504+
bool calleesAreStaticallyKnowable(SILModule &module, EnumElementDecl *eed);
505+
502506
// Attempt to get the instance for , whose static type is the same as
503507
// its exact dynamic type, returning a null SILValue() if we cannot find it.
504508
// The information that a static type is the same as the exact dynamic,

lib/SIL/SILDeclRef.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -334,7 +334,7 @@ SILLinkage SILDeclRef::getLinkage(ForDefinition_t forDefinition) const {
334334
limit = Limit::OnDemand;
335335
}
336336
}
337-
337+
338338
auto effectiveAccess = d->getEffectiveAccess();
339339

340340
// Private setter implementations for an internal storage declaration should

lib/SILGen/SILGenType.cpp

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -387,11 +387,25 @@ template<typename T> class SILGenWitnessTable : public SILWitnessVisitor<T> {
387387

388388
public:
389389
void addMethod(SILDeclRef requirementRef) {
390-
auto reqAccessor = dyn_cast<AccessorDecl>(requirementRef.getDecl());
390+
auto reqDecl = requirementRef.getDecl();
391+
392+
// Static functions can be witnessed by enum cases with payload
393+
if (!(isa<AccessorDecl>(reqDecl) || isa<ConstructorDecl>(reqDecl))) {
394+
auto FD = cast<FuncDecl>(reqDecl);
395+
if (auto witness = asDerived().getWitness(FD)) {
396+
if (auto EED = dyn_cast<EnumElementDecl>(witness.getDecl())) {
397+
return addMethodImplementation(
398+
requirementRef, SILDeclRef(EED, SILDeclRef::Kind::EnumElement),
399+
witness);
400+
}
401+
}
402+
}
403+
404+
auto reqAccessor = dyn_cast<AccessorDecl>(reqDecl);
391405

392406
// If it's not an accessor, just look for the witness.
393407
if (!reqAccessor) {
394-
if (auto witness = asDerived().getWitness(requirementRef.getDecl())) {
408+
if (auto witness = asDerived().getWitness(reqDecl)) {
395409
return addMethodImplementation(
396410
requirementRef, requirementRef.withDecl(witness.getDecl()),
397411
witness);
@@ -406,6 +420,13 @@ template<typename T> class SILGenWitnessTable : public SILWitnessVisitor<T> {
406420
if (!witness)
407421
return asDerived().addMissingMethod(requirementRef);
408422

423+
// Static properties can be witnessed by enum cases without payload
424+
if (auto EED = dyn_cast<EnumElementDecl>(witness.getDecl())) {
425+
return addMethodImplementation(
426+
requirementRef, SILDeclRef(EED, SILDeclRef::Kind::EnumElement),
427+
witness);
428+
}
429+
409430
auto witnessStorage = cast<AbstractStorageDecl>(witness.getDecl());
410431
if (reqAccessor->isSetter() && !witnessStorage->supportsMutation())
411432
return asDerived().addMissingMethod(requirementRef);
@@ -566,6 +587,10 @@ class SILGenConformance : public SILGenWitnessTable<SILGenConformance> {
566587
witnessLinkage = SILLinkage::Shared;
567588
}
568589

590+
if (isa<EnumElementDecl>(witnessRef.getDecl())) {
591+
assert(witnessRef.isEnumElement() && "Witness decl, but different kind?");
592+
}
593+
569594
SILFunction *witnessFn = SGM.emitProtocolWitness(
570595
ProtocolConformanceRef(Conformance), witnessLinkage, witnessSerialized,
571596
requirementRef, witnessRef, isFree, witness);

lib/SILOptimizer/Utils/InstOptUtils.cpp

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1797,6 +1797,11 @@ bool swift::calleesAreStaticallyKnowable(SILModule &module, SILDeclRef decl) {
17971797
if (decl.isForeign)
17981798
return false;
17991799

1800+
if (decl.isEnumElement()) {
1801+
return calleesAreStaticallyKnowable(module,
1802+
cast<EnumElementDecl>(decl.getDecl()));
1803+
}
1804+
18001805
auto *afd = decl.getAbstractFunctionDecl();
18011806
assert(afd && "Expected abstract function decl!");
18021807
return calleesAreStaticallyKnowable(module, afd);
@@ -1845,6 +1850,41 @@ bool swift::calleesAreStaticallyKnowable(SILModule &module,
18451850
llvm_unreachable("Unhandled access level in switch.");
18461851
}
18471852

1853+
/// Are the callees that could be called through Decl statically
1854+
/// knowable based on the Decl and the compilation mode?
1855+
// FIXME: Merge this with calleesAreStaticallyKnowable above
1856+
bool swift::calleesAreStaticallyKnowable(SILModule &module,
1857+
EnumElementDecl *eed) {
1858+
const DeclContext *assocDC = module.getAssociatedContext();
1859+
if (!assocDC)
1860+
return false;
1861+
1862+
// Only handle members defined within the SILModule's associated context.
1863+
if (!eed->isChildContextOf(assocDC))
1864+
return false;
1865+
1866+
if (eed->isDynamic()) {
1867+
return false;
1868+
}
1869+
1870+
if (!eed->hasAccess())
1871+
return false;
1872+
1873+
// Only consider 'private' members, unless we are in whole-module compilation.
1874+
switch (eed->getEffectiveAccess()) {
1875+
case AccessLevel::Open:
1876+
return false;
1877+
case AccessLevel::Public:
1878+
case AccessLevel::Internal:
1879+
return module.isWholeModule();
1880+
case AccessLevel::FilePrivate:
1881+
case AccessLevel::Private:
1882+
return true;
1883+
}
1884+
1885+
llvm_unreachable("Unhandled access level in switch.");
1886+
}
1887+
18481888
Optional<FindLocalApplySitesResult>
18491889
swift::findLocalApplySites(FunctionRefBaseInst *fri) {
18501890
SmallVector<Operand *, 32> worklist(fri->use_begin(), fri->use_end());

lib/Sema/TypeCheckProtocol.cpp

Lines changed: 40 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -484,8 +484,20 @@ swift::matchWitness(
484484
assert(!req->isInvalid() && "Cannot have an invalid requirement here");
485485

486486
/// Make sure the witness is of the same kind as the requirement.
487-
if (req->getKind() != witness->getKind())
488-
return RequirementMatch(witness, MatchKind::KindConflict);
487+
if (req->getKind() != witness->getKind()) {
488+
// An enum case can witness:
489+
// 1. A static get-only property requirement, as long as the property's
490+
// type is `Self` or it matches the type of the enum explicitly.
491+
// 2. A static function requirement, if the enum case has a payload
492+
// and the payload types and labels match the function and the
493+
// function returns `Self` or the type of the enum.
494+
//
495+
// If there are any discrepencies, we'll diagnose it later. For now,
496+
// let's assume the match is valid.
497+
if (!((isa<VarDecl>(req) || isa<FuncDecl>(req)) &&
498+
isa<EnumElementDecl>(witness)))
499+
return RequirementMatch(witness, MatchKind::KindConflict);
500+
}
489501

490502
// If we're currently validating the witness, bail out.
491503
if (witness->isRecursiveValidation())
@@ -502,7 +514,8 @@ swift::matchWitness(
502514
// Perform basic matching of the requirement and witness.
503515
bool decomposeFunctionType = false;
504516
bool ignoreReturnType = false;
505-
if (auto funcReq = dyn_cast<FuncDecl>(req)) {
517+
if (isa<FuncDecl>(req) && isa<FuncDecl>(witness)) {
518+
auto funcReq = cast<FuncDecl>(req);
506519
auto funcWitness = cast<FuncDecl>(witness);
507520

508521
// Either both must be 'static' or neither.
@@ -564,6 +577,18 @@ swift::matchWitness(
564577
} else if (isa<ConstructorDecl>(witness)) {
565578
decomposeFunctionType = true;
566579
ignoreReturnType = true;
580+
} else if (isa<EnumElementDecl>(witness)) {
581+
auto enumCase = cast<EnumElementDecl>(witness);
582+
if (enumCase->hasAssociatedValues() && isa<VarDecl>(req))
583+
return RequirementMatch(witness, MatchKind::EnumCaseWithAssociatedValues);
584+
auto isValid = isa<VarDecl>(req) || isa<FuncDecl>(req);
585+
if (!isValid)
586+
return RequirementMatch(witness, MatchKind::KindConflict);
587+
if (!cast<ValueDecl>(req)->isStatic())
588+
return RequirementMatch(witness, MatchKind::StaticNonStaticConflict);
589+
if (isa<VarDecl>(req) &&
590+
cast<VarDecl>(req)->getParsedAccessor(AccessorKind::Set))
591+
return RequirementMatch(witness, MatchKind::SettableConflict);
567592
}
568593

569594
// If the requirement is @objc, the witness must not be marked with @nonobjc.
@@ -2182,7 +2207,8 @@ diagnoseMatch(ModuleDecl *module, NormalProtocolConformance *conformance,
21822207
if (match.Kind != MatchKind::RenamedMatch &&
21832208
!match.Witness->getAttrs().hasAttribute<ImplementsAttr>() &&
21842209
match.Witness->getFullName() &&
2185-
req->getFullName() != match.Witness->getFullName())
2210+
req->getFullName() != match.Witness->getFullName() &&
2211+
!isa<EnumElementDecl>(match.Witness))
21862212
return;
21872213

21882214
// Form a string describing the associated type deductions.
@@ -2234,7 +2260,7 @@ diagnoseMatch(ModuleDecl *module, NormalProtocolConformance *conformance,
22342260
break;
22352261

22362262
case MatchKind::TypeConflict: {
2237-
if (!isa<TypeDecl>(req)) {
2263+
if (!isa<TypeDecl>(req) && !isa<EnumElementDecl>(match.Witness)) {
22382264
computeFixitsForOverridenDeclaration(match.Witness, req, [&](bool){
22392265
return diags.diagnose(match.Witness,
22402266
diag::protocol_witness_type_conflict,
@@ -2278,6 +2304,8 @@ diagnoseMatch(ModuleDecl *module, NormalProtocolConformance *conformance,
22782304
auto witness = match.Witness;
22792305
auto diag = diags.diagnose(witness, diag::protocol_witness_static_conflict,
22802306
!req->isInstanceMember());
2307+
if (isa<EnumElementDecl>(witness))
2308+
break;
22812309
if (req->isInstanceMember()) {
22822310
SourceLoc loc;
22832311
if (auto FD = dyn_cast<FuncDecl>(witness)) {
@@ -2403,6 +2431,9 @@ diagnoseMatch(ModuleDecl *module, NormalProtocolConformance *conformance,
24032431
}
24042432
break;
24052433
}
2434+
case MatchKind::EnumCaseWithAssociatedValues:
2435+
diags.diagnose(match.Witness, diag::protocol_witness_enum_case_payload);
2436+
break;
24062437
}
24072438
}
24082439

@@ -3325,12 +3356,10 @@ ConformanceChecker::resolveWitnessViaLookup(ValueDecl *requirement) {
33253356
auto &diags = proto->getASTContext().Diags;
33263357
{
33273358
SourceLoc diagLoc = getLocForDiagnosingWitness(conformance,witness);
3328-
auto diag = diags.diagnose(diagLoc,
3329-
diag::witness_argument_name_mismatch,
3330-
isa<ConstructorDecl>(witness),
3331-
witness->getFullName(),
3332-
proto->getDeclaredType(),
3333-
requirement->getFullName());
3359+
auto diag = diags.diagnose(
3360+
diagLoc, diag::witness_argument_name_mismatch,
3361+
witness->getDescriptiveKind(), witness->getFullName(),
3362+
proto->getDeclaredType(), requirement->getFullName());
33343363
if (diagLoc == witness->getLoc()) {
33353364
fixDeclarationName(diag, witness, requirement->getFullName());
33363365
} else {

lib/Sema/TypeCheckProtocol.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,10 @@ enum class MatchKind : uint8_t {
211211

212212
/// The witness is missing a `@differentiable` attribute from the requirement.
213213
MissingDifferentiableAttr,
214+
215+
/// The witness did not match because it is an enum case with
216+
/// associated values.
217+
EnumCaseWithAssociatedValues,
214218
};
215219

216220
/// Describes the kind of optional adjustment performed when
@@ -437,6 +441,7 @@ struct RequirementMatch {
437441
case MatchKind::ThrowsConflict:
438442
case MatchKind::NonObjC:
439443
case MatchKind::MissingDifferentiableAttr:
444+
case MatchKind::EnumCaseWithAssociatedValues:
440445
return false;
441446
}
442447

@@ -467,6 +472,7 @@ struct RequirementMatch {
467472
case MatchKind::ThrowsConflict:
468473
case MatchKind::NonObjC:
469474
case MatchKind::MissingDifferentiableAttr:
475+
case MatchKind::EnumCaseWithAssociatedValues:
470476
return false;
471477
}
472478

lib/TBDGen/TBDGen.cpp

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -482,21 +482,23 @@ void TBDGenVisitor::addConformances(DeclContext *DC) {
482482
}
483483
};
484484

485-
rootConformance->forEachValueWitness(
486-
[&](ValueDecl *valueReq, Witness witness) {
487-
auto witnessDecl = witness.getDecl();
488-
if (isa<AbstractFunctionDecl>(valueReq)) {
489-
addSymbolIfNecessary(valueReq, witnessDecl);
490-
} else if (auto *storage = dyn_cast<AbstractStorageDecl>(valueReq)) {
491-
auto witnessStorage = cast<AbstractStorageDecl>(witnessDecl);
492-
storage->visitOpaqueAccessors([&](AccessorDecl *reqtAccessor) {
493-
auto witnessAccessor =
494-
witnessStorage->getSynthesizedAccessor(
495-
reqtAccessor->getAccessorKind());
496-
addSymbolIfNecessary(reqtAccessor, witnessAccessor);
497-
});
498-
}
499-
});
485+
rootConformance->forEachValueWitness([&](ValueDecl *valueReq,
486+
Witness witness) {
487+
auto witnessDecl = witness.getDecl();
488+
if (isa<AbstractFunctionDecl>(valueReq)) {
489+
addSymbolIfNecessary(valueReq, witnessDecl);
490+
} else if (auto *storage = dyn_cast<AbstractStorageDecl>(valueReq)) {
491+
if (auto witnessStorage = dyn_cast<AbstractStorageDecl>(witnessDecl)) {
492+
storage->visitOpaqueAccessors([&](AccessorDecl *reqtAccessor) {
493+
auto witnessAccessor = witnessStorage->getSynthesizedAccessor(
494+
reqtAccessor->getAccessorKind());
495+
addSymbolIfNecessary(reqtAccessor, witnessAccessor);
496+
});
497+
} else if (isa<EnumElementDecl>(witnessDecl)) {
498+
addSymbolIfNecessary(valueReq, witnessDecl);
499+
}
500+
}
501+
});
500502
}
501503
}
502504

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
// RUN: %target-swift-emit-silgen %s | %FileCheck %s
2+
3+
protocol Foo {
4+
static var button: Self { get }
5+
}
6+
7+
enum Bar: Foo {
8+
case button
9+
}
10+
11+
protocol AnotherFoo {
12+
static func bar(arg: Int) -> Self
13+
}
14+
15+
enum AnotherBar: AnotherFoo {
16+
case bar(arg: Int)
17+
}
18+
19+
// CHECK-LABEL: sil private [transparent] [thunk] [ossa] @$s21protocol_enum_witness3BarOAA3FooA2aDP6buttonxvgZTW : $@convention(witness_method: Foo) (@thick Bar.Type) -> @out Bar {
20+
// CHECK: bb0([[BAR:%.*]] : $*Bar, [[BAR_TYPE:%.*]] : $@thick Bar.Type):
21+
// CHECK-NEXT: [[META_TYPE:%.*]] = metatype $@thin Bar.Type
22+
// CHECK: [[REF:%.*]] = function_ref @$s21protocol_enum_witness3BarO6buttonyA2CmF : $@convention(method) (@thin Bar.Type) -> Bar
23+
// CHECK-NEXT: [[RESULT:%.*]] = apply [[REF]]([[META_TYPE]]) : $@convention(method) (@thin Bar.Type) -> Bar
24+
// CHECK-NEXT: store [[RESULT]] to [trivial] [[BAR]] : $*Bar
25+
// CHECK-NEXT: [[TUPLE:%.*]] = tuple ()
26+
// CHECK-NEXT: return [[TUPLE]] : $()
27+
// CHECK-END: }
28+
29+
// CHECK-LABEL: sil hidden [transparent] [ossa] @$s21protocol_enum_witness3BarO6buttonyA2CmF : $@convention(method) (@thin Bar.Type) -> Bar {
30+
// CHECK: bb0({{%.*}} : $@thin Bar.Type):
31+
// CHECK-NEXT: [[CASE:%.*]] = enum $Bar, #Bar.button!enumelt
32+
// CHECK-NEXT: return [[CASE]] : $Bar
33+
// CHECK-END: }
34+
35+
// CHECK-LABEL: sil private [transparent] [thunk] [ossa] @$s21protocol_enum_witness10AnotherBarOAA0D3FooA2aDP3bar3argxSi_tFZTW : $@convention(witness_method: AnotherFoo) (Int, @thick AnotherBar.Type) -> @out AnotherBar {
36+
// CHECK: bb0([[ANOTHER_BAR:%.*]] : $*AnotherBar, [[INT_ARG:%.*]] : $Int, [[ANOTHER_BAR_TYPE:%.*]] : $@thick AnotherBar.Type):
37+
// CHECK-NEXT: [[META_TYPE:%.*]] = metatype $@thin AnotherBar.Type
38+
// CHECK: [[REF:%.*]] = function_ref @$s21protocol_enum_witness10AnotherBarO3baryACSi_tcACmF : $@convention(method) (Int, @thin AnotherBar.Type) -> AnotherBar
39+
// CHECK-NEXT: [[RESULT:%.*]] = apply [[REF]]([[INT_ARG]], [[META_TYPE]]) : $@convention(method) (Int, @thin AnotherBar.Type) -> AnotherBar
40+
// CHECK-NEXT: store [[RESULT]] to [trivial] [[ANOTHER_BAR]] : $*AnotherBar
41+
// CHECK-NEXT: [[TUPLE:%.*]] = tuple ()
42+
// CHECK-NEXT: return [[TUPLE]] : $()
43+
// CHECK-END: }
44+
45+
// CHECK-LABEL: sil_witness_table hidden Bar: Foo module protocol_enum_witness {
46+
// CHECK: method #Foo.button!getter: <Self where Self : Foo> (Self.Type) -> () -> Self : @$s21protocol_enum_witness3BarOAA3FooA2aDP6buttonxvgZTW
47+
48+
// CHECK-LABEL: sil_witness_table hidden AnotherBar: AnotherFoo module protocol_enum_witness {
49+
// CHECK: method #AnotherFoo.bar: <Self where Self : AnotherFoo> (Self.Type) -> (Int) -> Self : @$s21protocol_enum_witness10AnotherBarOAA0D3FooA2aDP3bar3argxSi_tFZTW

0 commit comments

Comments
 (0)