Skip to content

[AST/Sema] Introduce a new type that has associated location in source #77797

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Dec 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions include/swift/AST/TypeNodes.def
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,7 @@ UNCHECKED_TYPE(ErrorUnion, Type)
ALWAYS_CANONICAL_TYPE(Integer, Type)
ABSTRACT_SUGARED_TYPE(Sugar, Type)
SUGARED_TYPE(TypeAlias, SugarType)
SUGARED_TYPE(Locatable, SugarType)
ABSTRACT_SUGARED_TYPE(SyntaxSugar, SugarType)
ABSTRACT_SUGARED_TYPE(UnarySyntaxSugar, SyntaxSugarType)
SUGARED_TYPE(ArraySlice, UnarySyntaxSugarType)
Expand Down
13 changes: 13 additions & 0 deletions include/swift/AST/TypeTransform.h
Original file line number Diff line number Diff line change
Expand Up @@ -533,6 +533,19 @@ case TypeKind::Id:
newUnderlyingTy);
}

case TypeKind::Locatable: {
auto locatable = cast<LocatableType>(base);
Type oldUnderlyingTy = Type(locatable->getSinglyDesugaredType());
Type newUnderlyingTy = doIt(oldUnderlyingTy, pos);
if (!newUnderlyingTy)
return Type();

if (oldUnderlyingTy.getPointer() == newUnderlyingTy.getPointer())
return t;

return LocatableType::get(locatable->getLoc(), newUnderlyingTy);
}

case TypeKind::ErrorUnion: {
auto errorUnion = cast<ErrorUnionType>(base);
bool anyChanged = false;
Expand Down
23 changes: 23 additions & 0 deletions include/swift/AST/Types.h
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ enum class ReferenceCounting : uint8_t;
enum class ResilienceExpansion : unsigned;
class SILModule;
class SILType;
class SourceLoc;
class TypeAliasDecl;
class TypeDecl;
class NominalTypeDecl;
Expand Down Expand Up @@ -2322,6 +2323,28 @@ class TypeAliasType final
}
};

/// A type has been introduced at some fixed location in the AST.
class LocatableType final : public SugarType, public llvm::FoldingSetNode {
SourceLoc Loc;

LocatableType(SourceLoc loc, Type underlying,
RecursiveTypeProperties properties);

public:
SourceLoc getLoc() const { return Loc; }

static LocatableType *get(SourceLoc loc, Type underlying);

void Profile(llvm::FoldingSetNodeID &id) const;

static void Profile(llvm::FoldingSetNodeID &id, SourceLoc loc,
Type underlying);

static bool classof(const TypeBase *T) {
return T->getKind() == TypeKind::Locatable;
}
};

