@@ -1711,6 +1711,8 @@ class LoadableByAddress : public SILModuleTransform {
1711
1711
bool recreateDifferentiabilityWitnessFunction (
1712
1712
SILInstruction &I, SmallVectorImpl<SILInstruction *> &Delete);
1713
1713
1714
+ bool shouldTransformGlobal (SILGlobalVariable *global);
1715
+
1714
1716
private:
1715
1717
llvm::SetVector<SILFunction *> modFuncs;
1716
1718
llvm::SetVector<SingleValueInstruction *> conversionInstrs;
@@ -2907,6 +2909,24 @@ void LoadableByAddress::updateLoweredTypes(SILFunction *F) {
2907
2909
F->rewriteLoweredTypeUnsafe (newFuncTy);
2908
2910
}
2909
2911
2912
+ bool LoadableByAddress::shouldTransformGlobal (SILGlobalVariable *global) {
2913
+ SILInstruction *init = global->getStaticInitializerValue ();
2914
+ if (!init)
2915
+ return false ;
2916
+ auto silTy = global->getLoweredType ();
2917
+ if (!isa<SILFunctionType>(silTy.getASTType ()))
2918
+ return false ;
2919
+
2920
+ auto *decl = global->getDecl ();
2921
+ IRGenModule *currIRMod = getIRGenModule ()->IRGen .getGenModule (
2922
+ decl ? decl->getDeclContext () : nullptr );
2923
+ auto silFnTy = global->getLoweredFunctionType ();
2924
+ GenericEnvironment *genEnv = getSubstGenericEnvironment (silFnTy);
2925
+ if (MapperCache.shouldTransformFunctionType (genEnv, silFnTy, *currIRMod))
2926
+ return true ;
2927
+ return false ;
2928
+ }
2929
+
2910
2930
// / The entry point to this function transformation.
2911
2931
void LoadableByAddress::run () {
2912
2932
// Set the SIL state before the PassManager has a chance to run
@@ -2922,10 +2942,23 @@ void LoadableByAddress::run() {
2922
2942
2923
2943
// Scan the module for all references of the modified functions:
2924
2944
llvm::SetVector<FunctionRefBaseInst *> funcRefs;
2945
+ llvm::SetVector<SILInstruction *> globalRefs;
2925
2946
for (SILFunction &CurrF : *getModule ()) {
2926
2947
for (SILBasicBlock &BB : CurrF) {
2927
2948
for (SILInstruction &I : BB) {
2928
- if (auto *FRI = dyn_cast<FunctionRefBaseInst>(&I)) {
2949
+ if (auto *allocGlobal = dyn_cast<AllocGlobalInst>(&I)) {
2950
+ auto *global = allocGlobal->getReferencedGlobal ();
2951
+ if (shouldTransformGlobal (global))
2952
+ globalRefs.insert (allocGlobal);
2953
+ } else if (auto *globalAddr = dyn_cast<GlobalAddrInst>(&I)) {
2954
+ auto *global = globalAddr->getReferencedGlobal ();
2955
+ if (shouldTransformGlobal (global))
2956
+ globalRefs.insert (globalAddr);
2957
+ } else if (auto *globalVal = dyn_cast<GlobalValueInst>(&I)) {
2958
+ auto *global = globalVal->getReferencedGlobal ();
2959
+ if (shouldTransformGlobal (global))
2960
+ globalRefs.insert (globalVal);
2961
+ } else if (auto *FRI = dyn_cast<FunctionRefBaseInst>(&I)) {
2929
2962
SILFunction *RefF = FRI->getInitiallyReferencedFunction ();
2930
2963
if (modFuncs.count (RefF) != 0 ) {
2931
2964
// Go over the uses and add them to lists to modify
@@ -2954,7 +2987,7 @@ void LoadableByAddress::run() {
2954
2987
case SILInstructionKind::LinearFunctionExtractInst:
2955
2988
case SILInstructionKind::DifferentiableFunctionExtractInst: {
2956
2989
conversionInstrs.insert (
2957
- cast<SingleValueInstruction>(currInstr));
2990
+ cast<SingleValueInstruction>(currInstr));
2958
2991
break ;
2959
2992
}
2960
2993
case SILInstructionKind::BuiltinInst: {
@@ -3032,6 +3065,99 @@ void LoadableByAddress::run() {
3032
3065
updateLoweredTypes (F);
3033
3066
}
3034
3067
3068
+ auto computeNewResultType = [&](SILType ty, IRGenModule *mod) -> SILType {
3069
+ auto currSILFunctionType = ty.castTo <SILFunctionType>();
3070
+ GenericEnvironment *genEnv =
3071
+ getSubstGenericEnvironment (currSILFunctionType);
3072
+ return SILType::getPrimitiveObjectType (
3073
+ MapperCache.getNewSILFunctionType (genEnv, currSILFunctionType, *mod));
3074
+ };
3075
+
3076
+ // Update globals' initializer.
3077
+ SmallVector<SILGlobalVariable *, 16 > deadGlobals;
3078
+ for (SILGlobalVariable &global : getModule ()->getSILGlobals ()) {
3079
+ SILInstruction *init = global.getStaticInitializerValue ();
3080
+ if (!init)
3081
+ continue ;
3082
+ auto silTy = global.getLoweredType ();
3083
+ if (!isa<SILFunctionType>(silTy.getASTType ()))
3084
+ continue ;
3085
+ auto *decl = global.getDecl ();
3086
+ IRGenModule *currIRMod = getIRGenModule ()->IRGen .getGenModule (
3087
+ decl ? decl->getDeclContext () : nullptr );
3088
+ auto silFnTy = global.getLoweredFunctionType ();
3089
+ GenericEnvironment *genEnv = getSubstGenericEnvironment (silFnTy);
3090
+
3091
+ // Update the global's type.
3092
+ if (MapperCache.shouldTransformFunctionType (genEnv, silFnTy, *currIRMod)) {
3093
+ auto newSILFnType =
3094
+ MapperCache.getNewSILFunctionType (genEnv, silFnTy, *currIRMod);
3095
+ global.unsafeSetLoweredType (
3096
+ SILType::getPrimitiveObjectType (newSILFnType));
3097
+
3098
+ // Rewrite the init basic block...
3099
+ SmallVector<SILInstruction *, 8 > initBlockInsts;
3100
+ for (auto it = global.begin (), end = global.end (); it != end; ++it) {
3101
+ initBlockInsts.push_back (const_cast <SILInstruction *>(&*it));
3102
+ }
3103
+ for (auto *oldInst : initBlockInsts) {
3104
+ if (auto *f = dyn_cast<FunctionRefInst>(oldInst)) {
3105
+ SILBuilder builder (&global);
3106
+ auto *newInst = builder.createFunctionRef (
3107
+ f->getLoc (), f->getInitiallyReferencedFunction (), f->getKind ());
3108
+ f->replaceAllUsesWith (newInst);
3109
+ global.unsafeRemove (f, *getModule ());
3110
+ } else if (auto *cvt = dyn_cast<ConvertFunctionInst>(oldInst)) {
3111
+ auto newType = computeNewResultType (cvt->getType (), currIRMod);
3112
+ SILBuilder builder (&global);
3113
+ auto *newInst = builder.createConvertFunction (
3114
+ cvt->getLoc (), cvt->getOperand (), newType,
3115
+ cvt->withoutActuallyEscaping ());
3116
+ cvt->replaceAllUsesWith (newInst);
3117
+ global.unsafeRemove (cvt, *getModule ());
3118
+ } else if (auto *thinToThick =
3119
+ dyn_cast<ThinToThickFunctionInst>(oldInst)) {
3120
+ auto newType =
3121
+ computeNewResultType (thinToThick->getType (), currIRMod);
3122
+ SILBuilder builder (&global);
3123
+ auto *newInstr = builder.createThinToThickFunction (
3124
+ thinToThick->getLoc (), thinToThick->getOperand (), newType);
3125
+ thinToThick->replaceAllUsesWith (newInstr);
3126
+ global.unsafeRemove (thinToThick, *getModule ());
3127
+ } else {
3128
+ auto *sv = cast<SingleValueInstruction>(oldInst);
3129
+ auto *newInst = cast<SingleValueInstruction>(oldInst->clone ());
3130
+ global.unsafeAppend (newInst);
3131
+ sv->replaceAllUsesWith (newInst);
3132
+ global.unsafeRemove (oldInst, *getModule ());
3133
+ }
3134
+ }
3135
+ }
3136
+ }
3137
+
3138
+ // Rewrite global variable users.
3139
+ for (auto *inst : globalRefs) {
3140
+ if (auto *allocGlobal = dyn_cast<AllocGlobalInst>(inst)) {
3141
+ // alloc_global produces no results.
3142
+ SILBuilderWithScope builder (inst);
3143
+ builder.createAllocGlobal (allocGlobal->getLoc (),
3144
+ allocGlobal->getReferencedGlobal ());
3145
+ allocGlobal->eraseFromParent ();
3146
+ } else if (auto *globalAddr = dyn_cast<GlobalAddrInst>(inst)) {
3147
+ SILBuilderWithScope builder (inst);
3148
+ auto *newInst = builder.createGlobalAddr (
3149
+ globalAddr->getLoc (), globalAddr->getReferencedGlobal ());
3150
+ globalAddr->replaceAllUsesWith (newInst);
3151
+ globalAddr->eraseFromParent ();
3152
+ } else if (auto *globalVal = dyn_cast<GlobalValueInst>(inst)) {
3153
+ SILBuilderWithScope builder (inst);
3154
+ auto *newInst = builder.createGlobalValue (
3155
+ globalVal->getLoc (), globalVal->getReferencedGlobal ());
3156
+ globalVal->replaceAllUsesWith (newInst);
3157
+ globalVal->eraseFromParent ();
3158
+ }
3159
+ }
3160
+
3035
3161
// Update all references:
3036
3162
// Note: We don't need to update the witness tables and vtables
3037
3163
// They just contain a pointer to the function
0 commit comments