Skip to content

Sema: Apply @escaping to typealiases with underlying function type #4347

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
151 changes: 105 additions & 46 deletions lib/Sema/TypeCheckType.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1402,6 +1402,38 @@ static bool diagnoseAvailability(Type ty, IdentTypeRepr *IdType, SourceLoc Loc,
return false;
}

/// Whether the given DC is a noescape-by-default context, i.e. not a property
/// setter
static bool isDefaultNoEscapeContext(const DeclContext *DC) {
auto funcDecl = dyn_cast<FuncDecl>(DC);
return !funcDecl || !funcDecl->isSetter();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm. What about index parameters of subscripts?

}

// Hack to apply context-specific @escaping to an AST function type.
static Type adjustFunctionExtInfo(DeclContext *DC,
Type ty,
TypeResolutionOptions options) {
// Remember whether this is a function parameter.
bool isFunctionParam =
options.contains(TR_FunctionInput) ||
options.contains(TR_ImmediateFunctionInput);

bool defaultNoEscape = isFunctionParam && isDefaultNoEscapeContext(DC);

// Desugar here
auto *funcTy = ty->castTo<FunctionType>();
auto extInfo = funcTy->getExtInfo();
if (defaultNoEscape && !extInfo.isNoEscape()) {
extInfo = extInfo.withNoEscape();

// We lost the sugar to flip the isNoEscape bit
return FunctionType::get(funcTy->getInput(), funcTy->getResult(), extInfo);
}

// Note: original sugared type
return ty;
}

