@@ -403,14 +403,17 @@ SILParameterInfo LargeSILTypeMapper::getNewParameter(GenericEnvironment *env,
403
403
} else if (isLargeLoadableType (env, storageType, IGM)) {
404
404
if (param.getConvention () == ParameterConvention::Direct_Guaranteed)
405
405
return SILParameterInfo (storageType.getASTType (),
406
- ParameterConvention::Indirect_In_Guaranteed);
406
+ ParameterConvention::Indirect_In_Guaranteed,
407
+ param.getDifferentiability ());
407
408
else
408
409
return SILParameterInfo (storageType.getASTType (),
409
- ParameterConvention::Indirect_In_Constant);
410
+ ParameterConvention::Indirect_In_Constant,
411
+ param.getDifferentiability ());
410
412
} else {
411
413
auto newType = getNewSILType (env, storageType, IGM);
412
414
return SILParameterInfo (newType.getASTType (),
413
- param.getConvention ());
415
+ param.getConvention (),
416
+ param.getDifferentiability ());
414
417
}
415
418
}
416
419
@@ -1704,6 +1707,9 @@ class LoadableByAddress : public SILModuleTransform {
1704
1707
bool fixStoreToBlockStorageInstr (SILInstruction &I,
1705
1708
SmallVectorImpl<SILInstruction *> &Delete);
1706
1709
1710
+ bool recreateDifferentiabilityWitnessFunction (
1711
+ SILInstruction &I, SmallVectorImpl<SILInstruction *> &Delete);
1712
+
1707
1713
private:
1708
1714
llvm::SetVector<SILFunction *> modFuncs;
1709
1715
llvm::SetVector<SingleValueInstruction *> conversionInstrs;
@@ -2708,6 +2714,33 @@ bool LoadableByAddress::fixStoreToBlockStorageInstr(
2708
2714
return true ;
2709
2715
}
2710
2716
2717
+ bool LoadableByAddress::recreateDifferentiabilityWitnessFunction (
2718
+ SILInstruction &I, SmallVectorImpl<SILInstruction *> &Delete) {
2719
+ auto *instr = dyn_cast<DifferentiabilityWitnessFunctionInst>(&I);
2720
+ if (!instr)
2721
+ return false ;
2722
+
2723
+ // Check if we need to recreate the instruction.
2724
+ auto *currIRMod = getIRGenModule ()->IRGen .getGenModule (instr->getFunction ());
2725
+ auto resultFnTy = instr->getType ().castTo <SILFunctionType>();
2726
+ auto genSig = resultFnTy->getSubstGenericSignature ();
2727
+ GenericEnvironment *genEnv = nullptr ;
2728
+ if (genSig)
2729
+ genEnv = genSig->getGenericEnvironment ();
2730
+ auto newResultFnTy =
2731
+ MapperCache.getNewSILFunctionType (genEnv, resultFnTy, *currIRMod);
2732
+ if (resultFnTy == newResultFnTy)
2733
+ return true ;
2734
+
2735
+ SILBuilderWithScope builder (instr);
2736
+ auto *newInstr = builder.createDifferentiabilityWitnessFunction (
2737
+ instr->getLoc (), instr->getWitnessKind (), instr->getWitness (),
2738
+ SILType::getPrimitiveObjectType (newResultFnTy));
2739
+ instr->replaceAllUsesWith (newInstr);
2740
+ Delete.push_back (instr);
2741
+ return true ;
2742
+ }
2743
+
2711
2744
bool LoadableByAddress::recreateTupleInstr (
2712
2745
SILInstruction &I, SmallVectorImpl<SILInstruction *> &Delete) {
2713
2746
auto *tupleInstr = dyn_cast<TupleInst>(&I);
@@ -2750,6 +2783,19 @@ bool LoadableByAddress::recreateConvInstr(SILInstruction &I,
2750
2783
auto currSILFunctionType = currSILType.castTo <SILFunctionType>();
2751
2784
GenericEnvironment *genEnv =
2752
2785
getSubstGenericEnvironment (convInstr->getFunction ());
2786
+ // Differentiable function conversion instructions can happen while the
2787
+ // function is still generic. In that case, we must calculate the new type
2788
+ // using the converted function's generic environment rather than the
2789
+ // converting function's generic environment.
2790
+ //
2791
+ // This happens in witness thunks for default implementations of derivative
2792
+ // requirements.
2793
+ if (convInstr->getKind () == SILInstructionKind::DifferentiableFunctionInst ||
2794
+ convInstr->getKind () == SILInstructionKind::DifferentiableFunctionExtractInst ||
2795
+ convInstr->getKind () == SILInstructionKind::LinearFunctionInst ||
2796
+ convInstr->getKind () == SILInstructionKind::LinearFunctionExtractInst)
2797
+ if (auto genSig = currSILFunctionType->getSubstGenericSignature ())
2798
+ genEnv = genSig->getGenericEnvironment ();
2753
2799
CanSILFunctionType newFnType = MapperCache.getNewSILFunctionType (
2754
2800
genEnv, currSILFunctionType, *currIRMod);
2755
2801
SILType newType = SILType::getPrimitiveObjectType (newFnType);
@@ -2790,6 +2836,34 @@ bool LoadableByAddress::recreateConvInstr(SILInstruction &I,
2790
2836
instr->getLoc (), instr->getValue (), instr->getBase ());
2791
2837
break ;
2792
2838
}
2839
+ case SILInstructionKind::DifferentiableFunctionInst: {
2840
+ auto instr = cast<DifferentiableFunctionInst>(convInstr);
2841
+ newInstr = convBuilder.createDifferentiableFunction (
2842
+ instr->getLoc (), instr->getParameterIndices (),
2843
+ instr->getOriginalFunction (),
2844
+ instr->getOptionalDerivativeFunctionPair ());
2845
+ break ;
2846
+ }
2847
+ case SILInstructionKind::DifferentiableFunctionExtractInst: {
2848
+ auto instr = cast<DifferentiableFunctionExtractInst>(convInstr);
2849
+ // Rewrite `differentiable_function_extract` with explicit extractee type.
2850
+ newInstr = convBuilder.createDifferentiableFunctionExtract (
2851
+ instr->getLoc (), instr->getExtractee (), instr->getOperand (), newType);
2852
+ break ;
2853
+ }
2854
+ case SILInstructionKind::LinearFunctionInst: {
2855
+ auto instr = cast<LinearFunctionInst>(convInstr);
2856
+ newInstr = convBuilder.createLinearFunction (
2857
+ instr->getLoc (), instr->getParameterIndices (),
2858
+ instr->getOriginalFunction (), instr->getOptionalTransposeFunction ());
2859
+ break ;
2860
+ }
2861
+ case SILInstructionKind::LinearFunctionExtractInst: {
2862
+ auto instr = cast<LinearFunctionExtractInst>(convInstr);
2863
+ newInstr = convBuilder.createLinearFunctionExtract (
2864
+ instr->getLoc (), instr->getExtractee (), instr->getFunctionOperand ());
2865
+ break ;
2866
+ }
2793
2867
default :
2794
2868
llvm_unreachable (" Unexpected conversion instruction" );
2795
2869
}
@@ -2878,7 +2952,11 @@ void LoadableByAddress::run() {
2878
2952
case SILInstructionKind::ConvertEscapeToNoEscapeInst:
2879
2953
case SILInstructionKind::MarkDependenceInst:
2880
2954
case SILInstructionKind::ThinFunctionToPointerInst:
2881
- case SILInstructionKind::ThinToThickFunctionInst: {
2955
+ case SILInstructionKind::ThinToThickFunctionInst:
2956
+ case SILInstructionKind::DifferentiableFunctionInst:
2957
+ case SILInstructionKind::LinearFunctionInst:
2958
+ case SILInstructionKind::LinearFunctionExtractInst:
2959
+ case SILInstructionKind::DifferentiableFunctionExtractInst: {
2882
2960
conversionInstrs.insert (
2883
2961
cast<SingleValueInstruction>(currInstr));
2884
2962
break ;
@@ -2945,6 +3023,11 @@ void LoadableByAddress::run() {
2945
3023
if (modApplies.count (PAI) == 0 ) {
2946
3024
modApplies.insert (PAI);
2947
3025
}
3026
+ } else if (isa<DifferentiableFunctionInst>(&I) ||
3027
+ isa<LinearFunctionInst>(&I) ||
3028
+ isa<DifferentiableFunctionExtractInst>(&I) ||
3029
+ isa<LinearFunctionExtractInst>(&I)) {
3030
+ conversionInstrs.insert (cast<SingleValueInstruction>(&I));
2948
3031
}
2949
3032
}
2950
3033
}
@@ -2988,6 +3071,8 @@ void LoadableByAddress::run() {
2988
3071
continue ;
2989
3072
else if (recreateApply (I, Delete))
2990
3073
continue ;
3074
+ else if (recreateDifferentiabilityWitnessFunction (I, Delete))
3075
+ continue ;
2991
3076
else
2992
3077
fixStoreToBlockStorageInstr (I, Delete);
2993
3078
}
0 commit comments