/// The various spellings of ownership modifier that can be used in source.
enum class ParamSpecifier : uint8_t {
/// No explicit ownership specifier was provided. The parameter will use the
Expand Down
2 changes: 2 additions & 0 deletions include/swift/Sema/ConstraintSystem.h
Original file line number Diff line number Diff line change
Expand Up @@ -3516,6 +3516,8 @@ class ConstraintSystem {
return !solverState || solverState->recordFixes;
}

bool inSalvageMode() const { return solverState && solverState->recordFixes; }

ArrayRef<ConstraintFix *> getFixes() const { return Fixes.getArrayRef(); }

bool shouldSuppressDiagnostics() const {
Expand Down
41 changes: 41 additions & 0 deletions lib/AST/ASTContext.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -530,6 +530,7 @@ struct ASTContext::Implementation {

llvm::DenseMap<Type, ErrorType *> ErrorTypesWithOriginal;
llvm::FoldingSet<TypeAliasType> TypeAliasTypes;
llvm::FoldingSet<LocatableType> LocatableTypes;
llvm::FoldingSet<TupleType> TupleTypes;
llvm::FoldingSet<PackType> PackTypes;
llvm::FoldingSet<PackExpansionType> PackExpansionTypes;
Expand Down Expand Up @@ -3270,6 +3271,7 @@ void ASTContext::Implementation::Arena::dump(llvm::raw_ostream &os) const {

SIZE_AND_BYTES(ErrorTypesWithOriginal);
SIZE(TypeAliasTypes);
SIZE(LocatableTypes);
SIZE(TupleTypes);
SIZE(PackTypes);
SIZE(PackExpansionTypes);
Expand Down Expand Up @@ -3452,6 +3454,45 @@ void TypeAliasType::Profile(
id.AddPointer(underlying.getPointer());
}

LocatableType::LocatableType(SourceLoc loc, Type underlying,
RecursiveTypeProperties properties)
: SugarType(TypeKind::Locatable, underlying, properties), Loc(loc) {
ASSERT(loc.isValid());
}

LocatableType *LocatableType::get(SourceLoc loc, Type underlying) {
auto properties = underlying->getRecursiveProperties();

// Figure out which arena this type will go into.
auto &ctx = underlying->getASTContext();
auto arena = getArena(properties);

// Profile the type.
llvm::FoldingSetNodeID id;
LocatableType::Profile(id, loc, underlying);

// Did we already record this type?
void *insertPos;
auto &types = ctx.getImpl().getArena(arena).LocatableTypes;
if (auto result = types.FindNodeOrInsertPos(id, insertPos))
return result;

// Build a new type.
auto result = new (ctx, arena) LocatableType(loc, underlying, properties);
types.InsertNode(result, insertPos);
return result;
}

void LocatableType::Profile(llvm::FoldingSetNodeID &id) const {
Profile(id, Loc, Type(getSinglyDesugaredType()));
}

void LocatableType::Profile(llvm::FoldingSetNodeID &id, SourceLoc loc,
Type underlying) {
id.AddPointer(loc.getOpaquePointerValue());
id.AddPointer(underlying.getPointer());
}

// Simple accessors.
Type ErrorType::get(const ASTContext &C) { return C.TheErrorType; }

Expand Down
12 changes: 12 additions & 0 deletions lib/AST/ASTDumper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4098,6 +4098,18 @@ namespace {
printFoot();
}

void visitLocatableType(LocatableType *T, StringRef label) {
printCommon("locatable_type", label);
printFieldQuotedRaw(
[&](raw_ostream &OS) {
auto &C = T->getASTContext();
T->getLoc().print(OS, C.SourceMgr);
},
"loc");
printRec(T->getSinglyDesugaredType(), "underlying");
printFoot();
}

void visitPackType(PackType *T, StringRef label) {
printCommon("pack_type", label);

Expand Down
1 change: 1 addition & 0 deletions lib/AST/ASTMangler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1296,6 +1296,7 @@ void ASTMangler::appendType(Type type, GenericSignature sig,
case TypeKind::BuiltinUnsafeValueBuffer:
return appendOperator("BB");
case TypeKind::BuiltinUnboundGeneric:
case TypeKind::Locatable:
llvm_unreachable("not a real type");
case TypeKind::BuiltinFixedArray: {
auto bfa = cast<BuiltinFixedArrayType>(tybase);
Expand Down
4 changes: 4 additions & 0 deletions lib/AST/ASTPrinter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6009,6 +6009,10 @@ class TypePrinter : public TypeVisitor<TypePrinter> {
}
}

void visitLocatableType(LocatableType *T) {
visit(T->getSinglyDesugaredType());
}

void visitPackType(PackType *T) {
if (Options.PrintExplicitPackTypes || Options.PrintTypesForDebugging)
Printer << "Pack{";
Expand Down
2 changes: 2 additions & 0 deletions lib/AST/Type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1921,6 +1921,8 @@ Type SugarType::getSinglyDesugaredTypeSlow() {
#include "swift/AST/TypeNodes.def"
case TypeKind::TypeAlias:
llvm_unreachable("bound type alias types always have an underlying type");
case TypeKind::Locatable:
llvm_unreachable("locatable types always have an underlying type");
case TypeKind::ArraySlice:
case TypeKind::VariadicSequence:
implDecl = Context->getArrayDecl();
Expand Down
1 change: 1 addition & 0 deletions lib/AST/TypeWalker.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ class Traversal : public TypeVisitor<Traversal, bool>
return false;

}
bool visitLocatableType(LocatableType *ty) { return false; }
bool visitSILTokenType(SILTokenType *ty) { return false; }

bool visitPackType(PackType *ty) {
Expand Down
6 changes: 6 additions & 0 deletions lib/IRGen/IRGenDebugInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2213,6 +2213,12 @@ createSpecializedStructOrClassType(NominalOrBoundGenericNominalType *Type,
L.File, 0, Scope);
}

case TypeKind::Locatable: {
auto *Sugar = cast<LocatableType>(BaseTy);
auto *CanTy = Sugar->getSinglyDesugaredType();
return getOrCreateDesugaredType(CanTy, DbgTy);
}

// SyntaxSugarType derivations.
case TypeKind::Dictionary:
case TypeKind::ArraySlice:
Expand Down
7 changes: 7 additions & 0 deletions lib/Sema/CSDiagnostics.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -611,6 +611,13 @@ void RequirementFailure::maybeEmitRequirementNote(const Decl *anchor, Type lhs,
req.getFirstType(), lhs, req.getSecondType(), rhs);
}

SourceLoc MissingConformanceFailure::getLoc() const {
if (auto *locatable = dyn_cast<LocatableType>(LHS.getPointer())) {
return locatable->getLoc();
}
return RequirementFailure::getLoc();
}

bool MissingConformanceFailure::diagnoseAsError() {
auto anchor = getAnchor();
auto nonConformingType = getLHS();
Expand Down
6 changes: 6 additions & 0 deletions lib/Sema/CSDiagnostics.h
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,10 @@ class RequirementFailure : public FailureDiagnostic {
Apply = dyn_cast<ApplyExpr>(parentExpr);
}

virtual SourceLoc getLoc() const override {
return FailureDiagnostic::getLoc();
}

unsigned getRequirementIndex() const {
auto reqElt =
getLocator()->castLastElementTo<LocatorPathElt::AnyRequirement>();
Expand Down Expand Up @@ -341,6 +345,8 @@ class MissingConformanceFailure final : public RequirementFailure {
#endif
}

virtual SourceLoc getLoc() const override;

bool diagnoseAsError() override;

protected:
Expand Down
21 changes: 21 additions & 0 deletions lib/Sema/CSGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2687,6 +2687,21 @@ namespace {
varType = TypeChecker::getOptionalType(var->getLoc(), varType);
}

auto makeTypeLocatableIfPossible = [&var](Type type) -> Type {
if (auto loc = var->getLoc()) {
return LocatableType::get(loc, type);
}
return type;
};

auto useLocatableTypes = [&]() -> bool {
if (!CS.inSalvageMode())
return false;

return var->isImplicit() &&
var->getNameStr().starts_with("$__builder");
};

// When we are supposed to bind pattern variables, create a fresh
// type variable and a one-way constraint to assign it to either the
// deduced type or the externally-imposed type.
Expand All @@ -2711,6 +2726,9 @@ namespace {

CS.addConstraint(ConstraintKind::OneWayEqual, oneWayVarType,
varType, locator);

if (useLocatableTypes())
oneWayVarType = makeTypeLocatableIfPossible(oneWayVarType);
}

// Ascribe a type to the declaration so it's always available to
Expand Down Expand Up @@ -2772,6 +2790,9 @@ namespace {
CS.getConstraintLocator(locator));
}

if (useLocatableTypes())
declTy = makeTypeLocatableIfPossible(declTy);

CS.setType(var, declTy);
}

Expand Down
39 changes: 24 additions & 15 deletions lib/Sema/CSSimplify.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8427,11 +8427,10 @@ ConstraintSystem::SolutionKind ConstraintSystem::simplifyConformsToConstraint(

// We sometimes get a pack expansion type here.
if (auto *expansionType = type->getAs<PackExpansionType>()) {
// FIXME: Locator
addConstraint(ConstraintKind::ConformsTo,
expansionType->getPatternType(),
protocol->getDeclaredInterfaceType(),
locator);
addConstraint(
ConstraintKind::ConformsTo, expansionType->getPatternType(),
protocol->getDeclaredInterfaceType(),
locator.withPathElement(LocatorPathElt::PackExpansionPattern()));

return SolutionKind::Solved;
}
Expand Down Expand Up @@ -8625,17 +8624,27 @@ ConstraintSystem::SolutionKind ConstraintSystem::simplifyConformsToConstraint(
return recordFix(fix) ? SolutionKind::Error : SolutionKind::Solved;
}

// If we have something like ... -> type req # -> pack element #, we're
// solving a requirement of the form T : P where T is a type parameter pack
if (path.back().is<LocatorPathElt::PackElement>())
path.pop_back();
// Conditional conformance requirements could produce chains of
// `path element -> pack expansion pattern -> pack element`.
while (!path.empty()) {
// If we have something like ... -> type req # -> pack element #, we're
// solving a requirement of the form T : P where T is a type parameter pack
if (path.back().is<LocatorPathElt::PackElement>()) {
path.pop_back();
continue;
}

// This is similar to `PackElement` but locator points to the requirement
// associted with pack expansion pattern (i.e. `repeat each T: P`) where
// the path is something like:
// `... -> type req # -> pack expansion pattern`.
if (path.back().is<LocatorPathElt::PackExpansionPattern>())
path.pop_back();
// This is similar to `PackElement` but locator points to the requirement
// associated with pack expansion pattern (i.e. `repeat each T: P`) where
// the path is something like:
// `... -> type req # -> pack expansion pattern`.
if (path.back().is<LocatorPathElt::PackExpansionPattern>()) {
path.pop_back();
continue;
}

break;
}

if (auto req = path.back().getAs<LocatorPathElt::AnyRequirement>()) {
// If this is a requirement associated with `Self` which is bound
Expand Down
4 changes: 4 additions & 0 deletions lib/Serialization/Serialization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5443,6 +5443,10 @@ class Serializer::TypeSerializer : public TypeVisitor<TypeSerializer> {
llvm_unreachable("error union types do not persist in the AST");
}

void visitLocatableType(const LocatableType *) {
llvm_unreachable("locatable types do not persist in the AST");
}

void visitBuiltinTypeImpl(Type ty) {
using namespace decls_block;
TypeAliasDecl *typeAlias =
Expand Down
30 changes: 26 additions & 4 deletions test/Constraints/result_builder_diags.swift
Original file line number Diff line number Diff line change
Expand Up @@ -629,15 +629,37 @@ _ = wrapperifyInfer(true) { x in // Ok

struct DoesNotConform {}

struct List<C> {
typealias T = C

init(@TupleBuilder _: () -> C) {}
}

extension List: P where C: P {}

struct MyView {
@TupleBuilder var value: some P { // expected-error {{return type of property 'value' requires that 'DoesNotConform' conform to 'P'}}
struct Conforms : P {
typealias T = Void
}

@TupleBuilder var value: some P {
// expected-note@-1 {{opaque return type declared here}}
DoesNotConform() // expected-error {{return type of property 'value' requires that 'DoesNotConform' conform to 'P'}}
}

@TupleBuilder var nestedFailures: some P {
// expected-note@-1 {{opaque return type declared here}}
DoesNotConform()
List {
List {
DoesNotConform()
// expected-error@-1 {{return type of property 'nestedFailures' requires that 'DoesNotConform' conform to 'P'}}
}
}
}

@TupleBuilder func test() -> some P { // expected-error {{return type of instance method 'test()' requires that 'DoesNotConform' conform to 'P'}}
@TupleBuilder func test() -> some P {
// expected-note@-1 {{opaque return type declared here}}
DoesNotConform()
DoesNotConform() // expected-error {{return type of instance method 'test()' requires that 'DoesNotConform' conform to 'P'}}
}

@TupleBuilder var emptySwitch: some P {
Expand Down
Loading