Skip to content

Commit ead5f4d

Browse files
authored
[AutoDiff] Constrain wrt parameters to conform to Differentiable. (#26426)
Constrain all wrt parameters to conform to `Differentiable` when computing AD associated function generic signatures. This fixes crashes when differentiating generic original functions that do not constrain parameters to be `Differentiable`, e.g. an unconstrained identity function. Gardening included: - Remove unused `isSerialized` flag from `SILGenModule::getOrCreateAutoDiffAssociatedFunctionThunk`. - Rename `whereClauseGenericSignature` in SIL to `associatedFunctionGenericSignature`. - The generic signature does not necessarily come from the `where` clause of a `[differentiable]` attribute. Resolves TF-691 and TF-697.
1 parent 9fceb38 commit ead5f4d

File tree

7 files changed

+81
-42
lines changed

7 files changed

+81
-42
lines changed

include/swift/AST/Types.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4228,11 +4228,11 @@ class SILFunctionType final : public TypeBase, public llvm::FoldingSetNode,
42284228
AutoDiffIndexSubset *parameterIndices, unsigned resultIndex,
42294229
unsigned differentiationOrder, AutoDiffAssociatedFunctionKind kind,
42304230
SILModule &module, LookupConformanceFn lookupConformance,
4231-
CanGenericSignature whereClauseGenericSignature = nullptr);
4231+
CanGenericSignature associatedFunctionGenericSignature = nullptr);
42324232

42334233
/// Returns a bit vector that specifices which parameters you can
42344234
/// differentiate with respect to for this differentiable function type. (e.g.
4235-
/// which parameters are not @nondiff). The function type must be
4235+
/// which parameters are not `@nondiff`). The function type must be
42364236
/// differentiable.
42374237
AutoDiffIndexSubset *getDifferentiationParameterIndices();
42384238

lib/SIL/SILFunctionType.cpp

Lines changed: 56 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@
2222
#include "swift/AST/DiagnosticsSIL.h"
2323
#include "swift/AST/ForeignInfo.h"
2424
#include "swift/AST/GenericEnvironment.h"
25+
// SWIFT_ENABLE_TENSORFLOW
26+
#include "swift/AST/GenericSignatureBuilder.h"
2527
#include "swift/AST/Module.h"
2628
#include "swift/AST/ProtocolConformance.h"
2729
#include "swift/SIL/SILModule.h"
@@ -148,22 +150,68 @@ CanSILFunctionType SILFunctionType::getWithoutDifferentiability() {
148150
getOptionalErrorResult(), getASTContext());
149151
}
150152

