Skip to content

Commit 2a5a1e6

Browse files
committed
Package optimization allows bypassing resilience, but that assumes the memory layout of the
decl being accessed is correct. When this assumption fails due to a deserialization error of its members, the use site accesses the layout with a wrong field offset, resulting in UB or a crash. The deserialization error is currently not caught at compile time due to LangOpts.EnableDeserializationRecovery being enabled by default to allow for recovery of some of the deserialization errors at a later time. In case of member deserialization, however, it's not necessarily recovered later on. This PR tracks whether member deserialization had an error, and uses that info recursively in resilience bypassing check and fails with a diagnostic. Resolves rdar://132411524
1 parent f7d1c59 commit 2a5a1e6

File tree

10 files changed

+317
-20
lines changed

10 files changed

+317
-20
lines changed

include/swift/AST/DeclContext.h

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -806,6 +806,18 @@ class IterableDeclContext {
806806
/// while skipping the body of this context.
807807
unsigned HasDerivativeDeclarations : 1;
808808

809+
/// Members of a decl are deserialized lazily. This is set when
810+
/// deserialization of all members is done, regardless of errors.
811+
unsigned DeserializedMembers : 1;
812+
813+
/// Deserialization errors are attempted to be recovered later or
814+
/// silently dropped due to `EnableDeserializationRecovery` being
815+
/// on by default. The following flag is set when deserializing
816+
/// members fails regardless of the `EnableDeserializationRecovery`
817+
/// value and is used to prevent decl containing such members from
818+
/// being accessed non-resiliently.
819+
unsigned HasDeserializeMemberError : 1;
820+
809821
template<class A, class B, class C>
810822
friend struct ::llvm::CastInfo;
811823

@@ -824,6 +836,8 @@ class IterableDeclContext {
824836
HasDerivativeDeclarations = 0;
825837
HasNestedClassDeclarations = 0;
826838
InFreestandingMacroArgument = 0;
839+
DeserializedMembers = 0;
840+
HasDeserializeMemberError = 0;
827841
}
828842

829843
/// Determine the kind of iterable context we have.
@@ -833,6 +847,16 @@ class IterableDeclContext {
833847

834848
bool hasUnparsedMembers() const;
835849

850+
void setDeserializedMembers(bool deserialized) { DeserializedMembers = deserialized; }
851+
bool didDeserializeMembers() const { return DeserializedMembers; }
852+
853+
void setHasDeserializeMemberError(bool hasError) { HasDeserializeMemberError = hasError; }
854+
bool hasDeserializeMemberError() const { return HasDeserializeMemberError; }
855+
856+
/// This recursively checks whether types of this decl's members
857+
/// were deserialized correctly.
858+
void checkDeserializeMemberErrorRecursively();
859+
836860
bool maybeHasOperatorDeclarations() const {
837861
return HasOperatorDeclarations;
838862
}

include/swift/AST/DiagnosticsSema.def

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4726,6 +4726,11 @@ NOTE(ambiguous_because_of_trailing_closure,none,
47264726
"avoid using a trailing closure}0 to call %1",
47274727
(bool, const ValueDecl *))
47284728

4729+
// In-package resilience bypassing
4730+
ERROR(cannot_bypass_resilience_due_to_deserialization_error,none,
4731+
"cannot bypass resilience when accessing %0 because some of its members failed to deserialize",
4732+
(Identifier))
4733+
47294734
// Cannot capture inout-ness of a parameter
47304735
// Partial application of foreign functions not supported
47314736
ERROR(partial_application_of_function_invalid,none,

lib/AST/Decl.cpp

Lines changed: 54 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4572,15 +4572,62 @@ bool ValueDecl::hasOpenAccess(const DeclContext *useDC) const {
45724572
return access == AccessLevel::Open;
45734573
}
45744574

4575+
static llvm::DenseMap<Identifier, std::pair<Identifier, bool>> DeclBypassMap;
4576+
45754577
bool ValueDecl::bypassResilienceInPackage(ModuleDecl *accessingModule) const {
4576-
// If the defining module is built with package-cmo, bypass
4577-
// resilient access from the use site that belongs to a module
4578-
// in the same package.
4578+
// To allow bypassing resilience when accessing this decl from a
4579+
// use site, the module of the use site should be in the same package
4580+
// as the module of this decl.
45794581
auto declModule = getModuleContext();
4580-
return declModule->inSamePackage(accessingModule) &&
4581-
declModule->isResilient() &&
4582-
declModule->allowNonResilientAccess() &&
4583-
declModule->serializePackageEnabled();
4582+
if (!declModule->inSamePackage(accessingModule))
4583+
return false;
4584+
4585+
// If in the same package, package optimization should be enabled
4586+
// for the module containing this decl.
4587+
if (declModule->isResilient() &&
4588+
declModule->allowNonResilientAccess() &&
4589+
declModule->serializePackageEnabled()) {
4590+
4591+
// If this decl is accessed from another module, check if deserializing
4592+
// members of this decl had an error; disallow bypassing in case of an error
4593+
// since accessing memory layout directly will cause UB or crash in such case.
4594+
if (accessingModule &&
4595+
accessingModule != declModule &&
4596+
!getBaseName().isSpecial()) {
4597+
4598+
// First look up cached value
4599+
auto found = DeclBypassMap.find(getBaseIdentifier());
4600+
if (found != DeclBypassMap.end()) {
4601+
if (found->getSecond().first == accessingModule->getBaseIdentifier())
4602+
return found->getSecond().second;
4603+
}
4604+
4605+
// If this decl contains members, check if the members were
4606+
// deserialized correctly.
4607+
if (auto IDC = dyn_cast<IterableDeclContext>(this)) {
4608+
// Members are deserialized lazily, so if un-deserialized,
4609+
// force load them all here.
4610+
IDC->checkDeserializeMemberErrorRecursively();
4611+
4612+
// If member deserialization had an error, fail here.
4613+
if (IDC->hasDeserializeMemberError()) {
4614+
if (!DeclBypassMap.insert({getBaseIdentifier(), {accessingModule->getBaseIdentifier(), false}}).second) {
4615+
return false;
4616+
}
4617+
getASTContext().Diags.diagnose(getLoc(),
4618+
diag::cannot_bypass_resilience_due_to_deserialization_error,
4619+
getBaseIdentifier());
4620+
return false;
4621+
}
4622+
}
4623+
4624+
if (!DeclBypassMap.insert({getBaseIdentifier(), {accessingModule->getBaseIdentifier(), true}}).second) {
4625+
return false;
4626+
}
4627+
}
4628+
return true;
4629+
}
4630+
return false;
45844631
}
45854632

45864633
/// Given the formal access level for using \p VD, compute the scope where

lib/AST/DeclContext.cpp

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1173,6 +1173,51 @@ void IterableDeclContext::loadAllMembers() const {
11731173
--s->getFrontendCounters().NumUnloadedLazyIterableDeclContexts;
11741174
}
11751175

1176+
// This recursively checks whether types of members were deserialized
1177+
// correctly.
1178+
void IterableDeclContext::checkDeserializeMemberErrorRecursively() {
1179+
if (!didDeserializeMembers()) {
1180+
// This needs to be set to force load all members.
1181+
setHasLazyMembers(true);
1182+
// Call getMembers to actually load them all.
1183+
auto members = getMembers();
1184+
assert(!hasLazyMembers());
1185+
assert(didDeserializeMembers());
1186+
}
1187+
1188+
// Members could have been deserialized from other calls, so we
1189+
// still need to check for an error here even if they are already
1190+
// deserialized.
1191+
if (!hasDeserializeMemberError()) {
1192+
// If members are already loaded, getMembers call should be inexpensive.
1193+
for (auto member: getMembers()) {
1194+
if (auto *PBD = dyn_cast<PatternBindingDecl>(member)) {
1195+
for (auto i : range(PBD->getNumPatternEntries())) {
1196+
auto pattern = PBD->getPattern(i);
1197+
pattern->forEachVariable([&](const VarDecl *VD) {
1198+
if (auto actualType =
1199+
VD->getInterfaceType()->getCanonicalType().getOptionalObjectType()) {
1200+
if (auto fieldNominal = actualType->getNominalOrBoundGenericNominal()) {
1201+
if (auto fieldIDC = dyn_cast<IterableDeclContext>(fieldNominal)) {
1202+
// Recursively check on the type of this decl's member.
1203+
fieldIDC->checkDeserializeMemberErrorRecursively();
1204+
1205+
if (fieldIDC->hasDeserializeMemberError()) {
1206+
// If deserializing a member had an error, set
1207+
// the error bit for this decl as well.
1208+
setHasDeserializeMemberError(true);
1209+
return;
1210+
}
1211+
}
1212+
}
1213+
}
1214+
});
1215+
}
1216+
}
1217+
}
1218+
}
1219+
}
1220+
11761221
bool IterableDeclContext::wasDeserialized() const {
11771222
const DeclContext *DC = getAsGenericContext();
11781223
if (auto F = dyn_cast<FileUnit>(DC->getModuleScopeContext())) {

lib/AST/TypeSubstitution.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -997,8 +997,7 @@ ReplaceOpaqueTypesWithUnderlyingTypes::shouldPerformSubstitution(
997997
// resilient expansion if the context's and the opaque type's module are in
998998
// the same package.
999999
if (contextExpansion == ResilienceExpansion::Maximal &&
1000-
module->isResilient() && module->serializePackageEnabled() &&
1001-
module->inSamePackage(contextModule))
1000+
namingDecl->bypassResilienceInPackage(contextModule))
10021001
return OpaqueSubstitutionKind::SubstituteSamePackageMaximalResilience;
10031002

10041003
// Allow general replacement from non resilient modules. Otherwise, disallow.

lib/ClangImporter/ImportDecl.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9713,17 +9713,22 @@ ClangImporter::Implementation::loadAllMembers(Decl *D, uint64_t extra) {
97139713
// Check whether we're importing an Objective-C container of some sort.
97149714
auto objcContainer =
97159715
dyn_cast_or_null<clang::ObjCContainerDecl>(D->getClangDecl());
9716+
auto *IDC = dyn_cast<IterableDeclContext>(D);
97169717

97179718
// If not, we're importing globals-as-members into an extension.
97189719
if (objcContainer) {
97199720
loadAllMembersOfSuperclassIfNeeded(dyn_cast<ClassDecl>(D));
97209721
loadAllMembersOfObjcContainer(D, objcContainer);
9722+
if (IDC) // Set member deserialization status
9723+
IDC->setDeserializedMembers(true);
97219724
return;
97229725
}
97239726

97249727
if (isa_and_nonnull<clang::RecordDecl>(D->getClangDecl())) {
97259728
loadAllMembersOfRecordDecl(cast<NominalTypeDecl>(D),
97269729
cast<clang::RecordDecl>(D->getClangDecl()));
9730+
if (IDC) // Set member deserialization status
9731+
IDC->setDeserializedMembers(true);
97279732
return;
97289733
}
97299734

@@ -9734,6 +9739,8 @@ ClangImporter::Implementation::loadAllMembers(Decl *D, uint64_t extra) {
97349739
}
97359740

97369741
loadAllMembersIntoExtension(D, extra);
9742+
if (IDC) // Set member deserialization status
9743+
IDC->setDeserializedMembers(true);
97379744
}
97389745

97399746
void ClangImporter::Implementation::loadAllMembersIntoExtension(

lib/Frontend/CompilerInvocation.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1379,7 +1379,7 @@ static bool ParseLangArgs(LangOptions &Opts, ArgList &Args,
13791379
Opts.AllowNonResilientAccess = false;
13801380
}
13811381
}
1382-
1382+
13831383
// HACK: The driver currently erroneously passes all flags to module interface
13841384
// verification jobs. -experimental-skip-non-exportable-decls is not
13851385
// appropriate for verification tasks and should be ignored, though.

lib/SIL/IR/TypeLowering.cpp

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2442,13 +2442,8 @@ namespace {
24422442
// The same should happen if the type was resilient and serialized in
24432443
// another module in the same package with package-cmo enabled, which
24442444
// treats those modules to be in the same resilience domain.
2445-
auto declModule = D->getModuleContext();
2446-
bool sameModule = (declModule == &TC.M);
2447-
bool serializedPackage = declModule != &TC.M &&
2448-
declModule->inSamePackage(&TC.M) &&
2449-
declModule->isResilient() &&
2450-
declModule->serializePackageEnabled();
2451-
auto inSameResilienceDomain = sameModule || serializedPackage;
2445+
auto inSameResilienceDomain = D->getModuleContext() == &TC.M ||
2446+
D->bypassResilienceInPackage(&TC.M);
24522447
if (inSameResilienceDomain)
24532448
properties.addSubobject(RecursiveProperties::forResilient());
24542449

lib/Serialization/Deserialization.cpp

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8283,6 +8283,7 @@ ModuleFile::handleErrorAndSupplyMissingMember(ASTContext &context,
82838283
return handleErrorAndSupplyMissingProtoMember(context, std::move(error),
82848284
containingProto);
82858285
}
8286+
82868287
return handleErrorAndSupplyMissingMiscMember(std::move(error));
82878288
}
82888289

@@ -8300,6 +8301,7 @@ void ModuleFile::loadAllMembers(Decl *container, uint64_t contextData) {
83008301
if (diagnoseAndConsumeFatalIfNotSuccess(
83018302
DeclTypeCursor.JumpToBit(contextData)))
83028303
return;
8304+
83038305
llvm::BitstreamEntry entry = fatalIfUnexpected(DeclTypeCursor.advance());
83048306
if (entry.Kind != llvm::BitstreamEntry::Record)
83058307
return diagnoseAndConsumeFatal();
@@ -8314,9 +8316,12 @@ void ModuleFile::loadAllMembers(Decl *container, uint64_t contextData) {
83148316
ArrayRef<uint64_t> rawMemberIDs;
83158317
decls_block::MembersLayout::readRecord(memberIDBuffer, rawMemberIDs);
83168318

8317-
if (rawMemberIDs.empty())
8319+
if (rawMemberIDs.empty()) {
8320+
// No members; set the state of member deserialization to done.
8321+
if (!IDC->didDeserializeMembers())
8322+
IDC->setDeserializedMembers(true);
83188323
return;
8319-
8324+
}
83208325
SmallVector<Decl *, 16> members;
83218326
members.reserve(rawMemberIDs.size());
83228327
for (DeclID rawID : rawMemberIDs) {
@@ -8325,15 +8330,25 @@ void ModuleFile::loadAllMembers(Decl *container, uint64_t contextData) {
83258330
assert(next.get() && "unchecked error deserializing next member");
83268331
members.push_back(next.get());
83278332
} else {
8333+
83288334
if (!getContext().LangOpts.EnableDeserializationRecovery)
83298335
fatal(next.takeError());
83308336

83318337
Decl *suppliedMissingMember = handleErrorAndSupplyMissingMember(
83328338
getContext(), container, next.takeError());
8333-
if (suppliedMissingMember)
8339+
if (suppliedMissingMember) {
83348340
members.push_back(suppliedMissingMember);
8341+
}
8342+
8343+
// Not all members can be discovered as missing
8344+
// members as checked above, so set the error bit
8345+
// here.
8346+
IDC->setHasDeserializeMemberError(true);
83358347
}
83368348
}
8349+
// Set the status of member deserialization to Done.
8350+
if (!IDC->didDeserializeMembers())
8351+
IDC->setDeserializedMembers(true);
83378352

83388353
for (auto member : members)
83398354
IDC->addMember(member);

0 commit comments

Comments
 (0)