Skip to content

[6.1][Package CMO] Add deserialization checks to ensure correct memory layout #78700

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions include/swift/AST/DeclContext.h
Original file line number Diff line number Diff line change
Expand Up @@ -806,6 +806,22 @@ class IterableDeclContext {
/// while skipping the body of this context.
unsigned HasDerivativeDeclarations : 1;

/// Members of a decl are deserialized lazily. This is set when
/// deserialization of all members is done, regardless of errors.
unsigned DeserializedMembers : 1;

/// Deserialization errors are attempted to be recovered later or
/// silently dropped due to `EnableDeserializationRecovery` being
/// on by default. The following flag is set when deserializing
/// members fails regardless of the `EnableDeserializationRecovery`
/// value and is used to prevent decl containing such members from
/// being accessed non-resiliently.
unsigned HasDeserializeMemberError : 1;

/// Used to track whether members of this decl and their respective
/// members were checked for deserialization errors recursively.
unsigned CheckedForDeserializeMemberError : 1;

template<class A, class B, class C>
friend struct ::llvm::CastInfo;

Expand All @@ -816,6 +832,9 @@ class IterableDeclContext {
/// Retrieve the \c ASTContext in which this iterable context occurs.
ASTContext &getASTContext() const;

void setCheckedForDeserializeMemberError(bool checked) { CheckedForDeserializeMemberError = checked; }
bool checkedForDeserializeMemberError() const { return CheckedForDeserializeMemberError; }

public:
IterableDeclContext(IterableDeclContextKind kind)
: LastDeclAndKind(nullptr, kind) {
Expand All @@ -824,6 +843,9 @@ class IterableDeclContext {
HasDerivativeDeclarations = 0;
HasNestedClassDeclarations = 0;
InFreestandingMacroArgument = 0;
DeserializedMembers = 0;
HasDeserializeMemberError = 0;
CheckedForDeserializeMemberError = 0;
}

/// Determine the kind of iterable context we have.
Expand All @@ -833,6 +855,18 @@ class IterableDeclContext {

bool hasUnparsedMembers() const;

void setDeserializedMembers(bool deserialized) { DeserializedMembers = deserialized; }
bool didDeserializeMembers() const { return DeserializedMembers; }

void setHasDeserializeMemberError(bool hasError) { HasDeserializeMemberError = hasError; }
bool hasDeserializeMemberError() const { return HasDeserializeMemberError; }

/// This recursively checks whether members of this decl and their respective
/// members were deserialized correctly and emits a diagnostic in case of an error.
/// Requires accessing module and this decl's module are in the same package,
/// and this decl's module has package optimization enabled.
void checkDeserializeMemberErrorInPackage(ModuleDecl *accessingModule);

bool maybeHasOperatorDeclarations() const {
return HasOperatorDeclarations;
}
Expand Down
5 changes: 5 additions & 0 deletions include/swift/AST/DiagnosticsSema.def
Original file line number Diff line number Diff line change
Expand Up @@ -4722,6 +4722,11 @@ NOTE(ambiguous_because_of_trailing_closure,none,
"avoid using a trailing closure}0 to call %1",
(bool, const ValueDecl *))

// In-package resilience bypassing
ERROR(cannot_bypass_resilience_due_to_missing_member,none,
"cannot bypass resilience due to member deserialization failure while attempting to access %select{member %0|missing member}1 of %2 in module %3 from module %4",
(Identifier, bool, Identifier, Identifier, Identifier))

// Cannot capture inout-ness of a parameter
// Partial application of foreign functions not supported
ERROR(partial_application_of_function_invalid,none,
Expand Down
7 changes: 7 additions & 0 deletions include/swift/Basic/LangOptions.h
Original file line number Diff line number Diff line change
Expand Up @@ -610,6 +610,13 @@ namespace swift {
/// from source.
bool AllowNonResilientAccess = false;

/// When Package CMO is enabled, deserialization checks are done to
/// ensure that the members of a decl are correctly deserialized to maintain
/// proper layout. This ensures that bypassing resilience is safe. Accessing
/// an incorrectly laid-out decl directly can lead to runtime crashes. This flag
/// should only be used temporarily during migration to enable Package CMO.
bool SkipDeserializationChecksForPackageCMO = false;

/// Enables dumping type witness systems from associated type inference.
bool DumpTypeWitnessSystems = false;

Expand Down
4 changes: 4 additions & 0 deletions include/swift/Option/Options.td
Original file line number Diff line number Diff line change
Expand Up @@ -1074,6 +1074,10 @@ def Oplayground : Flag<["-"], "Oplayground">, Group<O_Group>,
Flags<[HelpHidden, FrontendOption, ModuleInterfaceOption]>,
HelpText<"Compile with optimizations appropriate for a playground">;

def ExperimentalSkipDeserializationChecksForPackageCMO : Flag<["-"], "experimental-skip-deserialization-checks-for-package-cmo">,
Flags<[FrontendOption]>,
HelpText<"Skip deserialization checks for package-cmo; use only for experimental purposes">;

def ExperimentalPackageCMO : Flag<["-"], "experimental-package-cmo">,
Flags<[FrontendOption]>,
HelpText<"Deprecated; use -package-cmo instead">;
Expand Down
51 changes: 44 additions & 7 deletions lib/AST/Decl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4560,14 +4560,51 @@ bool ValueDecl::hasOpenAccess(const DeclContext *useDC) const {
}

bool ValueDecl::bypassResilienceInPackage(ModuleDecl *accessingModule) const {
// If the defining module is built with package-cmo, bypass
// resilient access from the use site that belongs to a module
// in the same package.
// To allow bypassing resilience when accessing this decl from another
// module, it should be in the same package as this decl's module.
auto declModule = getModuleContext();
return declModule->inSamePackage(accessingModule) &&
declModule->isResilient() &&
declModule->allowNonResilientAccess() &&
declModule->serializePackageEnabled();
if (!declModule->inSamePackage(accessingModule))
return false;
// Package optimization allows bypassing resilience, but it assumes the
// memory layout of the decl being accessed is correct. When this assumption
// fails due to a deserialization error of its members, the use site incorrectly
// accesses the layout of the decl with a wrong field offset, resulting in UB
// or a crash.
// The deserialization error is currently not caught at compile time due to
// LangOpts.EnableDeserializationRecovery being enabled by default (to allow
// for recovery of some of the deserialization errors at a later time). In case
// of member deserialization, however, it's not necessarily recovered later on
// and is silently dropped, causing UB or a crash at runtime.
// The following tracks errors in member deserialization by recursively loading
// members of a type (if not done already) and checking whether the type's
// members, and their respective types (recursively), encountered deserialization
// failures.
// If any such type is found, it fails and emits a diagnostic at compile time.
// Simply disallowing resilience bypassing for this decl here is insufficient
// because it would cause a type requirement mismatch later during SIL
// deserialiaztion; e.g. generated SIL in the imported module might contain
// an instruction that allows a direct access, while the caller in client module
// might require a non-direct access due to a deserialization error.
if (declModule->isResilient() &&
declModule->allowNonResilientAccess() &&
declModule->serializePackageEnabled()) {
// Fail and diagnose if there is a member deserialization error,
// with an option to skip for temporary/migration purposes.
if (!getASTContext().LangOpts.SkipDeserializationChecksForPackageCMO) {
// Since we're checking for deserialization errors, make sure the
// accessing module is different from this decl's module.
if (accessingModule &&
accessingModule != declModule) {
if (auto IDC = dyn_cast<IterableDeclContext>(this)) {
// Recursively check if members and their members have failing
// deserialization, and emit a diagnostic.
IDC->checkDeserializeMemberErrorInPackage(accessingModule);
}
}
}
return true;
}
return false;
}

/// Given the formal access level for using \p VD, compute the scope where
Expand Down
140 changes: 140 additions & 0 deletions lib/AST/DeclContext.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include "swift/AST/SourceFile.h"
#include "swift/AST/Types.h"
#include "swift/AST/TypeCheckRequests.h"
#include "swift/AST/DiagnosticsSema.h"
#include "swift/Basic/Assertions.h"
#include "swift/Basic/SourceManager.h"
#include "swift/Basic/Statistic.h"
Expand Down Expand Up @@ -1173,6 +1174,145 @@ void IterableDeclContext::loadAllMembers() const {
--s->getFrontendCounters().NumUnloadedLazyIterableDeclContexts;
}

// Checks whether members of this decl and their respective members
// (recursively) were deserialized correctly and emits a diagnostic
// if deserialization failed. Requires accessing module and this decl's
// module are in the same package, and this decl's module has package
// optimization enabled.
void IterableDeclContext::checkDeserializeMemberErrorInPackage(ModuleDecl *accessingModule) {
// Only check if accessing module is in the same package as this
// decl's module, which has package optimization enabled.
if (!getDecl()->getModuleContext()->inSamePackage(accessingModule) ||
!getDecl()->getModuleContext()->isResilient() ||
!getDecl()->getModuleContext()->serializePackageEnabled())
return;
// Bail if already checked for an error.
if (checkedForDeserializeMemberError())
return;
// If members were not deserialized, force load here.
if (!didDeserializeMembers()) {
// This needs to be set to force load all members if not done already.
setHasLazyMembers(true);
// Calling getMembers actually loads the members.
auto members = getMembers();
assert(!hasLazyMembers());
assert(didDeserializeMembers());
}
// Members could have been deserialized from other flows. Check
// for an error here. First mark this decl 'checked' to prevent
// infinite recursion in case of self-referencing members.
setCheckedForDeserializeMemberError(true);

// If members are already loaded above or by other flows,
// calling getMembers here should be inexpensive.
auto memberList = getMembers();

// This decl contains a member deserialization error; emit a diag.
if (hasDeserializeMemberError()) {
auto containerID = Identifier();
if (auto container = dyn_cast<NominalTypeDecl>(getDecl())) {
containerID = container->getBaseIdentifier();
}

auto foundMissing = false;
for (auto member: memberList) {
// Only storage vars can affect memory layout so
// look up pattern binding decl or var decl.
if (auto *PBD = dyn_cast<PatternBindingDecl>(member)) {
// If this pattern binding decl is empty, we have
// a missing member.
if (PBD->getNumPatternEntries() == 0)
foundMissing = true;
}
// Check if a member can be cast to MissingMemberDecl.
if (auto missingMember = dyn_cast<MissingMemberDecl>(member)) {
if (!missingMember->getName().getBaseName().isSpecial() &&
foundMissing) {
foundMissing = false;
auto missingMemberID = missingMember->getName().getBaseIdentifier();
getASTContext().Diags.diagnose(member->getLoc(),
diag::cannot_bypass_resilience_due_to_missing_member,
missingMemberID,
missingMemberID.empty(),
containerID,
getDecl()->getModuleContext()->getBaseIdentifier(),
accessingModule->getBaseIdentifier());
continue;
}
}
// If not handled above, emit a diag here.
if (foundMissing) {
getASTContext().Diags.diagnose(getDecl()->getLoc(),
diag::cannot_bypass_resilience_due_to_missing_member,
Identifier(),
true,
containerID,
getDecl()->getModuleContext()->getBaseIdentifier(),
accessingModule->getBaseIdentifier());
}
}
} else {
// This decl does not contain a member deserialization error.
// Check for members of this decl's members recursively to
// see if a member deserialization failed.
for (auto member: memberList) {
// Only storage vars can affect memory layout so
// look up pattern binding decl or var decl.
if (auto *PBD = dyn_cast<PatternBindingDecl>(member)) {
for (auto i : range(PBD->getNumPatternEntries())) {
auto pattern = PBD->getPattern(i);
pattern->forEachVariable([&](const VarDecl *VD) {
// Bail if the var is static or has no storage
if (VD->isStatic() ||
!VD->hasStorageOrWrapsStorage())
return;
// Unwrap in case this var is optional.
auto varType = VD->getInterfaceType()->getCanonicalType();
if (auto unwrapped = varType->getCanonicalType()->getOptionalObjectType()) {
varType = unwrapped->getCanonicalType();
}
// Handle BoundGenericType, e.g. [Foo: Bar], Dictionary<Foo, Bar>.
// Check for its arguments types, i.e. Foo, Bar.
if (auto boundGeneric = varType->getAs<BoundGenericType>()) {
for (auto arg : boundGeneric->getGenericArgs()) {
if (auto argNominal = arg->getNominalOrBoundGenericNominal()) {
if (auto argIDC = dyn_cast<IterableDeclContext>(argNominal)) {
argIDC->checkDeserializeMemberErrorInPackage(getDecl()->getModuleContext());
if (argIDC->hasDeserializeMemberError()) {
setHasDeserializeMemberError(true);
break;
}
}
}
}
} else if (auto tupleType = varType->getAs<TupleType>()) {
// Handle TupleType, e.g. (Foo, Var).
for (auto element : tupleType->getElements()) {
if (auto elementNominal = element.getType()->getNominalOrBoundGenericNominal()) {
if (auto elementIDC = dyn_cast<IterableDeclContext>(elementNominal)) {
elementIDC->checkDeserializeMemberErrorInPackage(getDecl()->getModuleContext());
if (elementIDC->hasDeserializeMemberError()) {
setHasDeserializeMemberError(true);
break;
}
}
}
}
} else if (auto varNominal = varType->getNominalOrBoundGenericNominal()) {
if (auto varIDC = dyn_cast<IterableDeclContext>(varNominal)) {
varIDC->checkDeserializeMemberErrorInPackage(getDecl()->getModuleContext());
if (varIDC->hasDeserializeMemberError()) {
setHasDeserializeMemberError(true);
}
}
}
});
}
}
}
}
}

bool IterableDeclContext::wasDeserialized() const {
const DeclContext *DC = getAsGenericContext();
if (auto F = dyn_cast<FileUnit>(DC->getModuleScopeContext())) {
Expand Down
3 changes: 1 addition & 2 deletions lib/AST/TypeSubstitution.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -997,8 +997,7 @@ ReplaceOpaqueTypesWithUnderlyingTypes::shouldPerformSubstitution(
// resilient expansion if the context's and the opaque type's module are in
// the same package.
if (contextExpansion == ResilienceExpansion::Maximal &&
module->isResilient() && module->serializePackageEnabled() &&
module->inSamePackage(contextModule))
namingDecl->bypassResilienceInPackage(contextModule))
return OpaqueSubstitutionKind::SubstituteSamePackageMaximalResilience;

// Allow general replacement from non resilient modules. Otherwise, disallow.
Expand Down
7 changes: 7 additions & 0 deletions lib/ClangImporter/ImportDecl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9672,17 +9672,22 @@ ClangImporter::Implementation::loadAllMembers(Decl *D, uint64_t extra) {
// Check whether we're importing an Objective-C container of some sort.
auto objcContainer =
dyn_cast_or_null<clang::ObjCContainerDecl>(D->getClangDecl());
auto *IDC = dyn_cast<IterableDeclContext>(D);

// If not, we're importing globals-as-members into an extension.
if (objcContainer) {
loadAllMembersOfSuperclassIfNeeded(dyn_cast<ClassDecl>(D));
loadAllMembersOfObjcContainer(D, objcContainer);
if (IDC) // Set member deserialization status
IDC->setDeserializedMembers(true);
return;
}

if (isa_and_nonnull<clang::RecordDecl>(D->getClangDecl())) {
loadAllMembersOfRecordDecl(cast<NominalTypeDecl>(D),
cast<clang::RecordDecl>(D->getClangDecl()));
if (IDC) // Set member deserialization status
IDC->setDeserializedMembers(true);
return;
}

Expand All @@ -9693,6 +9698,8 @@ ClangImporter::Implementation::loadAllMembers(Decl *D, uint64_t extra) {
}

loadAllMembersIntoExtension(D, extra);
if (IDC) // Set member deserialization status
IDC->setDeserializedMembers(true);
}

void ClangImporter::Implementation::loadAllMembersIntoExtension(
Expand Down
1 change: 1 addition & 0 deletions lib/Frontend/CompilerInvocation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1356,6 +1356,7 @@ static bool ParseLangArgs(LangOptions &Opts, ArgList &Args,
}
}

Opts.SkipDeserializationChecksForPackageCMO = Args.hasArg(OPT_ExperimentalSkipDeserializationChecksForPackageCMO);
Opts.AllowNonResilientAccess =
Args.hasArg(OPT_experimental_allow_non_resilient_access) ||
Args.hasArg(OPT_allow_non_resilient_access) ||
Expand Down
9 changes: 2 additions & 7 deletions lib/SIL/IR/TypeLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2443,13 +2443,8 @@ namespace {
// The same should happen if the type was resilient and serialized in
// another module in the same package with package-cmo enabled, which
// treats those modules to be in the same resilience domain.
auto declModule = D->getModuleContext();
bool sameModule = (declModule == &TC.M);
bool serializedPackage = declModule != &TC.M &&
declModule->inSamePackage(&TC.M) &&
declModule->isResilient() &&
declModule->serializePackageEnabled();
auto inSameResilienceDomain = sameModule || serializedPackage;
auto inSameResilienceDomain = D->getModuleContext() == &TC.M ||
D->bypassResilienceInPackage(&TC.M);
if (inSameResilienceDomain)
properties.addSubobject(RecursiveProperties::forResilient());

Expand Down
Loading