Skip to content

Commit 64b691f

Browse files
authored
Merge pull request #29005 from DougGregor/functional-pattern-checking
[Type checker] Make typeCheckPattern() a functional request
2 parents ad79c2b + 6c74b33 commit 64b691f

File tree

14 files changed

+453
-197
lines changed

14 files changed

+453
-197
lines changed

include/swift/AST/Pattern.h

Lines changed: 87 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -756,7 +756,93 @@ inline Pattern *Pattern::getSemanticsProvidingPattern() {
756756
return vp->getSubPattern()->getSemanticsProvidingPattern();
757757
return this;
758758
}
759-
759+
760+
/// Describes a pattern and the context in which it occurs.
761+
class ContextualPattern {
762+
/// The pattern and whether this is the top-level pattern.
763+
llvm::PointerIntPair<Pattern *, 1, bool> patternAndTopLevel;
764+
765+
/// Either the declaration context or the enclosing pattern binding
766+
/// declaration.
767+
llvm::PointerUnion<PatternBindingDecl *, DeclContext *> declOrContext;
768+
769+
/// Index into the pattern binding declaration, when there is one.
770+
unsigned index = 0;
771+
772+
ContextualPattern(
773+
Pattern *pattern, bool topLevel,
774+
llvm::PointerUnion<PatternBindingDecl *, DeclContext *> declOrContext,
775+
unsigned index
776+
) : patternAndTopLevel(pattern, topLevel),
777+
declOrContext(declOrContext),
778+
index(index) { }
779+
780+
public:
781+
/// Produce a contextual pattern for a pattern binding declaration entry.
782+
static ContextualPattern forPatternBindingDecl(
783+
PatternBindingDecl *pbd, unsigned index);
784+
785+
/// Produce a contextual pattern for a raw pattern that always allows
786+
/// inference.
787+
static ContextualPattern forRawPattern(Pattern *pattern, DeclContext *dc) {
788+
return ContextualPattern(pattern, /*topLevel=*/true, dc, /*index=*/0);
789+
}
790+
791+
/// Retrieve a contextual pattern for the given subpattern.
792+
ContextualPattern forSubPattern(
793+
Pattern *subpattern, bool retainTopLevel) const {
794+
return ContextualPattern(
795+
subpattern, isTopLevel() && retainTopLevel, declOrContext, index);
796+
}
797+
798+
/// Retrieve the pattern.
799+
Pattern *getPattern() const {
800+
return patternAndTopLevel.getPointer();
801+
}
802+
803+
/// Whether this is the top-level pattern in this context.
804+
bool isTopLevel() const {
805+
return patternAndTopLevel.getInt();
806+
}
807+
808+
/// Retrieve the declaration context of the pattern.
809+
DeclContext *getDeclContext() const;
810+
811+
/// Retrieve the pattern binding declaration that owns this pattern, if
812+
/// there is one.
813+
PatternBindingDecl *getPatternBindingDecl() const;
814+
815+
/// Retrieve the index into the pattern binding declaration for the top-level
816+
/// pattern.
817+
unsigned getPatternBindingIndex() const {
818+
assert(getPatternBindingDecl() != nullptr);
819+
return index;
820+
}
821+
822+
/// Whether this pattern allows type inference, e.g., from an initializer
823+
/// expression.
824+
bool allowsInference() const;
825+
826+
friend llvm::hash_code hash_value(const ContextualPattern &pattern) {
827+
return llvm::hash_combine(pattern.getPattern(),
828+
pattern.isTopLevel(),
829+
pattern.declOrContext);
830+
}
831+
832+
friend bool operator==(const ContextualPattern &lhs,
833+
const ContextualPattern &rhs) {
834+
return lhs.patternAndTopLevel == rhs.patternAndTopLevel &&
835+
lhs.declOrContext == rhs.declOrContext;
836+
}
837+
838+
friend bool operator!=(const ContextualPattern &lhs,
839+
const ContextualPattern &rhs) {
840+
return !(lhs == rhs);
841+
}
842+
};
843+
844+
void simple_display(llvm::raw_ostream &out, const ContextualPattern &pattern);
845+
760846
} // end namespace swift
761847

762848
#endif

include/swift/AST/TypeCheckRequests.h

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
#include "swift/AST/GenericSignature.h"
2121
#include "swift/AST/Type.h"
2222
#include "swift/AST/Evaluator.h"
23+
#include "swift/AST/Pattern.h"
2324
#include "swift/AST/SimpleRequest.h"
2425
#include "swift/AST/TypeResolutionStage.h"
2526
#include "swift/Basic/AnyValue.h"
@@ -33,6 +34,7 @@ namespace swift {
3334
class AbstractStorageDecl;
3435
class AccessorDecl;
3536
enum class AccessorKind;
37+
class ContextualPattern;
3638
class DefaultArgumentExpr;
3739
class GenericParamList;
3840
class PrecedenceGroupDecl;
@@ -1991,6 +1993,32 @@ class HasDynamicMemberLookupAttributeRequest
19911993
}
19921994
};
19931995

