Skip to content

Commit 9589546

Browse files
committed
Package optimization allows bypassing resilience, but that assumes the memory layout of the
decl being accessed is correct. When this assumption fails due to a deserialization error of its members, the use site accesses the layout with a wrong field offset, resulting in UB or a crash. The deserialization error is currently not caught at compile time due to LangOpts.EnableDeserializationRecovery being enabled by default to allow for recovery of some of the deserialization errors at a later time. In case of member deserialization, however, it's not necessarily recovered later on. This PR tracks whether member deserialization had an error, and uses that info recursively in resilience bypassing check and fails with a diagnostic. Resolves rdar://132411524
1 parent f7d1c59 commit 9589546

File tree

10 files changed

+323
-18
lines changed

10 files changed

+323
-18
lines changed

include/swift/AST/Decl.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2828,6 +2828,7 @@ class ValueDecl : public Decl {
28282828

28292829
friend class Decl;
28302830
SourceLoc getLocFromSource() const { return NameLoc; }
2831+
28312832
protected:
28322833
ValueDecl(DeclKind K,
28332834
llvm::PointerUnion<DeclContext *, ASTContext *> context,

include/swift/AST/DeclContext.h

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -806,6 +806,18 @@ class IterableDeclContext {
806806
/// while skipping the body of this context.
807807
unsigned HasDerivativeDeclarations : 1;
808808

809+
/// Members of a decl are deserialized lazily. This is set when
810+
/// deserialization of all members is done, regardless of errors.
811+
unsigned DeserializedMembers : 1;
812+
813+
/// Deserialization errors are attempted to be recovered later or
814+
/// silently dropped due to `EnableDeserializationRecovery` being
815+
/// on by default. The following flag is set when deserializing
816+
/// members fails regardless of the `EnableDeserializationRecovery`
817+
/// value and is used to prevent decl containing such members from
818+
/// being accessed non-resiliently.
819+
unsigned HasDeserializeMemberError : 1;
820+
809821
template<class A, class B, class C>
810822
friend struct ::llvm::CastInfo;
811823

@@ -824,6 +836,8 @@ class IterableDeclContext {
824836
HasDerivativeDeclarations = 0;
825837
HasNestedClassDeclarations = 0;
826838
InFreestandingMacroArgument = 0;
839+
DeserializedMembers = 0;
840+
HasDeserializeMemberError = 0;
827841
}
828842

829843
/// Determine the kind of iterable context we have.
@@ -833,6 +847,16 @@ class IterableDeclContext {
833847

834848
bool hasUnparsedMembers() const;
835849

850+
void setDeserializedMembers(bool deserialized) { DeserializedMembers = deserialized; }
851+
bool didDeserializeMembers() const { return DeserializedMembers; }
852+
853+
void setHasDeserializeMemberError(bool hasError) { HasDeserializeMemberError = hasError; }
854+
bool hasDeserializeMemberError() const { return HasDeserializeMemberError; }
855+
856+
/// This recursively checks whether types of this decl's members
857+
/// were deserialized correctly.
858+
void checkDeserializeMemberErrorRecursively();
859+
836860
bool maybeHasOperatorDeclarations() const {
837861
return HasOperatorDeclarations;
838862
}

include/swift/AST/DiagnosticsSema.def

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4726,6 +4726,11 @@ NOTE(ambiguous_because_of_trailing_closure,none,
47264726
"avoid using a trailing closure}0 to call %1",
47274727
(bool, const ValueDecl *))
47284728

4729+
// In-package resilience bypassing
4730+
ERROR(cannot_bypass_resilience_due_to_deserialization_error,none,
4731+
"cannot bypass resilience when accessing %0 because some of its members failed to deserialize",
4732+
(Identifier))
4733+
47294734
// Cannot capture inout-ness of a parameter
47304735
// Partial application of foreign functions not supported
47314736
ERROR(partial_application_of_function_invalid,none,

lib/AST/Decl.cpp

Lines changed: 68 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4572,15 +4572,76 @@ bool ValueDecl::hasOpenAccess(const DeclContext *useDC) const {
45724572
return access == AccessLevel::Open;
45734573
}
45744574

4575+
/// Used to track whether accessing this decl from another module is allowed resilience bypassing.
4576+
static llvm::DenseMap<Identifier, Identifier> DeclResilienceBypassMap;
4577+
45754578
bool ValueDecl::bypassResilienceInPackage(ModuleDecl *accessingModule) const {
4576-
// If the defining module is built with package-cmo, bypass
4577-
// resilient access from the use site that belongs to a module
4578-
// in the same package.
4579+
// To allow bypassing resilience when accessing this decl from a
4580+
// use site, the module of the use site should be in the same package
4581+
// as the module of this decl.
45794582
auto declModule = getModuleContext();
4580-
return declModule->inSamePackage(accessingModule) &&
4581-
declModule->isResilient() &&
4582-
declModule->allowNonResilientAccess() &&
4583-
declModule->serializePackageEnabled();
4583+
if (!declModule->inSamePackage(accessingModule))
4584+
return false;
4585+
4586+
// First look up the cached value
4587+
if (accessingModule &&
4588+
accessingModule != declModule &&
4589+
!getBaseName().isSpecial()) {
4590+
auto found = DeclResilienceBypassMap.find(getBaseIdentifier());
4591+
if (found != DeclResilienceBypassMap.end()) {
4592+
if (found->getSecond() == accessingModule->getBaseIdentifier())
4593+
return true;
4594+
}
4595+
}
4596+
4597+
// Package optimization allows bypassing resilience, but it assumes the
4598+
// memory layout of the decl being accessed is correct. When this assumption
4599+
// fails due to a deserialization error of its members, the use site accesses
4600+
// the layout of the decl with a wrong field offset, resulting in UB or a crash.
4601+
// The deserialization error is currently not caught at compile time due to
4602+
// LangOpts.EnableDeserializationRecovery being enabled by default (to allow
4603+
// for recovery of some of the deserialization errors at a later time). In case
4604+
// of member deserialization, however, it's not necessarily recovered later on
4605+
// and is silently dropped.
4606+
// The following tracks errors in member deserialization by recursively loading
4607+
// members of a type (if not done) and checking whether the type's members, and
4608+
// their respective types (recursively), encountered deserialization failures.
4609+
// If any such type is found, it fails and emits a diagnostic at compile time.
4610+
// Simply disallowing resilience bypassing here and continuing is insufficient
4611+
// because it would later (during SIL deserialiaztion) require skipping instructions
4612+
// that were valid in the imported module but are no longer valid at the client
4613+
// module due to type requirement mismatch; addressing this would involve exhaustive
4614+
// instruction checks that can be complex and error-prone.
4615+
if (declModule->isResilient() &&
4616+
declModule->allowNonResilientAccess() &&
4617+
declModule->serializePackageEnabled()) {
4618+
// If this decl is accessed from another module, check if deserializing
4619+
// members of this decl had an error; if it errored, the decl now has
4620+
// an incorrect memory layout, and accessing it directly would most likely
4621+
// use a wrong field offset, resulting in UB or crash. To prevent this,
4622+
// fail and emit diagnostic here.
4623+
if (accessingModule &&
4624+
accessingModule != declModule &&
4625+
!getBaseName().isSpecial()) {
4626+
if (auto IDC = dyn_cast<IterableDeclContext>(this)) {
4627+
// Recursively check if members and their members have failing
4628+
// deserialization.
4629+
IDC->checkDeserializeMemberErrorRecursively();
4630+
// If member deserialization had an error, fail and emit diag here.
4631+
if (IDC->hasDeserializeMemberError()) {
4632+
getASTContext().Diags.diagnose(getLoc(),
4633+
diag::cannot_bypass_resilience_due_to_deserialization_error,
4634+
getBaseIdentifier());
4635+
return false;
4636+
}
4637+
}
4638+
if (!DeclResilienceBypassMap.insert({getBaseIdentifier(), accessingModule->getBaseIdentifier()}).second) {
4639+
return false;
4640+
}
4641+
}
4642+
return true;
4643+
}
4644+
return false;
45844645
}
45854646

45864647
/// Given the formal access level for using \p VD, compute the scope where

lib/AST/DeclContext.cpp

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1173,6 +1173,50 @@ void IterableDeclContext::loadAllMembers() const {
11731173
--s->getFrontendCounters().NumUnloadedLazyIterableDeclContexts;
11741174
}
11751175

1176+
// This checks whether types of members and their respective members
1177+
// (recursively) were deserialized correctly.
1178+
void IterableDeclContext::checkDeserializeMemberErrorRecursively() {
1179+
if (!didDeserializeMembers()) {
1180+
// This needs to be set to force load all members if not done already.
1181+
setHasLazyMembers(true);
1182+
// Call getMembers to actually load them all.
1183+
auto members = getMembers();
1184+
assert(!hasLazyMembers());
1185+
assert(didDeserializeMembers());
1186+
}
1187+
// Members could have been deserialized from other calls, so we
1188+
// still need to check for an error here even if they might have
1189+
// already been deserialized.
1190+
if (!hasDeserializeMemberError()) {
1191+
// If members are already loaded, getMembers call should be inexpensive.
1192+
for (auto member: getMembers()) {
1193+
if (auto *PBD = dyn_cast<PatternBindingDecl>(member)) {
1194+
for (auto i : range(PBD->getNumPatternEntries())) {
1195+
auto pattern = PBD->getPattern(i);
1196+
pattern->forEachVariable([&](const VarDecl *VD) {
1197+
// In case of Optional, looking up the unwrapped type.
1198+
if (auto actualType =
1199+
VD->getInterfaceType()->getCanonicalType().getOptionalObjectType()) {
1200+
if (auto fieldNominal = actualType->getNominalOrBoundGenericNominal()) {
1201+
if (auto fieldIDC = dyn_cast<IterableDeclContext>(fieldNominal)) {
1202+
// Recursively check on the type of a member.
1203+
fieldIDC->checkDeserializeMemberErrorRecursively();
1204+
// If the member had deserialization failure, mark the
1205+
// same for its containing type.
1206+
if (fieldIDC->hasDeserializeMemberError()) {
1207+
setHasDeserializeMemberError(true);
1208+
return;
1209+
}
1210+
}
1211+
}
1212+
}
1213+
});
1214+
}
1215+
}
1216+
}
1217+
}
1218+
}
1219+
11761220
bool IterableDeclContext::wasDeserialized() const {
11771221
const DeclContext *DC = getAsGenericContext();
11781222
if (auto F = dyn_cast<FileUnit>(DC->getModuleScopeContext())) {

lib/AST/TypeSubstitution.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -997,8 +997,7 @@ ReplaceOpaqueTypesWithUnderlyingTypes::shouldPerformSubstitution(
997997
// resilient expansion if the context's and the opaque type's module are in
998998
// the same package.
999999
if (contextExpansion == ResilienceExpansion::Maximal &&
1000-
module->isResilient() && module->serializePackageEnabled() &&
1001-
module->inSamePackage(contextModule))
1000+
namingDecl->bypassResilienceInPackage(contextModule))
10021001
return OpaqueSubstitutionKind::SubstituteSamePackageMaximalResilience;
10031002

10041003
// Allow general replacement from non resilient modules. Otherwise, disallow.

lib/ClangImporter/ImportDecl.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9713,17 +9713,22 @@ ClangImporter::Implementation::loadAllMembers(Decl *D, uint64_t extra) {
97139713
// Check whether we're importing an Objective-C container of some sort.
97149714
auto objcContainer =
97159715
dyn_cast_or_null<clang::ObjCContainerDecl>(D->getClangDecl());
9716+
auto *IDC = dyn_cast<IterableDeclContext>(D);
97169717

97179718
// If not, we're importing globals-as-members into an extension.
97189719
if (objcContainer) {
97199720
loadAllMembersOfSuperclassIfNeeded(dyn_cast<ClassDecl>(D));
97209721
loadAllMembersOfObjcContainer(D, objcContainer);
9722+
if (IDC) // Set member deserialization status
9723+
IDC->setDeserializedMembers(true);
97219724
return;
97229725
}
97239726

97249727
if (isa_and_nonnull<clang::RecordDecl>(D->getClangDecl())) {
97259728
loadAllMembersOfRecordDecl(cast<NominalTypeDecl>(D),
97269729
cast<clang::RecordDecl>(D->getClangDecl()));
9730+
if (IDC) // Set member deserialization status
9731+
IDC->setDeserializedMembers(true);
97279732
return;
97289733
}
97299734

@@ -9734,6 +9739,8 @@ ClangImporter::Implementation::loadAllMembers(Decl *D, uint64_t extra) {
97349739
}
97359740

97369741
loadAllMembersIntoExtension(D, extra);
9742+
if (IDC) // Set member deserialization status
9743+
IDC->setDeserializedMembers(true);
97379744
}
97389745

97399746
void ClangImporter::Implementation::loadAllMembersIntoExtension(

lib/SIL/IR/TypeLowering.cpp

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2442,13 +2442,8 @@ namespace {
24422442
// The same should happen if the type was resilient and serialized in
24432443
// another module in the same package with package-cmo enabled, which
24442444
// treats those modules to be in the same resilience domain.
2445-
auto declModule = D->getModuleContext();
2446-
bool sameModule = (declModule == &TC.M);
2447-
bool serializedPackage = declModule != &TC.M &&
2448-
declModule->inSamePackage(&TC.M) &&
2449-
declModule->isResilient() &&
2450-
declModule->serializePackageEnabled();
2451-
auto inSameResilienceDomain = sameModule || serializedPackage;
2445+
auto inSameResilienceDomain = D->getModuleContext() == &TC.M ||
2446+
D->bypassResilienceInPackage(&TC.M);
24522447
if (inSameResilienceDomain)
24532448
properties.addSubobject(RecursiveProperties::forResilient());
24542449

lib/Serialization/Deserialization.cpp

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8314,9 +8314,12 @@ void ModuleFile::loadAllMembers(Decl *container, uint64_t contextData) {
83148314
ArrayRef<uint64_t> rawMemberIDs;
83158315
decls_block::MembersLayout::readRecord(memberIDBuffer, rawMemberIDs);
83168316

8317-
if (rawMemberIDs.empty())
8317+
if (rawMemberIDs.empty()) {
8318+
// No members; set the state of member deserialization to done.
8319+
if (!IDC->didDeserializeMembers())
8320+
IDC->setDeserializedMembers(true);
83188321
return;
8319-
8322+
}
83208323
SmallVector<Decl *, 16> members;
83218324
members.reserve(rawMemberIDs.size());
83228325
for (DeclID rawID : rawMemberIDs) {
@@ -8332,8 +8335,16 @@ void ModuleFile::loadAllMembers(Decl *container, uint64_t contextData) {
83328335
getContext(), container, next.takeError());
83338336
if (suppliedMissingMember)
83348337
members.push_back(suppliedMissingMember);
8338+
8339+
// Not all members can be discovered as missing
8340+
// members as checked above, so set the error bit
8341+
// here.
8342+
IDC->setHasDeserializeMemberError(true);
83358343
}
83368344
}
8345+
// Set the status of member deserialization to Done.
8346+
if (!IDC->didDeserializeMembers())
8347+
IDC->setDeserializedMembers(true);
83378348

83388349
for (auto member : members)
83398350
IDC->addMember(member);

0 commit comments

Comments
 (0)