Skip to content

[Serialization] Drop extensions whose base type can't be deserialized. #11323

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions include/swift/Serialization/ModuleFormat.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ const uint16_t VERSION_MAJOR = 0;
/// in source control, you should also update the comment to briefly
/// describe what change you made. The content of this comment isn't important;
/// it just ensures a conflict if two people change the module format.
const uint16_t VERSION_MINOR = 354; // Last change: special destructor names
const uint16_t VERSION_MINOR = 355; // Last change: extension dependencies

using DeclID = PointerEmbeddedInt<unsigned, 31>;
using DeclIDField = BCFixed<31>;
Expand Down Expand Up @@ -1038,7 +1038,8 @@ namespace decls_block {
BCFixed<1>, // implicit flag
GenericEnvironmentIDField, // generic environment
BCVBR<4>, // # of protocol conformances
BCArray<TypeIDField> // inherited types
BCVBR<4>, // number of inherited types
BCArray<TypeIDField> // inherited types, followed by TypeID dependencies
// Trailed by the generic parameter lists, members record, and then
// conformance info (if any).
>;
Expand Down
22 changes: 17 additions & 5 deletions lib/Serialization/Deserialization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,8 @@ const char OverrideError::ID = '\0';
void OverrideError::anchor() {}
const char TypeError::ID = '\0';
void TypeError::anchor() {}
const char ExtensionError::ID = '\0';
void ExtensionError::anchor() {}

LLVM_NODISCARD
static std::unique_ptr<llvm::ErrorInfoBase> takeErrorInfo(llvm::Error error) {
Expand Down Expand Up @@ -3472,14 +3474,24 @@ ModuleFile::getDeclChecked(DeclID DID, Optional<DeclContext *> ForcedContext) {
DeclContextID contextID;
bool isImplicit;
GenericEnvironmentID genericEnvID;
unsigned numConformances;
ArrayRef<uint64_t> rawInheritedIDs;
unsigned numConformances, numInherited;
ArrayRef<uint64_t> inheritedAndDependencyIDs;

decls_block::ExtensionLayout::readRecord(scratch, baseID, contextID,
isImplicit, genericEnvID,
numConformances, rawInheritedIDs);
numConformances, numInherited,
inheritedAndDependencyIDs);

auto DC = getDeclContext(contextID);

for (TypeID dependencyID : inheritedAndDependencyIDs.slice(numInherited)) {
auto dependency = getTypeChecked(dependencyID);
if (!dependency) {
return llvm::make_error<ExtensionError>(
takeErrorInfo(dependency.takeError()));
}
}

if (declOrOffset.isComplete())
return declOrOffset;

Expand All @@ -3505,8 +3517,8 @@ ModuleFile::getDeclChecked(DeclID DID, Optional<DeclContext *> ForcedContext) {
if (isImplicit)
extension->setImplicit();

auto inheritedTypes = ctx.Allocate<TypeLoc>(rawInheritedIDs.size());
for_each(inheritedTypes, rawInheritedIDs,
auto inheritedTypes = ctx.Allocate<TypeLoc>(numInherited);
for_each(inheritedTypes, inheritedAndDependencyIDs.slice(0, numInherited),
[this](TypeLoc &tl, uint64_t rawID) {
tl = TypeLoc::withoutLoc(getType(rawID));
});
Expand Down
24 changes: 24 additions & 0 deletions lib/Serialization/DeserializationErrors.h
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,30 @@ class TypeError : public llvm::ErrorInfo<TypeError, DeclDeserializationError> {
}
};

class ExtensionError : public llvm::ErrorInfo<ExtensionError> {
friend ErrorInfo;
static const char ID;
void anchor() override;

std::unique_ptr<ErrorInfoBase> underlyingReason;

public:
explicit ExtensionError(std::unique_ptr<ErrorInfoBase> reason)
: underlyingReason(std::move(reason)) {}

void log(raw_ostream &OS) const override {
OS << "could not deserialize extension";
if (underlyingReason) {
OS << ": ";
underlyingReason->log(OS);
}
}

std::error_code convertToErrorCode() const override {
return llvm::inconvertibleErrorCode();
}
};

class PrettyStackTraceModuleFile : public llvm::PrettyStackTraceEntry {
const char *Action;
const ModuleFile &MF;
Expand Down
32 changes: 26 additions & 6 deletions lib/Serialization/ModuleFile.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1564,15 +1564,27 @@ void ModuleFile::loadExtensions(NominalTypeDecl *nominal) {
if (nominal->getParent()->isModuleScopeContext()) {
Identifier moduleName = nominal->getParentModule()->getName();
for (auto item : *iter) {
if (item.first == moduleName.str())
(void)getDecl(item.second);
if (item.first != moduleName.str())
continue;
Expected<Decl *> declOrError = getDeclChecked(item.second);
if (!declOrError) {
if (!getContext().LangOpts.EnableDeserializationRecovery)
fatal(declOrError.takeError());
llvm::consumeError(declOrError.takeError());
}
}
} else {
std::string mangledName =
Mangle::ASTMangler().mangleNominalType(nominal);
for (auto item : *iter) {
if (item.first == mangledName)
(void)getDecl(item.second);
if (item.first != mangledName)
continue;
Expected<Decl *> declOrError = getDeclChecked(item.second);
if (!declOrError) {
if (!getContext().LangOpts.EnableDeserializationRecovery)
fatal(declOrError.takeError());
llvm::consumeError(declOrError.takeError());
}
}
}
}
Expand Down Expand Up @@ -1751,8 +1763,16 @@ void ModuleFile::getTopLevelDecls(SmallVectorImpl<Decl *> &results) {

if (ExtensionDecls) {
for (auto entry : ExtensionDecls->data()) {
for (auto item : entry)
results.push_back(getDecl(item.second));
for (auto item : entry) {
Expected<Decl *> declOrError = getDeclChecked(item.second);
if (!declOrError) {
if (!getContext().LangOpts.EnableDeserializationRecovery)
fatal(declOrError.takeError());
llvm::consumeError(declOrError.takeError());
continue;
}
results.push_back(declOrError.get());
}
}
}
}
Expand Down
15 changes: 12 additions & 3 deletions lib/Serialization/Serialization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2482,9 +2482,17 @@ void Serializer::writeDecl(const Decl *D) {
ConformanceLookupKind::All,
nullptr, /*sorted=*/true);

SmallVector<TypeID, 8> inheritedTypes;
SmallVector<TypeID, 8> inheritedAndDependencyTypes;
for (auto inherited : extension->getInherited())
inheritedTypes.push_back(addTypeRef(inherited.getType()));
inheritedAndDependencyTypes.push_back(addTypeRef(inherited.getType()));
size_t numInherited = inheritedAndDependencyTypes.size();

// FIXME: Figure out what to do with requirements and such, which the
// extension also depends on. Right now just do what is safe to drop, which
// is the base declaration.
auto dependencies = collectDependenciesFromType(baseTy);
for (auto dependencyTy : dependencies)
inheritedAndDependencyTypes.push_back(addTypeRef(dependencyTy));

unsigned abbrCode = DeclTypeAbbrCodes[ExtensionLayout::Code];
ExtensionLayout::emitRecord(Out, ScratchRecord, abbrCode,
Expand All @@ -2494,7 +2502,8 @@ void Serializer::writeDecl(const Decl *D) {
addGenericEnvironmentRef(
extension->getGenericEnvironment()),
conformances.size(),
inheritedTypes);
numInherited,
inheritedAndDependencyTypes);

bool isClassExtension = false;
if (baseNominal) {
Expand Down
39 changes: 39 additions & 0 deletions test/Serialization/Recovery/typedefs.swift
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,17 @@ let _ = unwrapped // okay
_ = usesWrapped(nil) // expected-error {{use of unresolved identifier 'usesWrapped'}}
_ = usesUnwrapped(nil) // expected-error {{nil is not compatible with expected argument type 'Int32'}}

func testExtensions(wrapped: WrappedInt, unwrapped: UnwrappedInt) {
wrapped.wrappedMethod() // expected-error {{value of type 'WrappedInt' (aka 'Int32') has no member 'wrappedMethod'}}
unwrapped.unwrappedMethod() // expected-error {{value of type 'UnwrappedInt' has no member 'unwrappedMethod'}}

***wrapped // This one works because of the UnwrappedInt extension.
***unwrapped // expected-error {{cannot convert value of type 'UnwrappedInt' to expected argument type 'Int32'}}

let _: WrappedProto = wrapped // expected-error {{value of type 'WrappedInt' (aka 'Int32') does not conform to specified type 'WrappedProto'}}
let _: UnwrappedProto = unwrapped // expected-error {{value of type 'UnwrappedInt' does not conform to specified type 'UnwrappedProto'}}
}

public class UserDynamicSub: UserDynamic {
override init() {}
}
Expand All @@ -72,6 +83,31 @@ public class UserSub : User {} // expected-error {{cannot inherit from class 'Us

import Typedefs

prefix operator ***

// CHECK-LABEL: extension WrappedInt : WrappedProto {
// CHECK-NEXT: func wrappedMethod()
// CHECK-NEXT: prefix static func ***(x: WrappedInt)
// CHECK-NEXT: }
// CHECK-RECOVERY-NEGATIVE-NOT: extension WrappedInt
extension WrappedInt: WrappedProto {
public func wrappedMethod() {}
public static prefix func ***(x: WrappedInt) {}
}
// CHECK-LABEL: extension Int32 : UnwrappedProto {
// CHECK-NEXT: func unwrappedMethod()
// CHECK-NEXT: prefix static func ***(x: UnwrappedInt)
// CHECK-NEXT: }
// CHECK-RECOVERY-LABEL: extension Int32 : UnwrappedProto {
// CHECK-RECOVERY-NEXT: func unwrappedMethod()
// CHECK-RECOVERY-NEXT: prefix static func ***(x: Int32)
// CHECK-RECOVERY-NEXT: }
// CHECK-RECOVERY-NEGATIVE-NOT: extension UnwrappedInt
extension UnwrappedInt: UnwrappedProto {
public func unwrappedMethod() {}
public static prefix func ***(x: UnwrappedInt) {}
}

// CHECK-LABEL: class User {
// CHECK-RECOVERY-LABEL: class User {
open class User {
Expand Down Expand Up @@ -305,4 +341,7 @@ public func returnsWrapped() -> WrappedInt { fatalError() }
// CHECK-RECOVERY-NEGATIVE-NOT: func returnsWrappedGeneric<
public func returnsWrappedGeneric<T>(_: T.Type) -> WrappedInt { fatalError() }

public protocol WrappedProto {}
public protocol UnwrappedProto {}

#endif // TEST