/// \brief Returns a valid type or ErrorType in case of an error.
Type TypeChecker::resolveIdentifierType(
DeclContext *DC,
Expand Down Expand Up @@ -1432,6 +1464,11 @@ Type TypeChecker::resolveIdentifierType(
return ErrorType::get(Context);
}

// Hack to apply context-specific @escaping to a typealias with an underlying
// function type.
if (result->is<FunctionType>())
result = adjustFunctionExtInfo(DC, result, options);

// We allow a type to conform to a protocol that is less available than
// the type itself. This enables a type to retroactively model or directly
// conform to a protocol only available on newer OSes and yet still be used on
Expand Down Expand Up @@ -1637,29 +1674,20 @@ Type TypeChecker::resolveType(TypeRepr *TyR, DeclContext *DC,
return result;
}

/// Whether the given DC is a noescape-by-default context, i.e. not a property
/// setter
static bool isDefaultNoEscapeContext(const DeclContext *DC) {
auto funcDecl = dyn_cast<FuncDecl>(DC);
return !funcDecl || !funcDecl->isSetter();
}

Type TypeResolver::resolveType(TypeRepr *repr, TypeResolutionOptions options) {
assert(repr && "Cannot validate null TypeReprs!");

// If we know the type representation is invalid, just return an
// error type.
if (repr->isInvalid()) return ErrorType::get(TC.Context);

// Remember whether this is a function parameter.
bool isFunctionParam =
options.contains(TR_FunctionInput) ||
options.contains(TR_ImmediateFunctionInput);

// Strip the "is function input" bits unless this is a type that knows about
// them.
if (!isa<InOutTypeRepr>(repr) && !isa<TupleTypeRepr>(repr) &&
!isa<AttributedTypeRepr>(repr)) {
if (!isa<InOutTypeRepr>(repr) &&
!isa<TupleTypeRepr>(repr) &&
!isa<AttributedTypeRepr>(repr) &&
!isa<FunctionTypeRepr>(repr) &&
!isa<IdentTypeRepr>(repr)) {
options -= TR_ImmediateFunctionInput;
options -= TR_FunctionInput;
}
Expand All @@ -1686,11 +1714,10 @@ Type TypeResolver::resolveType(TypeRepr *repr, TypeResolutionOptions options) {
case TypeReprKind::Function:
if (!(options & TR_SILType)) {
// Default non-escaping for closure parameters
auto info = AnyFunctionType::ExtInfo().withNoEscape(
isFunctionParam &&
isDefaultNoEscapeContext(DC));
return resolveASTFunctionType(cast<FunctionTypeRepr>(repr), options,
info);
auto result = resolveASTFunctionType(cast<FunctionTypeRepr>(repr), options);
if (result->is<FunctionType>())
return adjustFunctionExtInfo(DC, result, options);
return result;
}
return resolveSILFunctionType(cast<FunctionTypeRepr>(repr), options);

Expand Down Expand Up @@ -1761,8 +1788,6 @@ Type TypeResolver::resolveAttributedType(TypeAttributes &attrs,
bool isFunctionParam =
options.contains(TR_FunctionInput) ||
options.contains(TR_ImmediateFunctionInput);
options -= TR_ImmediateFunctionInput;
options -= TR_FunctionInput;

// The type we're working with, in case we want to build it differently
// based on the attributes we see.
Expand All @@ -1786,7 +1811,11 @@ Type TypeResolver::resolveAttributedType(TypeAttributes &attrs,
if (base) {
Optional<MetatypeRepresentation> storedRepr;
// The instance type is not a SIL type.
auto instanceOptions = options - TR_SILType;
auto instanceOptions = options;
instanceOptions -= TR_SILType;
instanceOptions -= TR_ImmediateFunctionInput;
instanceOptions -= TR_FunctionInput;

auto instanceTy = resolveType(base, instanceOptions);
if (!instanceTy || instanceTy->is<ErrorType>())
return instanceTy;
Expand Down Expand Up @@ -1915,10 +1944,6 @@ Type TypeResolver::resolveAttributedType(TypeAttributes &attrs,

ty = resolveSILFunctionType(fnRepr, options, extInfo, calleeConvention);
if (!ty || ty->is<ErrorType>()) return ty;

for (auto i : FunctionAttrs)
attrs.clearAttribute(i);
attrs.convention = None;
} else if (hasFunctionAttr && fnRepr) {

FunctionType::Representation rep = FunctionType::Representation::Swift;
Expand Down Expand Up @@ -1962,29 +1987,54 @@ Type TypeResolver::resolveAttributedType(TypeAttributes &attrs,
.fixItReplace(resultRange, "Never");
}

bool defaultNoEscape = false;
if (isFunctionParam && !attrs.has(TAK_escaping)) {
defaultNoEscape = isDefaultNoEscapeContext(DC);
}

if (isFunctionParam && attrs.has(TAK_noescape) &&
isDefaultNoEscapeContext(DC)) {
if (attrs.has(TAK_noescape)) {
// FIXME: diagnostic to tell user this is redundant and drop it
}

// Resolve the function type directly with these attributes.
FunctionType::ExtInfo extInfo(rep,
attrs.has(TAK_autoclosure),
defaultNoEscape | attrs.has(TAK_noescape),
attrs.has(TAK_noescape),
fnRepr->throws());

ty = resolveASTFunctionType(fnRepr, options, extInfo);
if (!ty || ty->is<ErrorType>()) return ty;
}

for (auto i : FunctionAttrs)
attrs.clearAttribute(i);
attrs.convention = None;
} else if (hasFunctionAttr) {
auto instanceOptions = options;
instanceOptions -= TR_ImmediateFunctionInput;
instanceOptions -= TR_FunctionInput;

// If we didn't build the type differently above, we might have
// a typealias pointing at a function type with the @escaping
// attribute. Resolve the type as if it were in non-parameter
// context, and then set isNoEscape if @escaping is not present.
if (!ty) ty = resolveType(repr, instanceOptions);
if (!ty || ty->is<ErrorType>()) return ty;

// Handle @escaping
if (hasFunctionAttr && ty->is<FunctionType>()) {
if (attrs.has(TAK_escaping)) {
// The attribute is meaningless except on parameter types.
if (!isFunctionParam) {
auto &SM = TC.Context.SourceMgr;
auto loc = attrs.getLoc(TAK_escaping);
auto attrRange = SourceRange(
loc.getAdvancedLoc(-1),
Lexer::getLocForEndOfToken(SM, loc));

TC.diagnose(loc, diag::escaping_function_type)
.fixItRemove(attrRange);
}

attrs.clearAttribute(TAK_escaping);
} else {
// No attribute; set the isNoEscape bit if we're in parameter context.
ty = adjustFunctionExtInfo(DC, ty, options);
}
}

if (hasFunctionAttr && !fnRepr) {
// @autoclosure usually auto-implies @noescape, don't complain about both
// of them.
if (attrs.has(TAK_autoclosure))
Expand All @@ -1997,11 +2047,11 @@ Type TypeResolver::resolveAttributedType(TypeAttributes &attrs,
attrs.clearAttribute(i);
}
}
}

// If we didn't build the type differently above, build it normally now.
if (!ty) ty = resolveType(repr, options);
if (!ty || ty->is<ErrorType>()) return ty;
} else if (hasFunctionAttr && fnRepr) {
for (auto i : FunctionAttrs)
attrs.clearAttribute(i);
attrs.convention = None;
}

// In SIL, handle @opened (n), which creates an existential archetype.
if (attrs.has(TAK_opened)) {
Expand Down Expand Up @@ -2056,6 +2106,9 @@ Type TypeResolver::resolveAttributedType(TypeAttributes &attrs,
Type TypeResolver::resolveASTFunctionType(FunctionTypeRepr *repr,
TypeResolutionOptions options,
FunctionType::ExtInfo extInfo) {
options -= TR_ImmediateFunctionInput;
options -= TR_FunctionInput;

Type inputTy = resolveType(repr->getArgsTypeRepr(),
options | TR_ImmediateFunctionInput);
if (!inputTy || inputTy->is<ErrorType>()) return inputTy;
Expand Down Expand Up @@ -2120,6 +2173,9 @@ Type TypeResolver::resolveSILFunctionType(FunctionTypeRepr *repr,
TypeResolutionOptions options,
SILFunctionType::ExtInfo extInfo,
ParameterConvention callee) {
options -= TR_ImmediateFunctionInput;
options -= TR_FunctionInput;

bool hasError = false;

SmallVector<SILParameterInfo, 4> params;
Expand Down Expand Up @@ -2448,9 +2504,12 @@ Type TypeResolver::resolveTupleType(TupleTypeRepr *repr,

// If this is the top level of a function input list, peel off the
// ImmediateFunctionInput marker and install a FunctionInput one instead.
auto elementOptions = withoutContext(options);
if (options & TR_ImmediateFunctionInput)
elementOptions |= TR_FunctionInput;
auto elementOptions = options;
if (!repr->isParenType()) {
elementOptions = withoutContext(elementOptions);
if (options & TR_ImmediateFunctionInput)
elementOptions |= TR_FunctionInput;
}

for (auto tyR : repr->getElements()) {
NamedTypeRepr *namedTyR = dyn_cast<NamedTypeRepr>(tyR);
Expand Down
10 changes: 5 additions & 5 deletions lib/Serialization/Serialization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3097,9 +3097,9 @@ void Serializer::writeType(Type ty) {
}

case TypeKind::Optional: {
auto sliceTy = cast<OptionalType>(ty.getPointer());
auto optionalTy = cast<OptionalType>(ty.getPointer());

Type base = sliceTy->getBaseType();
Type base = optionalTy->getBaseType();

unsigned abbrCode = DeclTypeAbbrCodes[OptionalTypeLayout::Code];
OptionalTypeLayout::emitRecord(Out, ScratchRecord, abbrCode,
Expand All @@ -3108,13 +3108,13 @@ void Serializer::writeType(Type ty) {
}

case TypeKind::ImplicitlyUnwrappedOptional: {
auto sliceTy = cast<ImplicitlyUnwrappedOptionalType>(ty.getPointer());
auto optionalTy = cast<ImplicitlyUnwrappedOptionalType>(ty.getPointer());

Type base = sliceTy->getBaseType();
Type base = optionalTy->getBaseType();

unsigned abbrCode = DeclTypeAbbrCodes[ImplicitlyUnwrappedOptionalTypeLayout::Code];
ImplicitlyUnwrappedOptionalTypeLayout::emitRecord(Out, ScratchRecord, abbrCode,
addTypeRef(base));
addTypeRef(base));
break;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -514,8 +514,8 @@ internal func _product<C1 : Collection, C2 : Collection>(
wrapValueIntoEquatable: @escaping (
MinimalEquatableValue) -> CollectionWithEquatableElement.Iterator.Element,

extractValueFromEquatable:
((CollectionWithEquatableElement.Iterator.Element) -> MinimalEquatableValue),
extractValueFromEquatable: @escaping (
CollectionWithEquatableElement.Iterator.Element) -> MinimalEquatableValue,

resiliencyChecks: CollectionMisuseResiliencyChecks = .all,
outOfBoundsIndexOffset: Int = 1,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,11 +92,11 @@ extension TestSuite {

makeCollectionOfEquatable: @escaping ([CollectionWithEquatableElement.Iterator.Element]) -> CollectionWithEquatableElement,
wrapValueIntoEquatable: @escaping (MinimalEquatableValue) -> CollectionWithEquatableElement.Iterator.Element,
extractValueFromEquatable: ((CollectionWithEquatableElement.Iterator.Element) -> MinimalEquatableValue),
extractValueFromEquatable: @escaping ((CollectionWithEquatableElement.Iterator.Element) -> MinimalEquatableValue),

makeCollectionOfComparable: @escaping ([CollectionWithComparableElement.Iterator.Element]) -> CollectionWithComparableElement,
wrapValueIntoComparable: @escaping (MinimalComparableValue) -> CollectionWithComparableElement.Iterator.Element,
extractValueFromComparable: ((CollectionWithComparableElement.Iterator.Element) -> MinimalComparableValue),
extractValueFromComparable: @escaping ((CollectionWithComparableElement.Iterator.Element) -> MinimalComparableValue),

resiliencyChecks: CollectionMisuseResiliencyChecks = .all,
outOfBoundsIndexOffset: Int = 1,
Expand Down Expand Up @@ -505,8 +505,8 @@ self.test("\(testNamePrefix).sorted/DispatchesThrough_withUnsafeMutableBufferPoi

func checkSort_${'Predicate' if predicate else 'WhereElementIsComparable'}(
sequence: [Int],
equalImpl: ((Int, Int) -> Bool),
lessImpl: ((Int, Int) -> Bool),
equalImpl: @escaping ((Int, Int) -> Bool),
lessImpl: @escaping ((Int, Int) -> Bool),
verifyOrder: Bool
) {
% if predicate:
Expand Down Expand Up @@ -686,11 +686,11 @@ self.test("\(testNamePrefix).partition/InvalidOrderings") {

makeCollectionOfEquatable: @escaping ([CollectionWithEquatableElement.Iterator.Element]) -> CollectionWithEquatableElement,
wrapValueIntoEquatable: @escaping (MinimalEquatableValue) -> CollectionWithEquatableElement.Iterator.Element,
extractValueFromEquatable: ((CollectionWithEquatableElement.Iterator.Element) -> MinimalEquatableValue),
extractValueFromEquatable: @escaping ((CollectionWithEquatableElement.Iterator.Element) -> MinimalEquatableValue),

makeCollectionOfComparable: @escaping ([CollectionWithComparableElement.Iterator.Element]) -> CollectionWithComparableElement,
wrapValueIntoComparable: @escaping (MinimalComparableValue) -> CollectionWithComparableElement.Iterator.Element,
extractValueFromComparable: ((CollectionWithComparableElement.Iterator.Element) -> MinimalComparableValue),
extractValueFromComparable: @escaping ((CollectionWithComparableElement.Iterator.Element) -> MinimalComparableValue),

resiliencyChecks: CollectionMisuseResiliencyChecks = .all,
outOfBoundsIndexOffset: Int = 1,
Expand Down Expand Up @@ -842,11 +842,11 @@ self.test("\(testNamePrefix).partition/DispatchesThrough_withUnsafeMutableBuffer

makeCollectionOfEquatable: @escaping ([CollectionWithEquatableElement.Iterator.Element]) -> CollectionWithEquatableElement,
wrapValueIntoEquatable: @escaping (MinimalEquatableValue) -> CollectionWithEquatableElement.Iterator.Element,
extractValueFromEquatable: ((CollectionWithEquatableElement.Iterator.Element) -> MinimalEquatableValue),
extractValueFromEquatable: @escaping ((CollectionWithEquatableElement.Iterator.Element) -> MinimalEquatableValue),

makeCollectionOfComparable: @escaping ([CollectionWithComparableElement.Iterator.Element]) -> CollectionWithComparableElement,
wrapValueIntoComparable: @escaping (MinimalComparableValue) -> CollectionWithComparableElement.Iterator.Element,
extractValueFromComparable: ((CollectionWithComparableElement.Iterator.Element) -> MinimalComparableValue),
extractValueFromComparable: @escaping ((CollectionWithComparableElement.Iterator.Element) -> MinimalComparableValue),

resiliencyChecks: CollectionMisuseResiliencyChecks = .all,
outOfBoundsIndexOffset: Int = 1,
Expand Down Expand Up @@ -928,8 +928,8 @@ self.test("\(testNamePrefix).partition/DispatchesThrough_withUnsafeMutableBuffer

func checkSortInPlace_${'Predicate' if predicate else 'WhereElementIsComparable'}(
sequence: [Int],
equalImpl: ((Int, Int) -> Bool),
lessImpl: ((Int, Int) -> Bool),
equalImpl: @escaping ((Int, Int) -> Bool),
lessImpl: @escaping ((Int, Int) -> Bool),
verifyOrder: Bool
) {
% if predicate:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -456,7 +456,7 @@ extension TestSuite {

makeCollectionOfEquatable: @escaping ([CollectionWithEquatableElement.Iterator.Element]) -> CollectionWithEquatableElement,
wrapValueIntoEquatable: @escaping (MinimalEquatableValue) -> CollectionWithEquatableElement.Iterator.Element,
extractValueFromEquatable: ((CollectionWithEquatableElement.Iterator.Element) -> MinimalEquatableValue),
extractValueFromEquatable: @escaping ((CollectionWithEquatableElement.Iterator.Element) -> MinimalEquatableValue),

resiliencyChecks: CollectionMisuseResiliencyChecks = .all,
outOfBoundsIndexOffset: Int = 1,
Expand Down Expand Up @@ -1171,7 +1171,7 @@ self.test("\(testNamePrefix).OperatorPlus") {

makeCollectionOfEquatable: @escaping ([CollectionWithEquatableElement.Iterator.Element]) -> CollectionWithEquatableElement,
wrapValueIntoEquatable: @escaping (MinimalEquatableValue) -> CollectionWithEquatableElement.Iterator.Element,
extractValueFromEquatable: ((CollectionWithEquatableElement.Iterator.Element) -> MinimalEquatableValue),
extractValueFromEquatable: @escaping ((CollectionWithEquatableElement.Iterator.Element) -> MinimalEquatableValue),

resiliencyChecks: CollectionMisuseResiliencyChecks = .all,
outOfBoundsIndexOffset: Int = 1
Expand Down Expand Up @@ -1303,7 +1303,7 @@ self.test("\(testNamePrefix).removeLast(n: Int)/whereIndexIsBidirectional/remove

makeCollectionOfEquatable: @escaping ([CollectionWithEquatableElement.Iterator.Element]) -> CollectionWithEquatableElement,
wrapValueIntoEquatable: @escaping (MinimalEquatableValue) -> CollectionWithEquatableElement.Iterator.Element,
extractValueFromEquatable: ((CollectionWithEquatableElement.Iterator.Element) -> MinimalEquatableValue),
extractValueFromEquatable: @escaping ((CollectionWithEquatableElement.Iterator.Element) -> MinimalEquatableValue),

resiliencyChecks: CollectionMisuseResiliencyChecks = .all,
outOfBoundsIndexOffset: Int = 1
Expand Down
Loading