1996+
/// Determines the type of a given pattern.
1997+
///
1998+
/// Note that this returns the "raw" pattern type, which can involve
1999+
/// unresolved types and unbound generic types where type inference is
2000+
/// allowed.
2001+
class PatternTypeRequest
2002+
: public SimpleRequest<PatternTypeRequest, Type(ContextualPattern),
2003+
CacheKind::Cached> {
2004+
public:
2005+
using SimpleRequest::SimpleRequest;
2006+
2007+
private:
2008+
friend SimpleRequest;
2009+
2010+
// Evaluation.
2011+
llvm::Expected<Type> evaluate(
2012+
Evaluator &evaluator, ContextualPattern pattern) const;
2013+
2014+
public:
2015+
bool isCached() const { return true; }
2016+
2017+
SourceLoc getNearestLoc() const {
2018+
return std::get<0>(getStorage()).getPattern()->getLoc();
2019+
}
2020+
};
2021+
19942022
// Allow AnyValue to compare two Type values, even though Type doesn't
19952023
// support ==.
19962024
template<>

include/swift/AST/TypeCheckerTypeIDZone.def

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -216,3 +216,6 @@ SWIFT_REQUEST(TypeChecker, TypeWitnessRequest,
216216
SWIFT_REQUEST(TypeChecker, ValueWitnessRequest,
217217
Witness(NormalProtocolConformance *, ValueDecl *),
218218
SeparatelyCached, NoLocationInfo)
219+
SWIFT_REQUEST(TypeChecker, PatternTypeRequest,
220+
Type(ContextualPattern),
221+
Cached, HasNearestLocation)

lib/AST/Pattern.cpp

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -496,3 +496,33 @@ const UnifiedStatsReporter::TraceFormatter*
496496
FrontendStatsTracer::getTraceFormatter<const Pattern *>() {
497497
return &TF;
498498
}
499+
500+
501+
ContextualPattern ContextualPattern::forPatternBindingDecl(
502+
PatternBindingDecl *pbd, unsigned index) {
503+
return ContextualPattern(
504+
pbd->getPattern(index), /*isTopLevel=*/true, pbd, index);
505+
}
506+
507+
DeclContext *ContextualPattern::getDeclContext() const {
508+
if (auto pbd = getPatternBindingDecl())
509+
return pbd->getDeclContext();
510+
511+
return declOrContext.get<DeclContext *>();
512+
}
513+
514+
PatternBindingDecl *ContextualPattern::getPatternBindingDecl() const {
515+
return declOrContext.dyn_cast<PatternBindingDecl *>();
516+
}
517+
518+
bool ContextualPattern::allowsInference() const {
519+
if (auto pbd = getPatternBindingDecl())
520+
return pbd->isInitialized(index);
521+
522+
return true;
523+
}
524+
525+
void swift::simple_display(llvm::raw_ostream &out,
526+
const ContextualPattern &pattern) {
527+
out << "(pattern @ " << pattern.getPattern() << ")";
528+
}

lib/Sema/CSGen.cpp

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2158,10 +2158,11 @@ namespace {
21582158
}
21592159

21602160
case PatternKind::Typed: {
2161-
auto typedPattern = cast<TypedPattern>(pattern);
21622161
// FIXME: Need a better locator for a pattern as a base.
2163-
Type openedType = CS.openUnboundGenericType(typedPattern->getType(),
2164-
locator);
2162+
auto contextualPattern =
2163+
ContextualPattern::forRawPattern(pattern, CurDC);
2164+
Type type = TypeChecker::typeCheckPattern(contextualPattern);
2165+
Type openedType = CS.openUnboundGenericType(type, locator);
21652166

21662167
// For a typed pattern, simply return the opened type of the pattern.
21672168
// FIXME: Error recovery if the type is an error type?
@@ -2258,6 +2259,7 @@ namespace {
22582259
// or exhaustive catches.
22592260
class FindInnerThrows : public ASTWalker {
22602261
ConstraintSystem &CS;
2262+
DeclContext *DC;
22612263
bool FoundThrow = false;
22622264

22632265
std::pair<bool, Expr *> walkToExprPre(Expr *expr) override {
@@ -2343,12 +2345,12 @@ namespace {
23432345
Type exnType = CS.getASTContext().getErrorDecl()->getDeclaredType();
23442346
if (!exnType)
23452347
return false;
2346-
if (TypeChecker::coercePatternToType(pattern,
2347-
TypeResolution::forContextual(CS.DC),
2348-
exnType,
2349-
TypeResolverContext::InExpression)) {
2348+
auto contextualPattern =
2349+
ContextualPattern::forRawPattern(pattern, DC);
2350+
pattern = TypeChecker::coercePatternToType(
2351+
contextualPattern, exnType, TypeResolverContext::InExpression);
2352+
if (!pattern)
23502353
return false;
2351-
}
23522354

23532355
clause->setErrorPattern(pattern);
23542356
return clause->isSyntacticallyExhaustive();
@@ -2384,7 +2386,8 @@ namespace {
23842386
}
23852387

23862388
public:
2387-
FindInnerThrows(ConstraintSystem &cs) : CS(cs) {}
2389+
FindInnerThrows(ConstraintSystem &cs, DeclContext *dc)
2390+
: CS(cs), DC(dc) {}
23882391

23892392
bool foundThrow() { return FoundThrow; }
23902393
};
@@ -2397,7 +2400,7 @@ namespace {
23972400
if (!body)
23982401
return false;
23992402

2400-
auto tryFinder = FindInnerThrows(CS);
2403+
auto tryFinder = FindInnerThrows(CS, expr);
24012404
body->walk(tryFinder);
24022405
return tryFinder.foundThrow();
24032406
}

0 commit comments

Comments
 (0)