Skip to content

Commit 727fb8c

Browse files
authored
Merge pull request #78258 from swiftlang/elsh/disallow-bypass-deser-check
Package CMO: add deserialization checks to ensure correct memory layout
2 parents ee6652d + c03abed commit 727fb8c

File tree

12 files changed

+431
-18
lines changed

12 files changed

+431
-18
lines changed

include/swift/AST/DeclContext.h

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -811,6 +811,22 @@ class IterableDeclContext {
811811
/// while skipping the body of this context.
812812
unsigned HasDerivativeDeclarations : 1;
813813

814+
/// Members of a decl are deserialized lazily. This is set when
815+
/// deserialization of all members is done, regardless of errors.
816+
unsigned DeserializedMembers : 1;
817+
818+
/// Deserialization errors are attempted to be recovered later or
819+
/// silently dropped due to `EnableDeserializationRecovery` being
820+
/// on by default. The following flag is set when deserializing
821+
/// members fails regardless of the `EnableDeserializationRecovery`
822+
/// value and is used to prevent decl containing such members from
823+
/// being accessed non-resiliently.
824+
unsigned HasDeserializeMemberError : 1;
825+
826+
/// Used to track whether members of this decl and their respective
827+
/// members were checked for deserialization errors recursively.
828+
unsigned CheckedForDeserializeMemberError : 1;
829+
814830
template<class A, class B, class C>
815831
friend struct ::llvm::CastInfo;
816832

@@ -821,6 +837,9 @@ class IterableDeclContext {
821837
/// Retrieve the \c ASTContext in which this iterable context occurs.
822838
ASTContext &getASTContext() const;
823839

840+
void setCheckedForDeserializeMemberError(bool checked) { CheckedForDeserializeMemberError = checked; }
841+
bool checkedForDeserializeMemberError() const { return CheckedForDeserializeMemberError; }
842+
824843
public:
825844
IterableDeclContext(IterableDeclContextKind kind)
826845
: LastDeclAndKind(nullptr, kind) {
@@ -829,6 +848,9 @@ class IterableDeclContext {
829848
HasDerivativeDeclarations = 0;
830849
HasNestedClassDeclarations = 0;
831850
InFreestandingMacroArgument = 0;
851+
DeserializedMembers = 0;
852+
HasDeserializeMemberError = 0;
853+
CheckedForDeserializeMemberError = 0;
832854
}
833855

834856
/// Determine the kind of iterable context we have.
@@ -838,6 +860,18 @@ class IterableDeclContext {
838860

839861
bool hasUnparsedMembers() const;
840862

863+
void setDeserializedMembers(bool deserialized) { DeserializedMembers = deserialized; }
864+
bool didDeserializeMembers() const { return DeserializedMembers; }
865+
866+
void setHasDeserializeMemberError(bool hasError) { HasDeserializeMemberError = hasError; }
867+
bool hasDeserializeMemberError() const { return HasDeserializeMemberError; }
868+
869+
/// This recursively checks whether members of this decl and their respective
870+
/// members were deserialized correctly and emits a diagnostic in case of an error.
871+
/// Requires accessing module and this decl's module are in the same package,
872+
/// and this decl's module has package optimization enabled.
873+
void checkDeserializeMemberErrorInPackage(ModuleDecl *accessingModule);
874+
841875
bool maybeHasOperatorDeclarations() const {
842876
return HasOperatorDeclarations;
843877
}

include/swift/AST/DiagnosticsSema.def

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4734,6 +4734,11 @@ NOTE(ambiguous_because_of_trailing_closure,none,
47344734
"avoid using a trailing closure}0 to call %1",
47354735
(bool, const ValueDecl *))
47364736

4737+
// In-package resilience bypassing
4738+
ERROR(cannot_bypass_resilience_due_to_missing_member,none,
4739+
"cannot bypass resilience due to member deserialization failure while attempting to access %select{member %0|missing member}1 of %2 in module %3 from module %4",
4740+
(Identifier, bool, Identifier, Identifier, Identifier))
4741+
47374742
// Cannot capture inout-ness of a parameter
47384743
// Partial application of foreign functions not supported
47394744
ERROR(partial_application_of_function_invalid,none,

include/swift/Basic/LangOptions.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -610,6 +610,13 @@ namespace swift {
610610
/// from source.
611611
bool AllowNonResilientAccess = false;
612612

613+
/// When Package CMO is enabled, deserialization checks are done to
614+
/// ensure that the members of a decl are correctly deserialized to maintain
615+
/// proper layout. This ensures that bypassing resilience is safe. Accessing
616+
/// an incorrectly laid-out decl directly can lead to runtime crashes. This flag
617+
/// should only be used temporarily during migration to enable Package CMO.
618+
bool SkipDeserializationChecksForPackageCMO = false;
619+
613620
/// Enables dumping type witness systems from associated type inference.
614621
bool DumpTypeWitnessSystems = false;
615622

include/swift/Option/Options.td

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1081,6 +1081,10 @@ def Oplayground : Flag<["-"], "Oplayground">, Group<O_Group>,
10811081
Flags<[HelpHidden, FrontendOption, ModuleInterfaceOption]>,
10821082
HelpText<"Compile with optimizations appropriate for a playground">;
10831083

1084+
def ExperimentalSkipDeserializationChecksForPackageCMO : Flag<["-"], "experimental-skip-deserialization-checks-for-package-cmo">,
1085+
Flags<[FrontendOption]>,
1086+
HelpText<"Skip deserialization checks for package-cmo; use only for experimental purposes">;
1087+
10841088
def ExperimentalPackageCMO : Flag<["-"], "experimental-package-cmo">,
10851089
Flags<[FrontendOption]>,
10861090
HelpText<"Deprecated; use -package-cmo instead">;

lib/AST/Decl.cpp

Lines changed: 44 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4714,14 +4714,51 @@ bool ValueDecl::hasOpenAccess(const DeclContext *useDC) const {
47144714
}
47154715

47164716
bool ValueDecl::bypassResilienceInPackage(ModuleDecl *accessingModule) const {
4717-
// If the defining module is built with package-cmo, bypass
4718-
// resilient access from the use site that belongs to a module
4719-
// in the same package.
4717+
// To allow bypassing resilience when accessing this decl from another
4718+
// module, it should be in the same package as this decl's module.
47204719
auto declModule = getModuleContext();
4721-
return declModule->inSamePackage(accessingModule) &&
4722-
declModule->isResilient() &&
4723-
declModule->allowNonResilientAccess() &&
4724-
declModule->serializePackageEnabled();
4720+
if (!declModule->inSamePackage(accessingModule))
4721+
return false;
4722+
// Package optimization allows bypassing resilience, but it assumes the
4723+
// memory layout of the decl being accessed is correct. When this assumption
4724+
// fails due to a deserialization error of its members, the use site incorrectly
4725+
// accesses the layout of the decl with a wrong field offset, resulting in UB
4726+
// or a crash.
4727+
// The deserialization error is currently not caught at compile time due to
4728+
// LangOpts.EnableDeserializationRecovery being enabled by default (to allow
4729+
// for recovery of some of the deserialization errors at a later time). In case
4730+
// of member deserialization, however, it's not necessarily recovered later on
4731+
// and is silently dropped, causing UB or a crash at runtime.
4732+
// The following tracks errors in member deserialization by recursively loading
4733+
// members of a type (if not done already) and checking whether the type's
4734+
// members, and their respective types (recursively), encountered deserialization
4735+
// failures.
4736+
// If any such type is found, it fails and emits a diagnostic at compile time.
4737+
// Simply disallowing resilience bypassing for this decl here is insufficient
4738+
// because it would cause a type requirement mismatch later during SIL
4739+
// deserialiaztion; e.g. generated SIL in the imported module might contain
4740+
// an instruction that allows a direct access, while the caller in client module
4741+
// might require a non-direct access due to a deserialization error.
4742+
if (declModule->isResilient() &&
4743+
declModule->allowNonResilientAccess() &&
4744+
declModule->serializePackageEnabled()) {
4745+
// Fail and diagnose if there is a member deserialization error,
4746+
// with an option to skip for temporary/migration purposes.
4747+
if (!getASTContext().LangOpts.SkipDeserializationChecksForPackageCMO) {
4748+
// Since we're checking for deserialization errors, make sure the
4749+
// accessing module is different from this decl's module.
4750+
if (accessingModule &&
4751+
accessingModule != declModule) {
4752+
if (auto IDC = dyn_cast<IterableDeclContext>(this)) {
4753+
// Recursively check if members and their members have failing
4754+
// deserialization, and emit a diagnostic.
4755+
IDC->checkDeserializeMemberErrorInPackage(accessingModule);
4756+
}
4757+
}
4758+
}
4759+
return true;
4760+
}
4761+
return false;
47254762
}
47264763

47274764
/// Given the formal access level for using \p VD, compute the scope where

lib/AST/DeclContext.cpp

Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
#include "swift/AST/SourceFile.h"
2727
#include "swift/AST/TypeCheckRequests.h"
2828
#include "swift/AST/Types.h"
29+
#include "swift/AST/DiagnosticsSema.h"
2930
#include "swift/Basic/Assertions.h"
3031
#include "swift/Basic/SourceManager.h"
3132
#include "swift/Basic/Statistic.h"
@@ -1174,6 +1175,145 @@ void IterableDeclContext::loadAllMembers() const {
11741175
--s->getFrontendCounters().NumUnloadedLazyIterableDeclContexts;
11751176
}
11761177

1178+
// Checks whether members of this decl and their respective members
1179+
// (recursively) were deserialized correctly and emits a diagnostic
1180+
// if deserialization failed. Requires accessing module and this decl's
1181+
// module are in the same package, and this decl's module has package
1182+
// optimization enabled.
1183+
void IterableDeclContext::checkDeserializeMemberErrorInPackage(ModuleDecl *accessingModule) {
1184+
// Only check if accessing module is in the same package as this
1185+
// decl's module, which has package optimization enabled.
1186+
if (!getDecl()->getModuleContext()->inSamePackage(accessingModule) ||
1187+
!getDecl()->getModuleContext()->isResilient() ||
1188+
!getDecl()->getModuleContext()->serializePackageEnabled())
1189+
return;
1190+
// Bail if already checked for an error.
1191+
if (checkedForDeserializeMemberError())
1192+
return;
1193+
// If members were not deserialized, force load here.
1194+
if (!didDeserializeMembers()) {
1195+
// This needs to be set to force load all members if not done already.
1196+
setHasLazyMembers(true);
1197+
// Calling getMembers actually loads the members.
1198+
auto members = getMembers();
1199+
assert(!hasLazyMembers());
1200+
assert(didDeserializeMembers());
1201+
}
1202+
// Members could have been deserialized from other flows. Check
1203+
// for an error here. First mark this decl 'checked' to prevent
1204+
// infinite recursion in case of self-referencing members.
1205+
setCheckedForDeserializeMemberError(true);
1206+
1207+
// If members are already loaded above or by other flows,
1208+
// calling getMembers here should be inexpensive.
1209+
auto memberList = getMembers();
1210+
1211+
// This decl contains a member deserialization error; emit a diag.
1212+
if (hasDeserializeMemberError()) {
1213+
auto containerID = Identifier();
1214+
if (auto container = dyn_cast<NominalTypeDecl>(getDecl())) {
1215+
containerID = container->getBaseIdentifier();
1216+
}
1217+
1218+
auto foundMissing = false;
1219+
for (auto member: memberList) {
1220+
// Only storage vars can affect memory layout so
1221+
// look up pattern binding decl or var decl.
1222+
if (auto *PBD = dyn_cast<PatternBindingDecl>(member)) {
1223+
// If this pattern binding decl is empty, we have
1224+
// a missing member.
1225+
if (PBD->getNumPatternEntries() == 0)
1226+
foundMissing = true;
1227+
}
1228+
// Check if a member can be cast to MissingMemberDecl.
1229+
if (auto missingMember = dyn_cast<MissingMemberDecl>(member)) {
1230+
if (!missingMember->getName().getBaseName().isSpecial() &&
1231+
foundMissing) {
1232+
foundMissing = false;
1233+
auto missingMemberID = missingMember->getName().getBaseIdentifier();
1234+
getASTContext().Diags.diagnose(member->getLoc(),
1235+
diag::cannot_bypass_resilience_due_to_missing_member,
1236+
missingMemberID,
1237+
missingMemberID.empty(),
1238+
containerID,
1239+
getDecl()->getModuleContext()->getBaseIdentifier(),
1240+
accessingModule->getBaseIdentifier());
1241+
continue;
1242+
}
1243+
}
1244+
// If not handled above, emit a diag here.
1245+
if (foundMissing) {
1246+
getASTContext().Diags.diagnose(getDecl()->getLoc(),
1247+
diag::cannot_bypass_resilience_due_to_missing_member,
1248+
Identifier(),
1249+
true,
1250+
containerID,
1251+
getDecl()->getModuleContext()->getBaseIdentifier(),
1252+
accessingModule->getBaseIdentifier());
1253+
}
1254+
}
1255+
} else {
1256+
// This decl does not contain a member deserialization error.
1257+
// Check for members of this decl's members recursively to
1258+
// see if a member deserialization failed.
1259+
for (auto member: memberList) {
1260+
// Only storage vars can affect memory layout so
1261+
// look up pattern binding decl or var decl.
1262+
if (auto *PBD = dyn_cast<PatternBindingDecl>(member)) {
1263+
for (auto i : range(PBD->getNumPatternEntries())) {
1264+
auto pattern = PBD->getPattern(i);
1265+
pattern->forEachVariable([&](const VarDecl *VD) {
1266+
// Bail if the var is static or has no storage
1267+
if (VD->isStatic() ||
1268+
!VD->hasStorageOrWrapsStorage())
1269+
return;
1270+
// Unwrap in case this var is optional.
1271+
auto varType = VD->getInterfaceType()->getCanonicalType();
1272+
if (auto unwrapped = varType->getCanonicalType()->getOptionalObjectType()) {
1273+
varType = unwrapped->getCanonicalType();
1274+
}
1275+
// Handle BoundGenericType, e.g. [Foo: Bar], Dictionary<Foo, Bar>.
1276+
// Check for its arguments types, i.e. Foo, Bar.
1277+
if (auto boundGeneric = varType->getAs<BoundGenericType>()) {
1278+
for (auto arg : boundGeneric->getGenericArgs()) {
1279+
if (auto argNominal = arg->getNominalOrBoundGenericNominal()) {
1280+
if (auto argIDC = dyn_cast<IterableDeclContext>(argNominal)) {
1281+
argIDC->checkDeserializeMemberErrorInPackage(getDecl()->getModuleContext());
1282+
if (argIDC->hasDeserializeMemberError()) {
1283+
setHasDeserializeMemberError(true);
1284+
break;
1285+
}
1286+
}
1287+
}
1288+
}
1289+
} else if (auto tupleType = varType->getAs<TupleType>()) {
1290+
// Handle TupleType, e.g. (Foo, Var).
1291+
for (auto element : tupleType->getElements()) {
1292+
if (auto elementNominal = element.getType()->getNominalOrBoundGenericNominal()) {
1293+
if (auto elementIDC = dyn_cast<IterableDeclContext>(elementNominal)) {
1294+
elementIDC->checkDeserializeMemberErrorInPackage(getDecl()->getModuleContext());
1295+
if (elementIDC->hasDeserializeMemberError()) {
1296+
setHasDeserializeMemberError(true);
1297+
break;
1298+
}
1299+
}
1300+
}
1301+
}
1302+
} else if (auto varNominal = varType->getNominalOrBoundGenericNominal()) {
1303+
if (auto varIDC = dyn_cast<IterableDeclContext>(varNominal)) {
1304+
varIDC->checkDeserializeMemberErrorInPackage(getDecl()->getModuleContext());
1305+
if (varIDC->hasDeserializeMemberError()) {
1306+
setHasDeserializeMemberError(true);
1307+
}
1308+
}
1309+
}
1310+
});
1311+
}
1312+
}
1313+
}
1314+
}
1315+
}
1316+
11771317
bool IterableDeclContext::wasDeserialized() const {
11781318
const DeclContext *DC = getAsGenericContext();
11791319
if (auto F = dyn_cast<FileUnit>(DC->getModuleScopeContext())) {

lib/AST/TypeSubstitution.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -997,8 +997,7 @@ ReplaceOpaqueTypesWithUnderlyingTypes::shouldPerformSubstitution(
997997
// resilient expansion if the context's and the opaque type's module are in
998998
// the same package.
999999
if (contextExpansion == ResilienceExpansion::Maximal &&
1000-
module->isResilient() && module->serializePackageEnabled() &&
1001-
module->inSamePackage(contextModule))
1000+
namingDecl->bypassResilienceInPackage(contextModule))
10021001
return OpaqueSubstitutionKind::SubstituteSamePackageMaximalResilience;
10031002

10041003
// Allow general replacement from non resilient modules. Otherwise, disallow.

lib/ClangImporter/ImportDecl.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9890,17 +9890,22 @@ ClangImporter::Implementation::loadAllMembers(Decl *D, uint64_t extra) {
98909890
// Check whether we're importing an Objective-C container of some sort.
98919891
auto objcContainer =
98929892
dyn_cast_or_null<clang::ObjCContainerDecl>(D->getClangDecl());
9893+
auto *IDC = dyn_cast<IterableDeclContext>(D);
98939894

98949895
// If not, we're importing globals-as-members into an extension.
98959896
if (objcContainer) {
98969897
loadAllMembersOfSuperclassIfNeeded(dyn_cast<ClassDecl>(D));
98979898
loadAllMembersOfObjcContainer(D, objcContainer);
9899+
if (IDC) // Set member deserialization status
9900+
IDC->setDeserializedMembers(true);
98989901
return;
98999902
}
99009903

99019904
if (isa_and_nonnull<clang::RecordDecl>(D->getClangDecl())) {
99029905
loadAllMembersOfRecordDecl(cast<NominalTypeDecl>(D),
99039906
cast<clang::RecordDecl>(D->getClangDecl()));
9907+
if (IDC) // Set member deserialization status
9908+
IDC->setDeserializedMembers(true);
99049909
return;
99059910
}
99069911

@@ -9911,6 +9916,8 @@ ClangImporter::Implementation::loadAllMembers(Decl *D, uint64_t extra) {
99119916
}
99129917

99139918
loadAllMembersIntoExtension(D, extra);
9919+
if (IDC) // Set member deserialization status
9920+
IDC->setDeserializedMembers(true);
99149921
}
99159922

99169923
void ClangImporter::Implementation::loadAllMembersIntoExtension(

lib/Frontend/CompilerInvocation.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1354,6 +1354,7 @@ static bool ParseLangArgs(LangOptions &Opts, ArgList &Args,
13541354
}
13551355
}
13561356

1357+
Opts.SkipDeserializationChecksForPackageCMO = Args.hasArg(OPT_ExperimentalSkipDeserializationChecksForPackageCMO);
13571358
Opts.AllowNonResilientAccess =
13581359
Args.hasArg(OPT_experimental_allow_non_resilient_access) ||
13591360
Args.hasArg(OPT_allow_non_resilient_access) ||

lib/SIL/IR/TypeLowering.cpp

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2454,13 +2454,8 @@ namespace {
24542454
// The same should happen if the type was resilient and serialized in
24552455
// another module in the same package with package-cmo enabled, which
24562456
// treats those modules to be in the same resilience domain.
2457-
auto declModule = D->getModuleContext();
2458-
bool sameModule = (declModule == &TC.M);
2459-
bool serializedPackage = declModule != &TC.M &&
2460-
declModule->inSamePackage(&TC.M) &&
2461-
declModule->isResilient() &&
2462-
declModule->serializePackageEnabled();
2463-
auto inSameResilienceDomain = sameModule || serializedPackage;
2457+
auto inSameResilienceDomain = D->getModuleContext() == &TC.M ||
2458+
D->bypassResilienceInPackage(&TC.M);
24642459
if (inSameResilienceDomain)
24652460
properties.addSubobject(RecursiveProperties::forResilient());
24662461

0 commit comments

Comments
 (0)