Skip to content

Commit 6f58fd4

Browse files
authored
[AutoDiff] Constrain wrt parameters to conform to Differentiable. (#26406)
Constrain all wrt parameters to conform to `Differentiable` when computing AD associated function generic signatures. This fixes crashes when differentiating generic original functions that do not constrain parameters to be `Differentiable`, e.g. an unconstrained identity function. Resolves TF-691.
1 parent deb70d3 commit 6f58fd4

File tree

2 files changed

+52
-15
lines changed

2 files changed

+52
-15
lines changed

lib/SILOptimizer/Mandatory/Differentiation.cpp

Lines changed: 30 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -241,6 +241,19 @@ getAssociatedFunctionGenericSignature(SILDifferentiableAttr *attr,
241241
GenericSignatureBuilder::FloatingRequirementSource::forAbstract();
242242
for (auto &req : attr->getRequirements())
243243
builder.addRequirement(req, source, original->getModule().getSwiftModule());
244+
// Constrain all wrt parameters to conform to `Differentiable`.
245+
auto &ctx = original->getASTContext();
246+
auto *diffableProto = ctx.getProtocol(KnownProtocolKind::Differentiable);
247+
auto paramIndexSet = attr->getIndices().parameters;
248+
for (unsigned paramIdx : paramIndexSet->getIndices()) {
249+
if (!paramIndexSet->contains(paramIdx))
250+
continue;
251+
auto paramType =
252+
original->getConventions().getSILArgumentType(paramIdx).getASTType();
253+
Requirement req(RequirementKind::Conformance, paramType,
254+
diffableProto->getDeclaredType());
255+
builder.addRequirement(req, source, original->getModule().getSwiftModule());
256+
}
244257
return std::move(builder)
245258
.computeGenericSignature(SourceLoc(), /*allowConcreteGenericParams=*/true)
246259
->getCanonicalSignature();
@@ -2874,11 +2887,17 @@ class VJPEmitter final
28742887
auto origTy = original->getLoweredFunctionType();
28752888
auto lookupConformance = LookUpConformanceInModule(module.getSwiftModule());
28762889

2890+
auto pbGenericSig = getAssociatedFunctionGenericSignature(attr, original);
2891+
28772892
// RAII that pushes the original function's generic signature to
28782893
// `module.Types` so that the calls `module.Types.getTypeLowering()` below
2879-
// will know the original function's generic parameter types.
2894+
// will know the pullback's generic parameter types.
28802895
Lowering::GenericContextScope genericContextScope(
2881-
module.Types, origTy->getGenericSignature());
2896+
module.Types, pbGenericSig);
2897+
2898+
auto *pbGenericEnv = pbGenericSig
2899+
? pbGenericSig->createGenericEnvironment()
2900+
: nullptr;
28822901

28832902
// Given a type, returns its formal SIL parameter info.
28842903
auto getTangentParameterInfoForOriginalResult = [&](
@@ -2970,10 +2989,6 @@ class VJPEmitter final
29702989
mangler.mangleAutoDiffLinearMapHelper(
29712990
original->getName(), AutoDiffLinearMapKind::Pullback,
29722991
indices)).str();
2973-
auto pbGenericSig = getAssociatedFunctionGenericSignature(attr, original);
2974-
auto *pbGenericEnv = pbGenericSig
2975-
? pbGenericSig->createGenericEnvironment()
2976-
: nullptr;
29772992
auto pbType = SILFunctionType::get(
29782993
pbGenericSig, origTy->getExtInfo(), origTy->getCoroutineKind(),
29792994
origTy->getCalleeConvention(), pbParams, {}, adjResults, None,
@@ -3296,7 +3311,7 @@ class VJPEmitter final
32963311
auto original = getOpValue(ai->getCallee());
32973312
auto functionSource = original;
32983313
SILValue vjpValue;
3299-
// If functionSource is a @differentiable function, just extract it.
3314+
// If `functionSource` is a `@differentiable` function, just extract it.
33003315
auto originalFnTy = original->getType().castTo<SILFunctionType>();
33013316
if (originalFnTy->isDifferentiable()) {
33023317
auto paramIndices = originalFnTy->getDifferentiationParameterIndices();
@@ -3531,12 +3546,6 @@ class JVPEmitter final
35313546
auto origTy = original->getLoweredFunctionType();
35323547
auto lookupConformance = LookUpConformanceInModule(module.getSwiftModule());
35333548

3534-
// RAII that pushes the original function's generic signature to
3535-
// `module.Types` so that the calls `module.Types.getTypeLowering()` below
3536-
// will know the original function's generic parameter types.
3537-
Lowering::GenericContextScope genericContextScope(
3538-
module.Types, origTy->getGenericSignature());
3539-
35403549
SmallVector<SILParameterInfo, 8> diffParams;
35413550
SmallVector<SILResultInfo, 8> diffResults;
35423551
auto origParams = origTy->getParameters();
@@ -3563,6 +3572,13 @@ class JVPEmitter final
35633572
original->getName(), AutoDiffLinearMapKind::Differential,
35643573
indices)).str();
35653574
auto diffGenericSig = getAssociatedFunctionGenericSignature(attr, original);
3575+
3576+
// RAII that pushes the original function's generic signature to
3577+
// `module.Types` so that the calls `module.Types.getTypeLowering()` below
3578+
// will know the differential's generic parameter types.
3579+
Lowering::GenericContextScope genericContextScope(
3580+
module.Types, diffGenericSig);
3581+
35663582
auto *diffGenericEnv = diffGenericSig
35673583
? diffGenericSig->createGenericEnvironment()
35683584
: nullptr;
@@ -5985,7 +6001,7 @@ static SILFunction *createEmptyJVP(
59856001

59866002
// RAII that pushes the original function's generic signature to
59876003
// `module.Types` so that the calls `module.Types.getTypeLowering()` below
5988-
// will know the VJP's generic parameter types.
6004+
// will know the JVP's generic parameter types.
59896005
Lowering::GenericContextScope genericContextScope(
59906006
module.Types, jvpGenericSig);
59916007

test/AutoDiff/autodiff_indirect_diagnostics.swift

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,9 @@ func generic<T: Differentiable & FloatingPoint>(_ x: T) -> T {
1212
}
1313
_ = gradient(at: 1.0, in: generic)
1414

15-
// Test unmet generic requirements.
15+
//===----------------------------------------------------------------------===//
16+
// Unmet generic requirements
17+
//===----------------------------------------------------------------------===//
1618

1719
@differentiable(
1820
vjp: vjpWeirdExtraRequirements
@@ -67,3 +69,22 @@ struct TF8Struct<Scalar> : TF8Proto where Scalar : FloatingPoint & Differentiabl
6769
}
6870

6971
_ = gradient(at: 1.0, in: { x in x.squareRoot() })
72+
73+
//===----------------------------------------------------------------------===//
74+
// Add `Differentiable` conformance for generic wrt parameters
75+
//===----------------------------------------------------------------------===//
76+
77+
func id<T>(_ x: T) -> T { x }
78+
let _: @differentiable (Float) -> Float = { x in id(x) }
79+
80+
struct TF_691<Scalar> {
81+
var x: Scalar
82+
init(_ x: Scalar) {
83+
self.x = x
84+
}
85+
}
86+
extension TF_691: Differentiable where Scalar: Differentiable {}
87+
88+
func identity<T>(_ x: TF_691<T>) -> TF_691<T> { x }
89+
let _: @differentiable (Float) -> TF_691<Float> = { x in identity(TF_691(x)) }
90+
let _: @differentiable (Float) -> TF_691<Float> = { x in id(TF_691(x)) }

0 commit comments

Comments
 (0)