Skip to content

Commit f6709cd

Browse files
committed
LoadableByAddress: Updating global variables' types
Because of pointer authentication the type of global variables needs to be updated. rdar://93688980
1 parent b8c61e1 commit f6709cd

File tree

2 files changed

+133
-2
lines changed

2 files changed

+133
-2
lines changed

include/swift/SIL/SILGlobalVariable.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,10 @@ class SILGlobalVariable
125125
return getLoweredTypeInContext(context).castTo<SILFunctionType>();
126126
}
127127

128+
void unsafeSetLoweredType(SILType newType) { LoweredType = newType; }
129+
void unsafeAppend(SILInstruction *i) { StaticInitializerBlock.push_back(i); }
130+
void unsafeRemove(SILInstruction *i, SILModule &mod) { StaticInitializerBlock.erase(i, mod); }
131+
128132
StringRef getName() const { return Name; }
129133

130134
void setDeclaration(bool isD) { IsDeclaration = isD; }

lib/IRGen/LoadableByAddress.cpp

Lines changed: 129 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1711,6 +1711,8 @@ class LoadableByAddress : public SILModuleTransform {
17111711
bool recreateDifferentiabilityWitnessFunction(
17121712
SILInstruction &I, SmallVectorImpl<SILInstruction *> &Delete);
17131713

1714+
bool shouldTransformGlobal(SILGlobalVariable *global);
1715+
17141716
private:
17151717
llvm::SetVector<SILFunction *> modFuncs;
17161718
llvm::SetVector<SingleValueInstruction *> conversionInstrs;
@@ -2907,6 +2909,24 @@ void LoadableByAddress::updateLoweredTypes(SILFunction *F) {
29072909
F->rewriteLoweredTypeUnsafe(newFuncTy);
29082910
}
29092911

2912+
bool LoadableByAddress::shouldTransformGlobal(SILGlobalVariable *global) {
2913+
SILInstruction *init = global->getStaticInitializerValue();
2914+
if (!init)
2915+
return false;
2916+
auto silTy = global->getLoweredType();
2917+
if (!isa<SILFunctionType>(silTy.getASTType()))
2918+
return false;
2919+
2920+
auto *decl = global->getDecl();
2921+
IRGenModule *currIRMod = getIRGenModule()->IRGen.getGenModule(
2922+
decl ? decl->getDeclContext() : nullptr);
2923+
auto silFnTy = global->getLoweredFunctionType();
2924+
GenericEnvironment *genEnv = getSubstGenericEnvironment(silFnTy);
2925+
if (MapperCache.shouldTransformFunctionType(genEnv, silFnTy, *currIRMod))
2926+
return true;
2927+
return false;
2928+
}
2929+
29102930
/// The entry point to this function transformation.
29112931
void LoadableByAddress::run() {
29122932
// Set the SIL state before the PassManager has a chance to run
@@ -2922,10 +2942,23 @@ void LoadableByAddress::run() {
29222942

29232943
// Scan the module for all references of the modified functions:
29242944
llvm::SetVector<FunctionRefBaseInst *> funcRefs;
2945+
llvm::SetVector<SILInstruction *> globalRefs;
29252946
for (SILFunction &CurrF : *getModule()) {
29262947
for (SILBasicBlock &BB : CurrF) {
29272948
for (SILInstruction &I : BB) {
2928-
if (auto *FRI = dyn_cast<FunctionRefBaseInst>(&I)) {
2949+
if (auto *allocGlobal = dyn_cast<AllocGlobalInst>(&I)) {
2950+
auto *global = allocGlobal->getReferencedGlobal();
2951+
if (shouldTransformGlobal(global))
2952+
globalRefs.insert(allocGlobal);
2953+
} else if (auto *globalAddr = dyn_cast<GlobalAddrInst>(&I)) {
2954+
auto *global = globalAddr->getReferencedGlobal();
2955+
if (shouldTransformGlobal(global))
2956+
globalRefs.insert(globalAddr);
2957+
} else if (auto *globalVal = dyn_cast<GlobalValueInst>(&I)) {
2958+
auto *global = globalVal->getReferencedGlobal();
2959+
if (shouldTransformGlobal(global))
2960+
globalRefs.insert(globalVal);
2961+
} else if (auto *FRI = dyn_cast<FunctionRefBaseInst>(&I)) {
29292962
SILFunction *RefF = FRI->getInitiallyReferencedFunction();
29302963
if (modFuncs.count(RefF) != 0) {
29312964
// Go over the uses and add them to lists to modify
@@ -2954,7 +2987,7 @@ void LoadableByAddress::run() {
29542987
case SILInstructionKind::LinearFunctionExtractInst:
29552988
case SILInstructionKind::DifferentiableFunctionExtractInst: {
29562989
conversionInstrs.insert(
2957-
cast<SingleValueInstruction>(currInstr));
2990+
cast<SingleValueInstruction>(currInstr));
29582991
break;
29592992
}
29602993
case SILInstructionKind::BuiltinInst: {
@@ -3032,6 +3065,100 @@ void LoadableByAddress::run() {
30323065
updateLoweredTypes(F);
30333066
}
30343067

3068+
SmallVector<SILGlobalVariable *, 16> deadGlobals;
3069+
for (SILGlobalVariable &global : getModule()->getSILGlobals()) {
3070+
SILInstruction *init = global.getStaticInitializerValue();
3071+
if (!init)
3072+
continue;
3073+
auto silTy = global.getLoweredType();
3074+
if (!isa<SILFunctionType>(silTy.getASTType()))
3075+
continue;
3076+
auto *decl = global.getDecl();
3077+
IRGenModule *currIRMod = getIRGenModule()->IRGen.getGenModule(
3078+
decl ? decl->getDeclContext() : nullptr);
3079+
auto silFnTy = global.getLoweredFunctionType();
3080+
GenericEnvironment *genEnv = getSubstGenericEnvironment(silFnTy);
3081+
3082+
// Update the global's type.
3083+
if (MapperCache.shouldTransformFunctionType(genEnv, silFnTy, *currIRMod)) {
3084+
auto newSILFnType =
3085+
MapperCache.getNewSILFunctionType(genEnv, silFnTy, *currIRMod);
3086+
global.unsafeSetLoweredType(
3087+
SILType::getPrimitiveObjectType(newSILFnType));
3088+
3089+
// Rewrite the init basic block...
3090+
SmallVector<SILInstruction *, 8> initBlockInsts;
3091+
for (auto it = global.begin(), end = global.end(); it != end; ++it) {
3092+
initBlockInsts.push_back(const_cast<SILInstruction *>(&*it));
3093+
}
3094+
for (auto *oldInst : initBlockInsts) {
3095+
if (auto *f = dyn_cast<FunctionRefInst>(oldInst)) {
3096+
SILBuilder builder(&global);
3097+
auto *newInst = builder.createFunctionRef(
3098+
f->getLoc(), f->getInitiallyReferencedFunction(), f->getKind());
3099+
f->replaceAllUsesWith(newInst);
3100+
global.unsafeRemove(f, *getModule());
3101+
} else if (auto *cvt = dyn_cast<ConvertFunctionInst>(oldInst)) {
3102+
SILType currSILType = cvt->getType();
3103+
auto currSILFunctionType = currSILType.castTo<SILFunctionType>();
3104+
GenericEnvironment *genEnv =
3105+
getSubstGenericEnvironment(currSILFunctionType);
3106+
auto newType = SILType::getPrimitiveObjectType(
3107+
MapperCache.getNewSILFunctionType(genEnv, silFnTy, *currIRMod));
3108+
3109+
SILBuilder builder(&global);
3110+
auto *newInst = builder.createConvertFunction(
3111+
cvt->getLoc(), cvt->getOperand(), newType,
3112+
cvt->withoutActuallyEscaping());
3113+
cvt->replaceAllUsesWith(newInst);
3114+
global.unsafeRemove(cvt, *getModule());
3115+
} else if (auto *thinToThick =
3116+
dyn_cast<ThinToThickFunctionInst>(oldInst)) {
3117+
SILType currSILType = thinToThick->getType();
3118+
auto currSILFunctionType = currSILType.castTo<SILFunctionType>();
3119+
GenericEnvironment *genEnv =
3120+
getSubstGenericEnvironment(currSILFunctionType);
3121+
auto newType = SILType::getPrimitiveObjectType(
3122+
MapperCache.getNewSILFunctionType(genEnv, silFnTy, *currIRMod));
3123+
SILBuilder builder(&global);
3124+
auto *newInstr = builder.createThinToThickFunction(
3125+
thinToThick->getLoc(), thinToThick->getOperand(), newType);
3126+
thinToThick->replaceAllUsesWith(newInstr);
3127+
global.unsafeRemove(thinToThick, *getModule());
3128+
} else {
3129+
auto *sv = cast<SingleValueInstruction>(oldInst);
3130+
auto *newInst = cast<SingleValueInstruction>(oldInst->clone());
3131+
global.unsafeAppend(newInst);
3132+
sv->replaceAllUsesWith(newInst);
3133+
global.unsafeRemove(oldInst, *getModule());
3134+
}
3135+
}
3136+
}
3137+
}
3138+
3139+
// Rewrite global variable users.
3140+
for (auto *inst : globalRefs) {
3141+
if (auto *allocGlobal = dyn_cast<AllocGlobalInst>(inst)) {
3142+
// alloc_global produces no results.
3143+
SILBuilderWithScope builder(inst);
3144+
builder.createAllocGlobal(allocGlobal->getLoc(),
3145+
allocGlobal->getReferencedGlobal());
3146+
allocGlobal->eraseFromParent();
3147+
} else if (auto *globalAddr = dyn_cast<GlobalAddrInst>(inst)) {
3148+
SILBuilderWithScope builder(inst);
3149+
auto *newInst = builder.createGlobalAddr(
3150+
globalAddr->getLoc(), globalAddr->getReferencedGlobal());
3151+
globalAddr->replaceAllUsesWith(newInst);
3152+
globalAddr->eraseFromParent();
3153+
} else if (auto *globalVal = dyn_cast<GlobalValueInst>(inst)) {
3154+
SILBuilderWithScope builder(inst);
3155+
auto *newInst = builder.createGlobalValue(
3156+
globalVal->getLoc(), globalVal->getReferencedGlobal());
3157+
globalVal->replaceAllUsesWith(newInst);
3158+
globalVal->eraseFromParent();
3159+
}
3160+
}
3161+
30353162
// Update all references:
30363163
// Note: We don't need to update the witness tables and vtables
30373164
// They just contain a pointer to the function

0 commit comments

Comments
 (0)