@@ -1711,6 +1711,8 @@ class LoadableByAddress : public SILModuleTransform {
1711
1711
bool recreateDifferentiabilityWitnessFunction (
1712
1712
SILInstruction &I, SmallVectorImpl<SILInstruction *> &Delete);
1713
1713
1714
+ bool shouldTransformGlobal (SILGlobalVariable *global);
1715
+
1714
1716
private:
1715
1717
llvm::SetVector<SILFunction *> modFuncs;
1716
1718
llvm::SetVector<SingleValueInstruction *> conversionInstrs;
@@ -2907,6 +2909,24 @@ void LoadableByAddress::updateLoweredTypes(SILFunction *F) {
2907
2909
F->rewriteLoweredTypeUnsafe (newFuncTy);
2908
2910
}
2909
2911
2912
+ bool LoadableByAddress::shouldTransformGlobal (SILGlobalVariable *global) {
2913
+ SILInstruction *init = global->getStaticInitializerValue ();
2914
+ if (!init)
2915
+ return false ;
2916
+ auto silTy = global->getLoweredType ();
2917
+ if (!isa<SILFunctionType>(silTy.getASTType ()))
2918
+ return false ;
2919
+
2920
+ auto *decl = global->getDecl ();
2921
+ IRGenModule *currIRMod = getIRGenModule ()->IRGen .getGenModule (
2922
+ decl ? decl->getDeclContext () : nullptr );
2923
+ auto silFnTy = global->getLoweredFunctionType ();
2924
+ GenericEnvironment *genEnv = getSubstGenericEnvironment (silFnTy);
2925
+ if (MapperCache.shouldTransformFunctionType (genEnv, silFnTy, *currIRMod))
2926
+ return true ;
2927
+ return false ;
2928
+ }
2929
+
2910
2930
// / The entry point to this function transformation.
2911
2931
void LoadableByAddress::run () {
2912
2932
// Set the SIL state before the PassManager has a chance to run
@@ -2922,10 +2942,23 @@ void LoadableByAddress::run() {
2922
2942
2923
2943
// Scan the module for all references of the modified functions:
2924
2944
llvm::SetVector<FunctionRefBaseInst *> funcRefs;
2945
+ llvm::SetVector<SILInstruction *> globalRefs;
2925
2946
for (SILFunction &CurrF : *getModule ()) {
2926
2947
for (SILBasicBlock &BB : CurrF) {
2927
2948
for (SILInstruction &I : BB) {
2928
- if (auto *FRI = dyn_cast<FunctionRefBaseInst>(&I)) {
2949
+ if (auto *allocGlobal = dyn_cast<AllocGlobalInst>(&I)) {
2950
+ auto *global = allocGlobal->getReferencedGlobal ();
2951
+ if (shouldTransformGlobal (global))
2952
+ globalRefs.insert (allocGlobal);
2953
+ } else if (auto *globalAddr = dyn_cast<GlobalAddrInst>(&I)) {
2954
+ auto *global = globalAddr->getReferencedGlobal ();
2955
+ if (shouldTransformGlobal (global))
2956
+ globalRefs.insert (globalAddr);
2957
+ } else if (auto *globalVal = dyn_cast<GlobalValueInst>(&I)) {
2958
+ auto *global = globalVal->getReferencedGlobal ();
2959
+ if (shouldTransformGlobal (global))
2960
+ globalRefs.insert (globalVal);
2961
+ } else if (auto *FRI = dyn_cast<FunctionRefBaseInst>(&I)) {
2929
2962
SILFunction *RefF = FRI->getInitiallyReferencedFunction ();
2930
2963
if (modFuncs.count (RefF) != 0 ) {
2931
2964
// Go over the uses and add them to lists to modify
@@ -2954,7 +2987,7 @@ void LoadableByAddress::run() {
2954
2987
case SILInstructionKind::LinearFunctionExtractInst:
2955
2988
case SILInstructionKind::DifferentiableFunctionExtractInst: {
2956
2989
conversionInstrs.insert (
2957
- cast<SingleValueInstruction>(currInstr));
2990
+ cast<SingleValueInstruction>(currInstr));
2958
2991
break ;
2959
2992
}
2960
2993
case SILInstructionKind::BuiltinInst: {
@@ -3032,6 +3065,100 @@ void LoadableByAddress::run() {
3032
3065
updateLoweredTypes (F);
3033
3066
}
3034
3067
3068
+ SmallVector<SILGlobalVariable *, 16 > deadGlobals;
3069
+ for (SILGlobalVariable &global : getModule ()->getSILGlobals ()) {
3070
+ SILInstruction *init = global.getStaticInitializerValue ();
3071
+ if (!init)
3072
+ continue ;
3073
+ auto silTy = global.getLoweredType ();
3074
+ if (!isa<SILFunctionType>(silTy.getASTType ()))
3075
+ continue ;
3076
+ auto *decl = global.getDecl ();
3077
+ IRGenModule *currIRMod = getIRGenModule ()->IRGen .getGenModule (
3078
+ decl ? decl->getDeclContext () : nullptr );
3079
+ auto silFnTy = global.getLoweredFunctionType ();
3080
+ GenericEnvironment *genEnv = getSubstGenericEnvironment (silFnTy);
3081
+
3082
+ // Update the global's type.
3083
+ if (MapperCache.shouldTransformFunctionType (genEnv, silFnTy, *currIRMod)) {
3084
+ auto newSILFnType =
3085
+ MapperCache.getNewSILFunctionType (genEnv, silFnTy, *currIRMod);
3086
+ global.unsafeSetLoweredType (
3087
+ SILType::getPrimitiveObjectType (newSILFnType));
3088
+
3089
+ // Rewrite the init basic block...
3090
+ SmallVector<SILInstruction *, 8 > initBlockInsts;
3091
+ for (auto it = global.begin (), end = global.end (); it != end; ++it) {
3092
+ initBlockInsts.push_back (const_cast <SILInstruction *>(&*it));
3093
+ }
3094
+ for (auto *oldInst : initBlockInsts) {
3095
+ if (auto *f = dyn_cast<FunctionRefInst>(oldInst)) {
3096
+ SILBuilder builder (&global);
3097
+ auto *newInst = builder.createFunctionRef (
3098
+ f->getLoc (), f->getInitiallyReferencedFunction (), f->getKind ());
3099
+ f->replaceAllUsesWith (newInst);
3100
+ global.unsafeRemove (f, *getModule ());
3101
+ } else if (auto *cvt = dyn_cast<ConvertFunctionInst>(oldInst)) {
3102
+ SILType currSILType = cvt->getType ();
3103
+ auto currSILFunctionType = currSILType.castTo <SILFunctionType>();
3104
+ GenericEnvironment *genEnv =
3105
+ getSubstGenericEnvironment (currSILFunctionType);
3106
+ auto newType = SILType::getPrimitiveObjectType (
3107
+ MapperCache.getNewSILFunctionType (genEnv, silFnTy, *currIRMod));
3108
+
3109
+ SILBuilder builder (&global);
3110
+ auto *newInst = builder.createConvertFunction (
3111
+ cvt->getLoc (), cvt->getOperand (), newType,
3112
+ cvt->withoutActuallyEscaping ());
3113
+ cvt->replaceAllUsesWith (newInst);
3114
+ global.unsafeRemove (cvt, *getModule ());
3115
+ } else if (auto *thinToThick =
3116
+ dyn_cast<ThinToThickFunctionInst>(oldInst)) {
3117
+ SILType currSILType = thinToThick->getType ();
3118
+ auto currSILFunctionType = currSILType.castTo <SILFunctionType>();
3119
+ GenericEnvironment *genEnv =
3120
+ getSubstGenericEnvironment (currSILFunctionType);
3121
+ auto newType = SILType::getPrimitiveObjectType (
3122
+ MapperCache.getNewSILFunctionType (genEnv, silFnTy, *currIRMod));
3123
+ SILBuilder builder (&global);
3124
+ auto *newInstr = builder.createThinToThickFunction (
3125
+ thinToThick->getLoc (), thinToThick->getOperand (), newType);
3126
+ thinToThick->replaceAllUsesWith (newInstr);
3127
+ global.unsafeRemove (thinToThick, *getModule ());
3128
+ } else {
3129
+ auto *sv = cast<SingleValueInstruction>(oldInst);
3130
+ auto *newInst = cast<SingleValueInstruction>(oldInst->clone ());
3131
+ global.unsafeAppend (newInst);
3132
+ sv->replaceAllUsesWith (newInst);
3133
+ global.unsafeRemove (oldInst, *getModule ());
3134
+ }
3135
+ }
3136
+ }
3137
+ }
3138
+
3139
+ // Rewrite global variable users.
3140
+ for (auto *inst : globalRefs) {
3141
+ if (auto *allocGlobal = dyn_cast<AllocGlobalInst>(inst)) {
3142
+ // alloc_global produces no results.
3143
+ SILBuilderWithScope builder (inst);
3144
+ builder.createAllocGlobal (allocGlobal->getLoc (),
3145
+ allocGlobal->getReferencedGlobal ());
3146
+ allocGlobal->eraseFromParent ();
3147
+ } else if (auto *globalAddr = dyn_cast<GlobalAddrInst>(inst)) {
3148
+ SILBuilderWithScope builder (inst);
3149
+ auto *newInst = builder.createGlobalAddr (
3150
+ globalAddr->getLoc (), globalAddr->getReferencedGlobal ());
3151
+ globalAddr->replaceAllUsesWith (newInst);
3152
+ globalAddr->eraseFromParent ();
3153
+ } else if (auto *globalVal = dyn_cast<GlobalValueInst>(inst)) {
3154
+ SILBuilderWithScope builder (inst);
3155
+ auto *newInst = builder.createGlobalValue (
3156
+ globalVal->getLoc (), globalVal->getReferencedGlobal ());
3157
+ globalVal->replaceAllUsesWith (newInst);
3158
+ globalVal->eraseFromParent ();
3159
+ }
3160
+ }
3161
+
3035
3162
// Update all references:
3036
3163
// Note: We don't need to update the witness tables and vtables
3037
3164
// They just contain a pointer to the function
0 commit comments