@@ -217,19 +217,6 @@ getAssociatedFunctionGenericSignature(SILDifferentiableAttr *attr,
217
217
GenericSignatureBuilder::FloatingRequirementSource::forAbstract ();
218
218
for (auto &req : attr->getRequirements ())
219
219
builder.addRequirement (req, source, original->getModule ().getSwiftModule ());
220
- // Constrain all wrt parameters to conform to `Differentiable`.
221
- auto &ctx = original->getASTContext ();
222
- auto *diffableProto = ctx.getProtocol (KnownProtocolKind::Differentiable);
223
- auto paramIndexSet = attr->getIndices ().parameters ;
224
- for (unsigned paramIdx : paramIndexSet->getIndices ()) {
225
- if (!paramIndexSet->contains (paramIdx))
226
- continue ;
227
- auto paramType =
228
- original->getConventions ().getSILArgumentType (paramIdx).getASTType ();
229
- Requirement req (RequirementKind::Conformance, paramType,
230
- diffableProto->getDeclaredType ());
231
- builder.addRequirement (req, source, original->getModule ().getSwiftModule ());
232
- }
233
220
return std::move (builder)
234
221
.computeGenericSignature (SourceLoc (), /* allowConcreteGenericParams=*/ true )
235
222
->getCanonicalSignature ();
@@ -2863,17 +2850,11 @@ class VJPEmitter final
2863
2850
auto origTy = original->getLoweredFunctionType ();
2864
2851
auto lookupConformance = LookUpConformanceInModule (module .getSwiftModule ());
2865
2852
2866
- auto pbGenericSig = getAssociatedFunctionGenericSignature (attr, original);
2867
-
2868
2853
// RAII that pushes the original function's generic signature to
2869
2854
// `module.Types` so that the calls `module.Types.getTypeLowering()` below
2870
- // will know the pullback 's generic parameter types.
2855
+ // will know the original function 's generic parameter types.
2871
2856
Lowering::GenericContextScope genericContextScope (
2872
- module .Types , pbGenericSig);
2873
-
2874
- auto *pbGenericEnv = pbGenericSig
2875
- ? pbGenericSig->createGenericEnvironment ()
2876
- : nullptr ;
2857
+ module .Types , origTy->getGenericSignature ());
2877
2858
2878
2859
// Given a type, returns its formal SIL parameter info.
2879
2860
auto getTangentParameterInfoForOriginalResult = [&](
@@ -2965,6 +2946,10 @@ class VJPEmitter final
2965
2946
mangler.mangleAutoDiffLinearMapHelper (
2966
2947
original->getName (), AutoDiffLinearMapKind::Pullback,
2967
2948
indices)).str ();
2949
+ auto pbGenericSig = getAssociatedFunctionGenericSignature (attr, original);
2950
+ auto *pbGenericEnv = pbGenericSig
2951
+ ? pbGenericSig->createGenericEnvironment ()
2952
+ : nullptr ;
2968
2953
auto pbType = SILFunctionType::get (
2969
2954
pbGenericSig, origTy->getExtInfo (), origTy->getCoroutineKind (),
2970
2955
origTy->getCalleeConvention (), pbParams, {}, adjResults, None,
@@ -3286,7 +3271,7 @@ class VJPEmitter final
3286
3271
auto original = getOpValue (ai->getCallee ());
3287
3272
auto functionSource = original;
3288
3273
SILValue vjpValue;
3289
- // If ` functionSource` is a ` @differentiable` function, just extract it.
3274
+ // If functionSource is a @differentiable function, just extract it.
3290
3275
auto originalFnTy = original->getType ().castTo <SILFunctionType>();
3291
3276
if (originalFnTy->isDifferentiable ()) {
3292
3277
auto paramIndices = originalFnTy->getDifferentiationParameterIndices ();
@@ -3536,6 +3521,12 @@ class JVPEmitter final
3536
3521
auto origTy = original->getLoweredFunctionType ();
3537
3522
auto lookupConformance = LookUpConformanceInModule (module .getSwiftModule ());
3538
3523
3524
+ // RAII that pushes the original function's generic signature to
3525
+ // `module.Types` so that the calls `module.Types.getTypeLowering()` below
3526
+ // will know the original function's generic parameter types.
3527
+ Lowering::GenericContextScope genericContextScope (
3528
+ module .Types , origTy->getGenericSignature ());
3529
+
3539
3530
SmallVector<SILParameterInfo, 8 > diffParams;
3540
3531
SmallVector<SILResultInfo, 8 > diffResults;
3541
3532
auto origParams = origTy->getParameters ();
@@ -3562,13 +3553,6 @@ class JVPEmitter final
3562
3553
original->getName (), AutoDiffLinearMapKind::Differential,
3563
3554
indices)).str ();
3564
3555
auto diffGenericSig = getAssociatedFunctionGenericSignature (attr, original);
3565
-
3566
- // RAII that pushes the original function's generic signature to
3567
- // `module.Types` so that the calls `module.Types.getTypeLowering()` below
3568
- // will know the differential's generic parameter types.
3569
- Lowering::GenericContextScope genericContextScope (
3570
- module .Types , diffGenericSig);
3571
-
3572
3556
auto *diffGenericEnv = diffGenericSig
3573
3557
? diffGenericSig->createGenericEnvironment ()
3574
3558
: nullptr ;
@@ -5991,7 +5975,7 @@ static SILFunction *createEmptyJVP(
5991
5975
5992
5976
// RAII that pushes the original function's generic signature to
5993
5977
// `module.Types` so that the calls `module.Types.getTypeLowering()` below
5994
- // will know the JVP 's generic parameter types.
5978
+ // will know the VJP 's generic parameter types.
5995
5979
Lowering::GenericContextScope genericContextScope (
5996
5980
module .Types , jvpGenericSig);
5997
5981
0 commit comments