Skip to content

Sema: Fix ASTPrinter and opened element environments to handle nested pack expansion types [5.9] #67310

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions include/swift/AST/Type.h
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,14 @@ class Type {
llvm::function_ref<llvm::Optional<Type>(TypeBase *, TypePosition)> fn)
const;

/// Transform free pack element references, that is, those not captured by a
/// pack expansion.
///
/// This is the 'map' counterpart to TypeBase::getTypeParameterPacks().
Type transformTypeParameterPacks(
llvm::function_ref<llvm::Optional<Type>(SubstitutableType *)> fn)
const;

/// Look through the given type and its children and apply fn to them.
void visit(llvm::function_ref<void (Type)> fn) const {
findIf([&fn](Type t) -> bool {
Expand Down
6 changes: 3 additions & 3 deletions lib/AST/ASTContext.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5856,13 +5856,13 @@ ASTContext::getOpenedElementSignature(CanGenericSignature baseGenericSig,
}

auto eraseParameterPackRec = [&](Type type) -> Type {
return type.transformRec([&](Type t) -> llvm::Optional<Type> {
if (auto *paramType = t->getAs<GenericTypeParamType>()) {
return type.transformTypeParameterPacks([&](SubstitutableType *t) -> llvm::Optional<Type> {
if (auto *paramType = dyn_cast<GenericTypeParamType>(t)) {
if (packElementParams.find(paramType) != packElementParams.end()) {
return Type(packElementParams[paramType]);
}

return t;
return Type(t);
}
return llvm::None;
});
Expand Down
23 changes: 17 additions & 6 deletions lib/AST/ASTPrinter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -804,7 +804,7 @@ class PrintAST : public ASTVisitor<PrintAST> {

void printType(Type T) { printTypeWithOptions(T, Options); }

void printTransformedTypeWithOptions(Type T, PrintOptions options) {
Type getTransformedType(Type T) {
if (CurrentType && Current && CurrentType->mayHaveMembers()) {
auto *M = Current->getDeclContext()->getParentModule();
SubstitutionMap subMap;
Expand All @@ -825,9 +825,16 @@ class PrintAST : public ASTVisitor<PrintAST> {
}

T = T.subst(subMap, SubstFlags::DesugarMemberTypes);
}

return T;
}

void printTransformedTypeWithOptions(Type T, PrintOptions options) {
T = getTransformedType(T);

if (CurrentType && Current && CurrentType->mayHaveMembers())
options.TransformContext = TypeTransformContext(CurrentType);
}

printTypeWithOptions(T, options);
}
Expand Down Expand Up @@ -1833,6 +1840,11 @@ void PrintAST::printSingleDepthOfGenericSignature(
}

void PrintAST::printRequirement(const Requirement &req) {
SmallVector<Type, 2> rootParameterPacks;
getTransformedType(req.getFirstType())
->getTypeParameterPacks(rootParameterPacks);
bool isPackRequirement = !rootParameterPacks.empty();

switch (req.getKind()) {
case RequirementKind::SameShape:
Printer << "(repeat (";
Expand All @@ -1842,22 +1854,21 @@ void PrintAST::printRequirement(const Requirement &req) {
Printer << ")) : Any";
return;
case RequirementKind::Layout:
if (req.getFirstType()->hasParameterPack())
if (isPackRequirement)
Printer << "repeat ";
printTransformedType(req.getFirstType());
Printer << " : ";
req.getLayoutConstraint()->print(Printer, Options);
return;
case RequirementKind::Conformance:
case RequirementKind::Superclass:
if (req.getFirstType()->hasParameterPack())
if (isPackRequirement)
Printer << "repeat ";
printTransformedType(req.getFirstType());
Printer << " : ";
break;
case RequirementKind::SameType:
if (req.getFirstType()->hasParameterPack() ||
req.getSecondType()->hasParameterPack())
if (isPackRequirement)
Printer << "repeat ";
printTransformedType(req.getFirstType());
Printer << " == ";
Expand Down
58 changes: 58 additions & 0 deletions lib/AST/ParameterPack.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,64 @@

using namespace swift;

/// FV(PackExpansionType(Pattern, Count), N) = FV(Pattern, N+1)
/// FV(PackElementType(Param, M), N) = FV(Param, 0) if M >= N, {} otherwise
/// FV(Param, N) = {Param}
static Type transformTypeParameterPacksRec(
Type t, llvm::function_ref<llvm::Optional<Type>(SubstitutableType *)> fn,
unsigned expansionLevel) {
return t.transformWithPosition(
TypePosition::Invariant,
[&](TypeBase *t, TypePosition p) -> llvm::Optional<Type> {

// If we're already inside N levels of PackExpansionType, and we're
// walking into another PackExpansionType, a type parameter pack
// reference now needs level (N+1) to be free.
if (auto *expansionType = dyn_cast<PackExpansionType>(t)) {
auto countType = expansionType->getCountType();
auto patternType = expansionType->getPatternType();
auto newPatternType = transformTypeParameterPacksRec(
patternType, fn, expansionLevel + 1);
if (patternType.getPointer() != newPatternType.getPointer())
return Type(PackExpansionType::get(patternType, countType));

return Type(expansionType);
}

// A PackElementType with level N reaches past N levels of
// nested PackExpansionType. So a type parameter pack reference
// therein is free if N is greater than or equal to our current
// expansion level.
if (auto *eltType = dyn_cast<PackElementType>(t)) {
if (eltType->getLevel() >= expansionLevel) {
return transformTypeParameterPacksRec(eltType->getPackType(), fn,
/*expansionLevel=*/0);
}

return Type(eltType);
}

// A bare type parameter pack is like a PackElementType with level 0.
if (auto *paramType = dyn_cast<SubstitutableType>(t)) {
if (expansionLevel == 0 &&
(isa<PackArchetypeType>(paramType) ||
(isa<GenericTypeParamType>(paramType) &&
cast<GenericTypeParamType>(paramType)->isParameterPack()))) {
return fn(paramType);
}

return Type(paramType);
}

return llvm::None;
});
}

Type Type::transformTypeParameterPacks(
llvm::function_ref<llvm::Optional<Type>(SubstitutableType *)> fn) const {
return transformTypeParameterPacksRec(*this, fn, /*expansionLevel=*/0);
}

namespace {

/// Collects all unique pack type parameters referenced from the pattern type,
Expand Down
7 changes: 5 additions & 2 deletions lib/Sema/ConstraintSystem.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1009,8 +1009,11 @@ Type ConstraintSystem::openType(Type type, OpenedTypeMap &replacements,
// that gets introduced by the interface type, see
// \c openUnboundGenericType for more details.
if (auto *packTy = type->getAs<PackType>()) {
if (auto expansion = packTy->unwrapSingletonPackExpansion())
type = expansion->getPatternType();
if (auto expansionTy = packTy->unwrapSingletonPackExpansion()) {
auto patternTy = expansionTy->getPatternType();
if (patternTy->isTypeParameter())
return openType(patternTy, replacements, locator);
}
}

if (auto *expansion = type->getAs<PackExpansionType>()) {
Expand Down
9 changes: 9 additions & 0 deletions test/ModuleInterface/pack_expansion_type.swift
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
// RUN: %target-swift-emit-module-interface(%t/PackExpansionType.swiftinterface) %s -module-name PackExpansionType -disable-availability-checking
// RUN: %FileCheck %s < %t/PackExpansionType.swiftinterface

/// Requirements

// CHECK: #if compiler(>=5.3) && $ParameterPacks
// CHECK-NEXT: public func variadicFunction<each T, each U>(t: repeat each T, u: repeat each U) -> (repeat (each T, each U)) where (repeat (each T, each U)) : Any
public func variadicFunction<each T, each U>(t: repeat each T, u: repeat each U) -> (repeat (each T, each U)) {}
Expand All @@ -24,6 +26,13 @@ public struct VariadicType<each T> {
// CHECK: }
// CHECK-NEXT: #endif

// The second requirement should not be prefixed with 'repeat'
// CHECK: public struct SameTypeReq<T, each U> where T : Swift.Sequence, T.Element == PackExpansionType.VariadicType<repeat each U> {
public struct SameTypeReq<T: Sequence, each U> where T.Element == VariadicType<repeat each U> {}
// CHECK: }

/// Pack expansion types

// CHECK: public func returnsVariadicType() -> PackExpansionType.VariadicType<>
public func returnsVariadicType() -> VariadicType< > {}

Expand Down
17 changes: 17 additions & 0 deletions validation-test/compiler_crashers_2_fixed/rdar112108253.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
// RUN: %target-swift-frontend -emit-ir %s -disable-availability-checking

public protocol P {
associatedtype A
}

public struct G<each T> {}

public protocol Q {
init()
}

public struct S<T: P, each U: P>: Q where repeat T.A == G<repeat each U.A> {
public init() {}
public init(predicate: repeat each U) {}
}