Skip to content

Commit 8cfdb99

Browse files
authored
Merge pull request #41436 from xedin/allow-specialization-from-default-expr
[TypeChecker] Allow inference from default expressions in certain scenarios (under a flag)
2 parents 48de949 + eaa737c commit 8cfdb99

25 files changed

+994
-53
lines changed

include/swift/AST/Decl.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5504,6 +5504,7 @@ enum class ParamSpecifier : uint8_t {
55045504
class ParamDecl : public VarDecl {
55055505
friend class DefaultArgumentInitContextRequest;
55065506
friend class DefaultArgumentExprRequest;
5507+
friend class DefaultArgumentTypeRequest;
55075508

55085509
enum class ArgumentNameFlags : uint8_t {
55095510
/// Whether or not this parameter is destructed.
@@ -5524,6 +5525,9 @@ class ParamDecl : public VarDecl {
55245525
struct alignas(1 << StoredDefaultArgumentAlignInBits) StoredDefaultArgument {
55255526
PointerUnion<Expr *, VarDecl *> DefaultArg;
55265527

5528+
/// The type of the default argument expression.
5529+
Type ExprType;
5530+
55275531
/// Stores the context for the default argument as well as a bit to
55285532
/// indicate whether the default expression has been type-checked.
55295533
llvm::PointerIntPair<Initializer *, 1, bool> InitContextAndIsTypeChecked;
@@ -5641,6 +5645,10 @@ class ParamDecl : public VarDecl {
56415645
return nullptr;
56425646
}
56435647

5648+
/// Retrieve the type of the default expression (if any) associated with
5649+
/// this parameter declaration.
5650+
Type getTypeOfDefaultExpr() const;
5651+
56445652
VarDecl *getStoredProperty() const {
56455653
if (auto stored = DefaultValueAndFlags.getPointer())
56465654
return stored->DefaultArg.dyn_cast<VarDecl *>();
@@ -5655,6 +5663,10 @@ class ParamDecl : public VarDecl {
56555663
/// parameter's fully type-checked default argument.
56565664
void setDefaultExpr(Expr *E, bool isTypeChecked);
56575665

5666+
/// Sets a type of default expression associated with this parameter.
5667+
/// This should only be called by deserialization.
5668+
void setDefaultExprType(Type type);
5669+
56585670
void setStoredProperty(VarDecl *var);
56595671

56605672
/// Retrieve the initializer context for the parameter's default argument.

include/swift/AST/DiagnosticsSema.def

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,8 @@ NOTE(extended_type_declared_here,none,
3838
"extended type declared here", ())
3939
NOTE(opaque_return_type_declared_here,none,
4040
"opaque return type declared here", ())
41+
NOTE(default_value_declared_here,none,
42+
"default value declared here", ())
4143

4244
//------------------------------------------------------------------------------
4345
// MARK: Constraint solver diagnostics
@@ -6240,5 +6242,28 @@ ERROR(type_sequence_on_non_generic_param, none,
62406242
"'@_typeSequence' must appear on a generic parameter",
62416243
())
62426244

6245+
//------------------------------------------------------------------------------
6246+
// MARK: Type inference from default expressions
6247+
//------------------------------------------------------------------------------
6248+
6249+
ERROR(cannot_default_generic_parameter_inferrable_from_another_parameter, none,
6250+
"cannot use default expression for inference of %0 because it "
6251+
"is inferrable from parameters %1",
6252+
(Type, StringRef))
6253+
6254+
ERROR(cannot_default_generic_parameter_inferrable_through_same_type, none,
6255+
"cannot use default expression for inference of %0 because it "
6256+
"is inferrable through same-type requirement: '%1'",
6257+
(Type, StringRef))
6258+
6259+
ERROR(cannot_default_generic_parameter_invalid_requirement, none,
6260+
"cannot use default expression for inference of %0 because "
6261+
"requirement '%1' refers to other generic parameters",
6262+
(Type, StringRef))
6263+
6264+
ERROR(cannot_convert_default_value_type_to_argument_type, none,
6265+
"cannot convert default value of type %0 to expected argument type %1 for parameter #%2",
6266+
(Type, Type, unsigned))
6267+
62436268
#define UNDEFINE_DIAGNOSTIC_MACROS
62446269
#include "DefineDiagnosticMacros.h"

include/swift/AST/TypeCheckRequests.h

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ class AccessorDecl;
4141
enum class AccessorKind;
4242
class ContextualPattern;
4343
class DefaultArgumentExpr;
44+
class DefaultArgumentType;
4445
class ClosureExpr;
4546
class GenericParamList;
4647
class PrecedenceGroupDecl;
@@ -2665,6 +2666,26 @@ class DefaultArgumentExprRequest
26652666
void cacheResult(Expr *expr) const;
26662667
};
26672668

2669+
/// Computes the type of the default expression for a given parameter.
2670+
class DefaultArgumentTypeRequest
2671+
: public SimpleRequest<DefaultArgumentTypeRequest, Type(ParamDecl *),
2672+
RequestFlags::SeparatelyCached> {
2673+
public:
2674+
using SimpleRequest::SimpleRequest;
2675+
2676+
private:
2677+
friend SimpleRequest;
2678+
2679+
// Evaluation.
2680+
Type evaluate(Evaluator &evaluator, ParamDecl *param) const;
2681+
2682+
public:
2683+
// Separate caching.
2684+
bool isCached() const { return true; }
2685+
Optional<Type> getCachedResult() const;
2686+
void cacheResult(Type type) const;
2687+
};
2688+
26682689
/// Computes the fully type-checked caller-side default argument within the
26692690
/// context of the call site that it will be inserted into.
26702691
class CallerSideDefaultArgExprRequest

include/swift/AST/TypeCheckerTypeIDZone.def

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,8 @@ SWIFT_REQUEST(TypeChecker, CustomAttrTypeRequest,
5959
SeparatelyCached, NoLocationInfo)
6060
SWIFT_REQUEST(TypeChecker, DefaultArgumentExprRequest,
6161
Expr *(ParamDecl *), SeparatelyCached, NoLocationInfo)
62+
SWIFT_REQUEST(TypeChecker, DefaultArgumentTypeRequest,
63+
Type(ParamDecl *), SeparatelyCached, NoLocationInfo)
6264
SWIFT_REQUEST(TypeChecker, DefaultArgumentInitContextRequest,
6365
Initializer *(ParamDecl *), SeparatelyCached, NoLocationInfo)
6466
SWIFT_REQUEST(TypeChecker, DefaultDefinitionTypeRequest,

include/swift/Basic/LangOptions.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -718,6 +718,10 @@ namespace swift {
718718
/// closures.
719719
bool EnableMultiStatementClosureInference = false;
720720

721+
/// Enable experimental support for generic parameter inference in
722+
/// parameter positions from associated default expressions.
723+
bool EnableTypeInferenceFromDefaultArguments = false;
724+
721725
/// See \ref FrontendOptions.PrintFullConvention
722726
bool PrintFullConvention = false;
723727
};

include/swift/Option/FrontendOptions.td

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -861,6 +861,10 @@ def experimental_multi_statement_closures :
861861
Flag<["-"], "experimental-multi-statement-closures">,
862862
HelpText<"Enable experimental support for type inference in multi-statement closures">;
863863

864+
def experimental_type_inference_from_defaults :
865+
Flag<["-"], "enable-experimental-type-inference-from-defaults">,
866+
HelpText<"Enable experimental support for generic parameter inference from default values">;
867+
864868
def prebuilt_module_cache_path :
865869
Separate<["-"], "prebuilt-module-cache-path">,
866870
HelpText<"Directory of prebuilt modules for loading module interfaces">;

include/swift/Sema/CSFix.h

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -382,6 +382,10 @@ enum class FixKind : uint8_t {
382382

383383
/// Produce an error for not getting a compile-time constant
384384
NotCompileTimeConst,
385+
386+
/// Ignore a type mismatch while trying to infer generic parameter type
387+
/// from default expression.
388+
IgnoreDefaultExprTypeMismatch,
385389
};
386390

387391
class ConstraintFix {
@@ -2896,6 +2900,29 @@ class AllowSwiftToCPointerConversion final : public ConstraintFix {
28962900
ConstraintLocator *locator);
28972901
};
28982902

2903+
class IgnoreDefaultExprTypeMismatch : public AllowArgumentMismatch {
2904+
protected:
2905+
IgnoreDefaultExprTypeMismatch(ConstraintSystem &cs, Type argType,
2906+
Type paramType, ConstraintLocator *locator)
2907+
: AllowArgumentMismatch(cs, FixKind::IgnoreDefaultExprTypeMismatch,
2908+
argType, paramType, locator) {}
2909+
2910+
public:
2911+
std::string getName() const override {
2912+
return "allow default expression conversion mismatch";
2913+
}
2914+
2915+
bool diagnose(const Solution &solution, bool asNote = false) const override;
2916+
2917+
static IgnoreDefaultExprTypeMismatch *create(ConstraintSystem &cs,
2918+
Type argType, Type paramType,
2919+
ConstraintLocator *locator);
2920+
2921+
static bool classof(const ConstraintFix *fix) {
2922+
return fix->getKind() == FixKind::IgnoreDefaultExprTypeMismatch;
2923+
}
2924+
};
2925+
28992926
} // end namespace constraints
29002927
} // end namespace swift
29012928

include/swift/Sema/ConstraintSystem.h

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,8 @@ Optional<constraints::SolutionApplicationTarget>
8080
typeCheckExpression(constraints::SolutionApplicationTarget &target,
8181
OptionSet<TypeCheckExprFlags> options);
8282

83+
Type typeCheckParameterDefault(Expr *&, DeclContext *, Type, bool);
84+
8385
} // end namespace TypeChecker
8486

8587
} // end namespace swift
@@ -3045,6 +3047,10 @@ class ConstraintSystem {
30453047
swift::TypeChecker::typeCheckExpression(
30463048
SolutionApplicationTarget &target, OptionSet<TypeCheckExprFlags> options);
30473049

3050+
friend Type swift::TypeChecker::typeCheckParameterDefault(Expr *&,
3051+
DeclContext *, Type,
3052+
bool);
3053+
30483054
/// Emit the fixes computed as part of the solution, returning true if we were
30493055
/// able to emit an error message, or false if none of the fixits worked out.
30503056
bool applySolutionFixes(const Solution &solution);
@@ -4167,6 +4173,13 @@ class ConstraintSystem {
41674173
OpenedTypeMap &replacements,
41684174
ConstraintLocatorBuilder locator);
41694175

4176+
/// Open a generic parameter into a type variable and record
4177+
/// it in \c replacements.
4178+
TypeVariableType *openGenericParameter(DeclContext *outerDC,
4179+
GenericTypeParamType *parameter,
4180+
OpenedTypeMap &replacements,
4181+
ConstraintLocatorBuilder locator);
4182+
41704183
/// Given generic signature open its generic requirements,
41714184
/// using substitution function, and record them in the
41724185
/// constraint system for further processing.
@@ -4176,6 +4189,14 @@ class ConstraintSystem {
41764189
ConstraintLocatorBuilder locator,
41774190
llvm::function_ref<Type(Type)> subst);
41784191

4192+
// Record the given requirement in the constraint system.
4193+
void openGenericRequirement(DeclContext *outerDC,
4194+
unsigned index,
4195+
const Requirement &requirement,
4196+
bool skipProtocolSelfConstraint,
4197+
ConstraintLocatorBuilder locator,
4198+
llvm::function_ref<Type(Type)> subst);
4199+
41794200
/// Record the set of opened types for the given locator.
41804201
void recordOpenedTypes(
41814202
ConstraintLocatorBuilder locator,

lib/AST/Decl.cpp

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6992,6 +6992,18 @@ Expr *ParamDecl::getTypeCheckedDefaultExpr() const {
69926992
return new (ctx) ErrorExpr(getSourceRange(), ErrorType::get(ctx));
69936993
}
69946994

6995+
Type ParamDecl::getTypeOfDefaultExpr() const {
6996+
auto &ctx = getASTContext();
6997+
6998+
if (Type type = evaluateOrDefault(
6999+
ctx.evaluator,
7000+
DefaultArgumentTypeRequest{const_cast<ParamDecl *>(this)}, nullptr)) {
7001+
return type;
7002+
}
7003+
7004+
return Type();
7005+
}
7006+
69957007
void ParamDecl::setDefaultExpr(Expr *E, bool isTypeChecked) {
69967008
if (!DefaultValueAndFlags.getPointer()) {
69977009
if (!E) return;
@@ -7009,9 +7021,20 @@ void ParamDecl::setDefaultExpr(Expr *E, bool isTypeChecked) {
70097021
"Can't overwrite type-checked default with un-type-checked default");
70107022
}
70117023
defaultInfo->DefaultArg = E;
7024+
defaultInfo->ExprType = E->getType();
70127025
defaultInfo->InitContextAndIsTypeChecked.setInt(isTypeChecked);
70137026
}
70147027

7028+
void ParamDecl::setDefaultExprType(Type type) {
7029+
if (!DefaultValueAndFlags.getPointer()) {
7030+
DefaultValueAndFlags.setPointer(
7031+
getASTContext().Allocate<StoredDefaultArgument>());
7032+
}
7033+
7034+
auto *defaultInfo = DefaultValueAndFlags.getPointer();
7035+
defaultInfo->ExprType = type;
7036+
}
7037+
70157038
void ParamDecl::setStoredProperty(VarDecl *var) {
70167039
if (!DefaultValueAndFlags.getPointer()) {
70177040
if (!var) return;

lib/AST/TypeCheckRequests.cpp

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1206,6 +1206,27 @@ void DefaultArgumentExprRequest::cacheResult(Expr *expr) const {
12061206
param->setDefaultExpr(expr, /*isTypeChecked*/ true);
12071207
}
12081208

1209+
//----------------------------------------------------------------------------//
1210+
// DefaultArgumentTypeRequest computation.
1211+
//----------------------------------------------------------------------------//
1212+
1213+
Optional<Type> DefaultArgumentTypeRequest::getCachedResult() const {
1214+
auto *param = std::get<0>(getStorage());
1215+
auto *defaultInfo = param->DefaultValueAndFlags.getPointer();
1216+
if (!defaultInfo)
1217+
return None;
1218+
1219+
if (!defaultInfo->InitContextAndIsTypeChecked.getInt())
1220+
return None;
1221+
1222+
return defaultInfo->ExprType;
1223+
}
1224+
1225+
void DefaultArgumentTypeRequest::cacheResult(Type type) const {
1226+
auto *param = std::get<0>(getStorage());
1227+
param->setDefaultExprType(type);
1228+
}
1229+
12091230
//----------------------------------------------------------------------------//
12101231
// CallerSideDefaultArgExprRequest computation.
12111232
//----------------------------------------------------------------------------//

lib/Frontend/CompilerInvocation.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1051,6 +1051,9 @@ static bool ParseTypeCheckerArgs(TypeCheckerOptions &Opts, ArgList &Args,
10511051
Opts.EnableMultiStatementClosureInference |=
10521052
Args.hasArg(OPT_experimental_multi_statement_closures);
10531053

1054+
Opts.EnableTypeInferenceFromDefaultArguments |=
1055+
Args.hasArg(OPT_experimental_type_inference_from_defaults);
1056+
10541057
Opts.PrintFullConvention |=
10551058
Args.hasArg(OPT_experimental_print_full_convention);
10561059

lib/SIL/IR/TypeLowering.cpp

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2456,8 +2456,16 @@ static CanAnyFunctionType getDefaultArgGeneratorInterfaceType(
24562456
TypeConverter &TC,
24572457
SILDeclRef c) {
24582458
auto *vd = c.getDecl();
2459-
auto resultTy = getParameterAt(vd,
2460-
c.defaultArgIndex)->getInterfaceType();
2459+
auto *pd = getParameterAt(vd, c.defaultArgIndex);
2460+
2461+
Type resultTy;
2462+
2463+
if (auto type = pd->getTypeOfDefaultExpr()) {
2464+
resultTy = type->mapTypeOutOfContext();
2465+
} else {
2466+
resultTy = pd->getInterfaceType();
2467+
}
2468+
24612469
assert(resultTy && "Didn't find default argument?");
24622470

24632471
// The result type might be written in terms of type parameters

lib/Sema/CSDiagnostics.cpp

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -257,6 +257,11 @@ ValueDecl *RequirementFailure::getDeclRef() const {
257257
return cast<ValueDecl>(getDC()->getAsDecl());
258258
}
259259

260+
if (contextualPurpose == CTP_DefaultParameter ||
261+
contextualPurpose == CTP_AutoclosureDefaultParameter) {
262+
return cast<ValueDecl>(getDC()->getParent()->getAsDecl());
263+
}
264+
260265
return getAffectedDeclFromType(contextualTy);
261266
}
262267

@@ -8087,3 +8092,27 @@ bool SwiftToCPointerConversionInInvalidContext::diagnoseAsError() {
80878092
paramType, callee->getDescriptiveKind(), callee->getName());
80888093
return true;
80898094
}
8095+
8096+
bool DefaultExprTypeMismatch::diagnoseAsError() {
8097+
auto *locator = getLocator();
8098+
8099+
unsigned paramIdx =
8100+
locator->castLastElementTo<LocatorPathElt::ApplyArgToParam>()
8101+
.getParamIdx();
8102+
8103+
emitDiagnostic(diag::cannot_convert_default_value_type_to_argument_type,
8104+
getFromType(), getToType(), paramIdx);
8105+
8106+
auto overload = getCalleeOverloadChoiceIfAvailable(locator);
8107+
assert(overload);
8108+
8109+
auto *PD = getParameterList(overload->choice.getDecl())->get(paramIdx);
8110+
8111+
auto note = emitDiagnosticAt(PD->getLoc(), diag::default_value_declared_here);
8112+
8113+
if (auto *defaultExpr = PD->getTypeCheckedDefaultExpr()) {
8114+
note.highlight(defaultExpr->getSourceRange());
8115+
}
8116+
8117+
return true;
8118+
}

lib/Sema/CSDiagnostics.h

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2665,6 +2665,28 @@ class SwiftToCPointerConversionInInvalidContext final
26652665
bool diagnoseAsError() override;
26662666
};
26672667

2668+
/// Diagnose situations where the type of default expression doesn't
2669+
/// match expected type of the argument i.e. generic parameter type
2670+
/// was inferred from result:
2671+
///
2672+
/// \code
2673+
/// func test<T>(_: T = 42) -> T { ... }
2674+
///
2675+
/// let _: String = test() // conflict between `String` and `Int`.
2676+
/// \endcode
2677+
class DefaultExprTypeMismatch final : public ContextualFailure {
2678+
public:
2679+
DefaultExprTypeMismatch(const Solution &solution, Type argType,
2680+
Type paramType, ConstraintLocator *locator)
2681+
: ContextualFailure(solution, argType, paramType, locator) {}
2682+
2683+
SourceLoc getLoc() const override {
2684+
return constraints::getLoc(getLocator()->getAnchor());
2685+
}
2686+
2687+
bool diagnoseAsError() override;
2688+
};
2689+
26682690
} // end namespace constraints
26692691
} // end namespace swift
26702692

0 commit comments

Comments
 (0)