@@ -241,6 +241,19 @@ getAssociatedFunctionGenericSignature(SILDifferentiableAttr *attr,
241
241
GenericSignatureBuilder::FloatingRequirementSource::forAbstract ();
242
242
for (auto &req : attr->getRequirements ())
243
243
builder.addRequirement (req, source, original->getModule ().getSwiftModule ());
244
+ // Constrain all wrt parameters to conform to `Differentiable`.
245
+ auto &ctx = original->getASTContext ();
246
+ auto *diffableProto = ctx.getProtocol (KnownProtocolKind::Differentiable);
247
+ auto paramIndexSet = attr->getIndices ().parameters ;
248
+ for (unsigned paramIdx : paramIndexSet->getIndices ()) {
249
+ if (!paramIndexSet->contains (paramIdx))
250
+ continue ;
251
+ auto paramType =
252
+ original->getConventions ().getSILArgumentType (paramIdx).getASTType ();
253
+ Requirement req (RequirementKind::Conformance, paramType,
254
+ diffableProto->getDeclaredType ());
255
+ builder.addRequirement (req, source, original->getModule ().getSwiftModule ());
256
+ }
244
257
return std::move (builder)
245
258
.computeGenericSignature (SourceLoc (), /* allowConcreteGenericParams=*/ true )
246
259
->getCanonicalSignature ();
@@ -2874,11 +2887,17 @@ class VJPEmitter final
2874
2887
auto origTy = original->getLoweredFunctionType ();
2875
2888
auto lookupConformance = LookUpConformanceInModule (module .getSwiftModule ());
2876
2889
2890
+ auto pbGenericSig = getAssociatedFunctionGenericSignature (attr, original);
2891
+
2877
2892
// RAII that pushes the original function's generic signature to
2878
2893
// `module.Types` so that the calls `module.Types.getTypeLowering()` below
2879
- // will know the original function 's generic parameter types.
2894
+ // will know the pullback 's generic parameter types.
2880
2895
Lowering::GenericContextScope genericContextScope (
2881
- module .Types , origTy->getGenericSignature ());
2896
+ module .Types , pbGenericSig);
2897
+
2898
+ auto *pbGenericEnv = pbGenericSig
2899
+ ? pbGenericSig->createGenericEnvironment ()
2900
+ : nullptr ;
2882
2901
2883
2902
// Given a type, returns its formal SIL parameter info.
2884
2903
auto getTangentParameterInfoForOriginalResult = [&](
@@ -2970,10 +2989,6 @@ class VJPEmitter final
2970
2989
mangler.mangleAutoDiffLinearMapHelper (
2971
2990
original->getName (), AutoDiffLinearMapKind::Pullback,
2972
2991
indices)).str ();
2973
- auto pbGenericSig = getAssociatedFunctionGenericSignature (attr, original);
2974
- auto *pbGenericEnv = pbGenericSig
2975
- ? pbGenericSig->createGenericEnvironment ()
2976
- : nullptr ;
2977
2992
auto pbType = SILFunctionType::get (
2978
2993
pbGenericSig, origTy->getExtInfo (), origTy->getCoroutineKind (),
2979
2994
origTy->getCalleeConvention (), pbParams, {}, adjResults, None,
@@ -3296,7 +3311,7 @@ class VJPEmitter final
3296
3311
auto original = getOpValue (ai->getCallee ());
3297
3312
auto functionSource = original;
3298
3313
SILValue vjpValue;
3299
- // If functionSource is a @differentiable function, just extract it.
3314
+ // If ` functionSource` is a ` @differentiable` function, just extract it.
3300
3315
auto originalFnTy = original->getType ().castTo <SILFunctionType>();
3301
3316
if (originalFnTy->isDifferentiable ()) {
3302
3317
auto paramIndices = originalFnTy->getDifferentiationParameterIndices ();
@@ -3531,12 +3546,6 @@ class JVPEmitter final
3531
3546
auto origTy = original->getLoweredFunctionType ();
3532
3547
auto lookupConformance = LookUpConformanceInModule (module .getSwiftModule ());
3533
3548
3534
- // RAII that pushes the original function's generic signature to
3535
- // `module.Types` so that the calls `module.Types.getTypeLowering()` below
3536
- // will know the original function's generic parameter types.
3537
- Lowering::GenericContextScope genericContextScope (
3538
- module .Types , origTy->getGenericSignature ());
3539
-
3540
3549
SmallVector<SILParameterInfo, 8 > diffParams;
3541
3550
SmallVector<SILResultInfo, 8 > diffResults;
3542
3551
auto origParams = origTy->getParameters ();
@@ -3563,6 +3572,13 @@ class JVPEmitter final
3563
3572
original->getName (), AutoDiffLinearMapKind::Differential,
3564
3573
indices)).str ();
3565
3574
auto diffGenericSig = getAssociatedFunctionGenericSignature (attr, original);
3575
+
3576
+ // RAII that pushes the original function's generic signature to
3577
+ // `module.Types` so that the calls `module.Types.getTypeLowering()` below
3578
+ // will know the differential's generic parameter types.
3579
+ Lowering::GenericContextScope genericContextScope (
3580
+ module .Types , diffGenericSig);
3581
+
3566
3582
auto *diffGenericEnv = diffGenericSig
3567
3583
? diffGenericSig->createGenericEnvironment ()
3568
3584
: nullptr ;
@@ -5985,7 +6001,7 @@ static SILFunction *createEmptyJVP(
5985
6001
5986
6002
// RAII that pushes the original function's generic signature to
5987
6003
// `module.Types` so that the calls `module.Types.getTypeLowering()` below
5988
- // will know the VJP 's generic parameter types.
6004
+ // will know the JVP 's generic parameter types.
5989
6005
Lowering::GenericContextScope genericContextScope (
5990
6006
module .Types , jvpGenericSig);
5991
6007
0 commit comments