Skip to content

Commit 1022d8c

Browse files
committed
[AST/Sema] changes for flattened functions
1 parent 3b015dd commit 1022d8c

File tree

11 files changed

+209
-115
lines changed

11 files changed

+209
-115
lines changed

include/swift/AST/Types.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2297,6 +2297,16 @@ class AnyFunctionType : public TypeBase {
22972297
return getExtInfo().throws();
22982298
}
22992299

2300+
unsigned getCurryLevel() const {
2301+
unsigned Level = 0;
2302+
const AnyFunctionType *function = this;
2303+
while ((function = function->getResult()->getAs<AnyFunctionType>()))
2304+
++Level;
2305+
return Level;
2306+
}
2307+
2308+
AnyFunctionType *getUncurriedFunction();
2309+
23002310
/// Returns a new function type exactly like this one but with the ExtInfo
23012311
/// replaced.
23022312
AnyFunctionType *withExtInfo(ExtInfo info) const;

lib/AST/ASTContext.cpp

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3131,6 +3131,32 @@ AnyFunctionType *AnyFunctionType::withExtInfo(ExtInfo info) const {
31313131
llvm_unreachable("unhandled function type");
31323132
}
31333133

3134+
AnyFunctionType *AnyFunctionType::getUncurriedFunction() {
3135+
assert(getCurryLevel() > 0 && "nothing to uncurry");
3136+
3137+
auto innerFunction = getResult()->castTo<AnyFunctionType>();
3138+
SmallVector<TupleTypeElt, 4> params{getInput()->getDesugaredType()};
3139+
3140+
if (auto tuple = innerFunction->getInput()->getAs<TupleType>())
3141+
params.append(tuple->getElements().begin(), tuple->getElements().end());
3142+
else
3143+
params.push_back(innerFunction->getInput()->getDesugaredType());
3144+
3145+
auto inputType = TupleType::get(params, getASTContext());
3146+
auto extInfo =
3147+
innerFunction->getExtInfo().withRepresentation(getRepresentation());
3148+
3149+
if (auto generic = getAs<GenericFunctionType>())
3150+
return GenericFunctionType::get(generic->getGenericSignature(), inputType,
3151+
innerFunction->getResult(), extInfo);
3152+
3153+
if (auto poly = getAs<PolymorphicFunctionType>())
3154+
return PolymorphicFunctionType::get(inputType, innerFunction->getResult(),
3155+
&poly->getGenericParams(), extInfo);
3156+
3157+
return FunctionType::get(inputType, innerFunction->getResult(), extInfo);
3158+
}
3159+
31343160
FunctionType *FunctionType::get(Type Input, Type Result,
31353161
const ExtInfo &Info) {
31363162
auto properties = getFunctionRecursiveProperties(Input, Result);

lib/Sema/CSApply.cpp

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1024,6 +1024,16 @@ namespace {
10241024
Expr *result = new (context) DotSyntaxBaseIgnoredExpr(base, dotLoc,
10251025
ref);
10261026
closeExistential(result, /*force=*/openedExistential);
1027+
1028+
if (!isa<FuncDecl>(member))
1029+
return result;
1030+
1031+
auto newTy = result->getType()
1032+
->castTo<AnyFunctionType>()
1033+
->getUncurriedFunction();
1034+
result = coerceToType(
1035+
result, newTy, locator.withPathElement(ConstraintLocator::Member));
1036+
cast<FunctionConversionExpr>(result)->setFlattening();
10271037
return result;
10281038
} else {
10291039
assert((!baseIsInstance || member->isInstanceMember()) &&
@@ -4637,6 +4647,8 @@ static bool isReferenceToMetatypeMember(Expr *expr) {
46374647
return dotIgnored->getLHS()->getType()->is<AnyMetatypeType>();
46384648
if (auto dotSyntax = dyn_cast<DotSyntaxCallExpr>(expr))
46394649
return dotSyntax->getBase()->getType()->is<AnyMetatypeType>();
4650+
if (auto conversion = dyn_cast<FunctionConversionExpr>(expr))
4651+
return isReferenceToMetatypeMember(conversion->getSubExpr());
46404652
return false;
46414653
}
46424654

@@ -6952,4 +6964,3 @@ Expr *Solution::convertOptionalToBool(Expr *expr,
69526964
isSomeExpr->setType(tc.lookupBoolType(cs.DC));
69536965
return isSomeExpr;
69546966
}
6955-

lib/Sema/CSSimplify.cpp

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -889,6 +889,13 @@ ConstraintSystem::SolutionKind
889889
ConstraintSystem::matchFunctionTypes(FunctionType *func1, FunctionType *func2,
890890
TypeMatchKind kind, unsigned flags,
891891
ConstraintLocatorBuilder locator) {
892+
if (flags & TMF_FlattenFunction) {
893+
// If we want a flattened function, remove a level of currying and continue.
894+
return matchFunctionTypes(
895+
func1->getUncurriedFunction()->castTo<FunctionType>(), func2, kind,
896+
flags & ~TMF_FlattenFunction, locator);
897+
}
898+
892899
// An @autoclosure function type can be a subtype of a
893900
// non-@autoclosure function type.
894901
if (func1->isAutoClosure() != func2->isAutoClosure() &&
@@ -2992,19 +2999,24 @@ performMemberLookup(ConstraintKind constraintKind, DeclName memberName,
29922999
}
29933000
}
29943001

3002+
bool isFlattened =
3003+
isMetatype && cand->isInstanceMember() && isa<FuncDecl>(cand);
3004+
29953005
// If we're looking into an existential type, check whether this
29963006
// result was found via dynamic lookup.
29973007
if (isDynamicLookup) {
29983008
assert(cand->getDeclContext()->isTypeContext() && "Dynamic lookup bug");
29993009

30003010
// We found this declaration via dynamic lookup, record it as such.
3001-
result.addViable(OverloadChoice::getDeclViaDynamic(baseTy, cand));
3011+
result.addViable(
3012+
OverloadChoice::getDeclViaDynamic(baseTy, cand, isFlattened));
30023013
return;
30033014
}
30043015

30053016
// If we have a bridged type, we found this declaration via bridging.
30063017
if (isBridged) {
3007-
result.addViable(OverloadChoice::getDeclViaBridge(bridgedType, cand));
3018+
result.addViable(
3019+
OverloadChoice::getDeclViaBridge(bridgedType, cand, isFlattened));
30083020
return;
30093021
}
30103022

@@ -3015,10 +3027,11 @@ performMemberLookup(ConstraintKind constraintKind, DeclName memberName,
30153027
ovlBaseTy = MetatypeType::get(baseTy->castTo<MetatypeType>()
30163028
->getInstanceType()
30173029
->getAnyOptionalObjectType());
3018-
result.addViable(OverloadChoice::getDeclViaUnwrappedOptional(ovlBaseTy,
3019-
cand));
3030+
result.addViable(OverloadChoice::getDeclViaUnwrappedOptional(
3031+
ovlBaseTy, cand, isFlattened));
30203032
} else {
3021-
result.addViable(OverloadChoice(ovlBaseTy, cand, *this));
3033+
result.addViable(OverloadChoice(ovlBaseTy, cand, *this,
3034+
/*isSpecialized=*/false, isFlattened));
30223035
}
30233036
};
30243037

@@ -4328,9 +4341,10 @@ ConstraintSystem::simplifyConstraint(const Constraint &constraint) {
43284341
return result;
43294342
}
43304343

4331-
return matchTypes(constraint.getFirstType(), constraint.getSecondType(),
4332-
matchKind,
4333-
TMF_None, constraint.getLocator());
4344+
return matchTypes(
4345+
constraint.getFirstType(), constraint.getSecondType(), matchKind,
4346+
constraint.isFunctionFlattening() ? TMF_FlattenFunction : TMF_None,
4347+
constraint.getLocator());
43344348
}
43354349

43364350
case ConstraintKind::ApplicableFunction:

lib/Sema/Constraint.cpp

Lines changed: 24 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -27,24 +27,23 @@ using namespace swift;
2727
using namespace constraints;
2828

2929
Constraint::Constraint(ConstraintKind kind, ArrayRef<Constraint *> constraints,
30-
ConstraintLocator *locator,
30+
ConstraintLocator *locator,
3131
ArrayRef<TypeVariableType *> typeVars)
32-
: Kind(kind), HasRestriction(false), HasFix(false), IsActive(false),
33-
RememberChoice(false), IsFavored(false), NumTypeVariables(typeVars.size()),
34-
Nested(constraints), Locator(locator)
35-
{
32+
: Kind(kind), HasRestriction(false), HasFix(false), IsActive(false),
33+
RememberChoice(false), IsFavored(false), IsFunctionFlattening(false),
34+
NumTypeVariables(typeVars.size()), Nested(constraints), Locator(locator) {
3635
assert(kind == ConstraintKind::Disjunction);
3736
std::uninitialized_copy(typeVars.begin(), typeVars.end(),
3837
getTypeVariablesBuffer().begin());
3938
}
4039

41-
Constraint::Constraint(ConstraintKind Kind, Type First, Type Second,
40+
Constraint::Constraint(ConstraintKind Kind, Type First, Type Second,
4241
DeclName Member, ConstraintLocator *locator,
4342
ArrayRef<TypeVariableType *> typeVars)
44-
: Kind(Kind), HasRestriction(false), HasFix(false), IsActive(false),
45-
RememberChoice(false), IsFavored(false), NumTypeVariables(typeVars.size()),
46-
Types { First, Second, Member }, Locator(locator)
47-
{
43+
: Kind(Kind), HasRestriction(false), HasFix(false), IsActive(false),
44+
RememberChoice(false), IsFavored(false), IsFunctionFlattening(false),
45+
NumTypeVariables(typeVars.size()), Types{First, Second, Member},
46+
Locator(locator) {
4847
switch (Kind) {
4948
case ConstraintKind::Bind:
5049
case ConstraintKind::Equal:
@@ -100,26 +99,24 @@ Constraint::Constraint(ConstraintKind Kind, Type First, Type Second,
10099
std::copy(typeVars.begin(), typeVars.end(), getTypeVariablesBuffer().begin());
101100
}
102101

103-
Constraint::Constraint(Type type, OverloadChoice choice,
102+
Constraint::Constraint(Type type, OverloadChoice choice,
104103
ConstraintLocator *locator,
105104
ArrayRef<TypeVariableType *> typeVars)
106-
: Kind(ConstraintKind::BindOverload),
107-
HasRestriction(false), HasFix(false), IsActive(false),
108-
RememberChoice(false), IsFavored(false), NumTypeVariables(typeVars.size()),
109-
Overload{type, choice}, Locator(locator)
110-
{
105+
: Kind(ConstraintKind::BindOverload), HasRestriction(false), HasFix(false),
106+
IsActive(false), RememberChoice(false), IsFavored(false),
107+
IsFunctionFlattening(false), NumTypeVariables(typeVars.size()),
108+
Overload{type, choice}, Locator(locator) {
111109
std::copy(typeVars.begin(), typeVars.end(), getTypeVariablesBuffer().begin());
112110
}
113111

114-
Constraint::Constraint(ConstraintKind kind,
112+
Constraint::Constraint(ConstraintKind kind,
115113
ConversionRestrictionKind restriction,
116114
Type first, Type second, ConstraintLocator *locator,
117115
ArrayRef<TypeVariableType *> typeVars)
118-
: Kind(kind), Restriction(restriction),
119-
HasRestriction(true), HasFix(false), IsActive(false),
120-
RememberChoice(false), IsFavored(false), NumTypeVariables(typeVars.size()),
121-
Types{ first, second, Identifier() }, Locator(locator)
122-
{
116+
: Kind(kind), Restriction(restriction), HasRestriction(true), HasFix(false),
117+
IsActive(false), RememberChoice(false), IsFavored(false),
118+
IsFunctionFlattening(false), NumTypeVariables(typeVars.size()),
119+
Types{first, second, Identifier()}, Locator(locator) {
123120
assert(!first.isNull());
124121
assert(!second.isNull());
125122
std::copy(typeVars.begin(), typeVars.end(), getTypeVariablesBuffer().begin());
@@ -128,12 +125,11 @@ Constraint::Constraint(ConstraintKind kind,
128125
Constraint::Constraint(ConstraintKind kind, Fix fix,
129126
Type first, Type second, ConstraintLocator *locator,
130127
ArrayRef<TypeVariableType *> typeVars)
131-
: Kind(kind), TheFix(fix.getKind()), FixData(fix.getData()),
132-
HasRestriction(false), HasFix(true),
133-
IsActive(false), RememberChoice(false), IsFavored(false),
134-
NumTypeVariables(typeVars.size()),
135-
Types{ first, second, Identifier() }, Locator(locator)
136-
{
128+
: Kind(kind), TheFix(fix.getKind()), FixData(fix.getData()),
129+
HasRestriction(false), HasFix(true), IsActive(false),
130+
RememberChoice(false), IsFavored(false), IsFunctionFlattening(false),
131+
NumTypeVariables(typeVars.size()), Types{first, second, Identifier()},
132+
Locator(locator) {
137133
assert(!first.isNull());
138134
assert(!second.isNull());
139135
std::copy(typeVars.begin(), typeVars.end(), getTypeVariablesBuffer().begin());

lib/Sema/Constraint.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -319,6 +319,8 @@ class Constraint final : public llvm::ilist_node<Constraint>,
319319
/// in its disjunction.
320320
unsigned IsFavored : 1;
321321

322+
unsigned IsFunctionFlattening : 1;
323+
322324
/// The number of type variables referenced by this constraint.
323325
///
324326
/// The type variables themselves are tail-allocated.
@@ -446,6 +448,9 @@ class Constraint final : public llvm::ilist_node<Constraint>,
446448
/// this constraint.
447449
bool shouldRememberChoice() const { return RememberChoice; }
448450

451+
void setFunctionFlattening() { IsFunctionFlattening = true; }
452+
bool isFunctionFlattening() const { return IsFunctionFlattening; }
453+
449454
/// Retrieve the set of type variables referenced by this constraint.
450455
ArrayRef<TypeVariableType *> getTypeVariables() const {
451456
return {getTrailingObjects<TypeVariableType*>(), NumTypeVariables};

lib/Sema/ConstraintSystem.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1430,6 +1430,9 @@ void ConstraintSystem::resolveOverload(ConstraintLocator *locator,
14301430
= getTypeOfMemberReference(choice.getBaseType(), choice.getDecl(),
14311431
isTypeReference, isDynamicResult,
14321432
locator, base, nullptr);
1433+
1434+
if (choice.isFlattened())
1435+
refType = refType->castTo<AnyFunctionType>()->getUncurriedFunction();
14331436
} else {
14341437
std::tie(openedFullType, refType)
14351438
= getTypeOfReference(choice.getDecl(), isTypeReference,

0 commit comments

Comments
 (0)