153+
// Returns the canonical generic signature for an autodiff associated function
154+
// given an existing associated function generic signature. All differentiation
155+
// parameters are constrained to conform to `Differentiable`.
156+
static CanGenericSignature getAutoDiffAssociatedFunctionGenericSignature(
157+
CanGenericSignature assocFnGenSig,
158+
ArrayRef<SILParameterInfo> originalParameters,
159+
AutoDiffIndexSubset *parameterIndices, SILModule &module) {
160+
// If associated function has no
161+
if (!assocFnGenSig)
162+
return nullptr;
163+
auto &ctx = module.getASTContext();
164+
GenericSignatureBuilder builder(ctx);
165+
166+
// Add associated function generic signature.
167+
builder.addGenericSignature(assocFnGenSig);
168+
// Constrain all wrt parameters to conform to `Differentiable`.
169+
auto source =
170+
GenericSignatureBuilder::FloatingRequirementSource::forAbstract();
171+
auto *diffableProto = ctx.getProtocol(KnownProtocolKind::Differentiable);
172+
for (unsigned paramIdx : parameterIndices->getIndices()) {
173+
auto paramType = originalParameters[paramIdx].getType();
174+
Requirement req(RequirementKind::Conformance, paramType,
175+
diffableProto->getDeclaredType());
176+
builder.addRequirement(req, source, module.getSwiftModule());
177+
}
178+
return std::move(builder)
179+
.computeGenericSignature(SourceLoc(), /*allowConcreteGenericParams*/ true)
180+
->getCanonicalSignature();
181+
}
182+
151183
CanSILFunctionType SILFunctionType::getAutoDiffAssociatedFunctionType(
152184
AutoDiffIndexSubset *parameterIndices, unsigned resultIndex,
153185
unsigned differentiationOrder, AutoDiffAssociatedFunctionKind kind,
154186
SILModule &module, LookupConformanceFn lookupConformance,
155-
CanGenericSignature whereClauseGenSig) {
187+
CanGenericSignature assocFnGenSig) {
156188
// JVP: (T...) -> ((R...),
157189
// (T.TangentVector...) -> (R.TangentVector...))
158190
// VJP: (T...) -> ((R...),
159191
// (R.TangentVector...) -> (T.TangentVector...))
160192

161193
auto &ctx = getASTContext();
162194
auto &typeConverter = module.Types;
163-
if (!whereClauseGenSig)
164-
whereClauseGenSig = getGenericSignature();
165-
Lowering::GenericContextScope genericContextScope(
166-
module.Types, whereClauseGenSig);
195+
196+
// Helper function testing if we are differentiating wrt this index.
197+
auto isWrtIndex = [&](unsigned index) -> bool {
198+
return index < parameterIndices->getCapacity() &&
199+
parameterIndices->contains(index);
200+
};
201+
202+
// Calculate differentiation parameter infos.
203+
SmallVector<SILParameterInfo, 4> wrtParams;
204+
for (auto valueAndIndex : enumerate(getParameters()))
205+
if (isWrtIndex(valueAndIndex.index()))
206+
wrtParams.push_back(valueAndIndex.value());
207+
208+
// Get the canonical associated function generic signature.
209+
if (!assocFnGenSig)
210+
assocFnGenSig = getGenericSignature();
211+
assocFnGenSig = getAutoDiffAssociatedFunctionGenericSignature(
212+
assocFnGenSig, getParameters(), parameterIndices, module);
213+
Lowering::GenericContextScope genericContextScope(module.Types,
214+
assocFnGenSig);
167215

168216
// Given a type, returns its formal SIL parameter info.
169217
auto getTangentParameterInfoForOriginalResult = [&](
@@ -214,18 +262,6 @@ CanSILFunctionType SILFunctionType::getAutoDiffAssociatedFunctionType(
214262
return {tanType, conv};
215263
};
216264

217-
// Helper function testing if we are differentiating wrt this index.
218-
auto isWrtIndex = [&](unsigned index) -> bool {
219-
return index < parameterIndices->getCapacity() &&
220-
parameterIndices->contains(index);
221-
};
222-
223-
// Calculate differentiation parameter infos.
224-
SmallVector<SILParameterInfo, 4> wrtParams;
225-
for (auto valueAndIndex : enumerate(getParameters()))
226-
if (isWrtIndex(valueAndIndex.index()))
227-
wrtParams.push_back(valueAndIndex.value());
228-
229265
CanSILFunctionType closureType;
230266
switch (kind) {
231267
case AutoDiffAssociatedFunctionKind::JVP: {
@@ -280,12 +316,12 @@ CanSILFunctionType SILFunctionType::getAutoDiffAssociatedFunctionType(
280316
newResults.reserve(getNumResults() + 1);
281317
for (auto &result : getResults()) {
282318
auto mappedResult = result.getWithType(
283-
result.getType()->getCanonicalType(whereClauseGenSig));
319+
result.getType()->getCanonicalType(assocFnGenSig));
284320
newResults.push_back(mappedResult);
285321
}
286-
newResults.push_back({closureType->getCanonicalType(whereClauseGenSig),
322+
newResults.push_back({closureType->getCanonicalType(assocFnGenSig),
287323
ResultConvention::Owned});
288-
return SILFunctionType::get(whereClauseGenSig, getExtInfo(),
324+
return SILFunctionType::get(assocFnGenSig, getExtInfo(),
289325
getCoroutineKind(), getCalleeConvention(),
290326
getParameters(), getYields(), newResults,
291327
getOptionalErrorResult(), ctx,

lib/SILGen/SILGen.cpp

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -816,8 +816,7 @@ void SILGenModule::postEmitFunction(SILDeclRef constant,
816816
auto *jvpFn = getFunction(SILDeclRef(jvpDecl), NotForDefinition);
817817
if (jvpFn->getLoweredFunctionType() != expectedJVPType) {
818818
jvpThunk = getOrCreateAutoDiffAssociatedFunctionThunk(
819-
F, indices, jvpFn, AutoDiffAssociatedFunctionKind::JVP,
820-
jvpFn->isSerialized());
819+
F, indices, jvpFn, AutoDiffAssociatedFunctionKind::JVP);
821820
} else {
822821
auto *id = AutoDiffAssociatedFunctionIdentifier::get(
823822
AutoDiffAssociatedFunctionKind::JVP, /*differentiationOrder*/ 1,
@@ -836,8 +835,7 @@ void SILGenModule::postEmitFunction(SILDeclRef constant,
836835
auto *vjpFn = getFunction(SILDeclRef(vjpDecl), NotForDefinition);
837836
if (vjpFn->getLoweredFunctionType() != expectedVJPType) {
838837
vjpThunk = getOrCreateAutoDiffAssociatedFunctionThunk(
839-
F, indices, vjpFn, AutoDiffAssociatedFunctionKind::VJP,
840-
vjpFn->isSerialized());
838+
F, indices, vjpFn, AutoDiffAssociatedFunctionKind::VJP);
841839
} else {
842840
auto *id = AutoDiffAssociatedFunctionIdentifier::get(
843841
AutoDiffAssociatedFunctionKind::VJP, /*differentiationOrder*/ 1,

lib/SILGen/SILGen.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -187,8 +187,7 @@ class LLVM_LIBRARY_VISIBILITY SILGenModule : public ASTVisitor<SILGenModule> {
187187
/// - The last result in the returned pullback.
188188
SILFunction *getOrCreateAutoDiffAssociatedFunctionThunk(
189189
SILFunction *original, SILAutoDiffIndices &indices,
190-
SILFunction *assocFn, AutoDiffAssociatedFunctionKind assocFnKind,
191-
IsSerialized_t isSerialized);
190+
SILFunction *assocFn, AutoDiffAssociatedFunctionKind assocFnKind);
192191

193192
/// Determine whether the given class has any instance variables that
194193
/// need to be destroyed.

lib/SILGen/SILGenPoly.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3687,8 +3687,7 @@ static void forwardFunctionArgumentsConvertingOwnership(
36873687
SILFunction *
36883688
SILGenModule::getOrCreateAutoDiffAssociatedFunctionThunk(
36893689
SILFunction *original, SILAutoDiffIndices &indices,
3690-
SILFunction *assocFn, AutoDiffAssociatedFunctionKind assocFnKind,
3691-
IsSerialized_t isSerialized) {
3690+
SILFunction *assocFn, AutoDiffAssociatedFunctionKind assocFnKind) {
36923691
auto assocFnType = assocFn->getLoweredFunctionType();
36933692

36943693
Mangle::ASTMangler mangler;

lib/SILOptimizer/Mandatory/Differentiation.cpp

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -199,26 +199,38 @@ static LoadOwnershipQualifier getBufferLOQ(Type type, SILFunction &fn) {
199199
return LoadOwnershipQualifier::Unqualified;
200200
}
201201

202-
// Return the expected generic signature for autodiff associated functions given
203-
// a SILDifferentiableAttr. The expected generic signature is built from the
204-
// original generic signature and the attribute's requirements.
202+
// Returns the generic signature for an autodiff associated function given a
203+
// `SILDifferentiableAttr` and the original function. The associated function's
204+
// generic signature is built from the original function's generic signature and
205+
// the attribute's requirements. All differentiation parameters are constrained
206+
// to conform to `Differentiable`.
205207
static CanGenericSignature
206208
getAssociatedFunctionGenericSignature(SILDifferentiableAttr *attr,
207209
SILFunction *original) {
208-
auto originalGenSig =
209-
original->getLoweredFunctionType()->getGenericSignature();
210+
auto originalFnTy = original->getLoweredFunctionType();
211+
auto originalGenSig = originalFnTy->getGenericSignature();
210212
if (!originalGenSig)
211213
return nullptr;
212-
GenericSignatureBuilder builder(original->getASTContext());
214+
auto &ctx = original->getASTContext();
215+
GenericSignatureBuilder builder(ctx);
213216
// Add original generic signature.
214217
builder.addGenericSignature(originalGenSig);
215218
// Add where clause requirements.
216219
auto source =
217220
GenericSignatureBuilder::FloatingRequirementSource::forAbstract();
218221
for (auto &req : attr->getRequirements())
219222
builder.addRequirement(req, source, original->getModule().getSwiftModule());
223+
// Constrain all wrt parameters to conform to `Differentiable`.
224+
auto *diffableProto = ctx.getProtocol(KnownProtocolKind::Differentiable);
225+
auto paramIndexSet = attr->getIndices().parameters;
226+
for (unsigned paramIdx : paramIndexSet->getIndices()) {
227+
auto paramType = originalFnTy->getParameters()[paramIdx].getType();
228+
Requirement req(RequirementKind::Conformance, paramType,
229+
diffableProto->getDeclaredType());
230+
builder.addRequirement(req, source, original->getModule().getSwiftModule());
231+
}
220232
return std::move(builder)
221-
.computeGenericSignature(SourceLoc(), /*allowConcreteGenericParams=*/true)
233+
.computeGenericSignature(SourceLoc(), /*allowConcreteGenericParams*/ true)
222234
->getCanonicalSignature();
223235
}
224236

test/AutoDiff/autodiff_indirect_diagnostics.swift

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -89,10 +89,6 @@ let _: @differentiable (Float) -> TF_687<Any> = { x in TF_687<Any>(x, dummy: x)
8989
// Add `Differentiable` conformance for generic wrt parameters
9090
//===----------------------------------------------------------------------===//
9191

92-
// FIXME(TF-697): The tests below were fixed by
93-
// https://github.com/apple/swift/pull/26406, which was reverted because it
94-
// introduced TF-697.
95-
/*
9692
func id<T>(_ x: T) -> T { x }
9793
let _: @differentiable (Float) -> Float = { x in id(x) }
9894

@@ -107,4 +103,3 @@ extension TF_691: Differentiable where Scalar: Differentiable {}
107103
func identity<T>(_ x: TF_691<T>) -> TF_691<T> { x }
108104
let _: @differentiable (Float) -> TF_691<Float> = { x in identity(TF_691(x)) }
109105
let _: @differentiable (Float) -> TF_691<Float> = { x in id(TF_691(x)) }
110-
*/

0 commit comments

Comments
 (0)