Skip to content

Commit 0c0018f

Browse files
authored
[AutoDiff] Improve getConstrainedDerivativeGenericSignature helper. (#29620)
Move `getConstrainedDerivativeGenericSignature` under `autodiff` namespace. Improve naming and documentation.
1 parent 7ff9ba1 commit 0c0018f

File tree

3 files changed

+42
-32
lines changed

3 files changed

+42
-32
lines changed

include/swift/AST/AutoDiff.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
namespace swift {
3030

3131
class AnyFunctionType;
32+
class SILFunctionType;
3233
class TupleType;
3334

3435
/// A function type differentiability kind.
@@ -231,6 +232,18 @@ void getSubsetParameterTypes(IndexSubset *indices, AnyFunctionType *type,
231232
SmallVectorImpl<Type> &results,
232233
bool reverseCurryLevels = false);
233234

235+
/// "Constrained" derivative generic signatures require all differentiability
236+
/// parameters to conform to the `Differentiable` protocol.
237+
///
238+
/// Returns the "constrained" derivative generic signature given:
239+
/// - An original SIL function type.
240+
/// - Differentiability parameter indices.
241+
/// - A possibly "unconstrained" derivative generic signature.
242+
GenericSignature
243+
getConstrainedDerivativeGenericSignature(SILFunctionType *originalFnTy,
244+
IndexSubset *diffParamIndices,
245+
GenericSignature derivativeGenSig);
246+
234247
} // end namespace autodiff
235248

236249
} // end namespace swift

lib/AST/AutoDiff.cpp

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
//===----------------------------------------------------------------------===//
1212

1313
#include "swift/AST/AutoDiff.h"
14+
#include "swift/AST/ASTContext.h"
15+
#include "swift/AST/TypeCheckRequests.h"
1416
#include "swift/AST/Types.h"
1517

1618
using namespace swift;
@@ -67,6 +69,31 @@ void autodiff::getSubsetParameterTypes(IndexSubset *subset,
6769
}
6870
}
6971

72+
GenericSignature autodiff::getConstrainedDerivativeGenericSignature(
73+
SILFunctionType *originalFnTy, IndexSubset *diffParamIndices,
74+
GenericSignature derivativeGenSig) {
75+
if (!derivativeGenSig)
76+
derivativeGenSig = originalFnTy->getSubstGenericSignature();
77+
if (!derivativeGenSig)
78+
return nullptr;
79+
// Constrain all differentiability parameters to `Differentiable`.
80+
auto &ctx = originalFnTy->getASTContext();
81+
auto *diffableProto = ctx.getProtocol(KnownProtocolKind::Differentiable);
82+
SmallVector<Requirement, 4> requirements;
83+
for (unsigned paramIdx : diffParamIndices->getIndices()) {
84+
auto paramType = originalFnTy->getParameters()[paramIdx].getInterfaceType();
85+
Requirement req(RequirementKind::Conformance, paramType,
86+
diffableProto->getDeclaredType());
87+
requirements.push_back(req);
88+
}
89+
return evaluateOrDefault(
90+
ctx.evaluator,
91+
AbstractGenericSignatureRequest{derivativeGenSig.getPointer(),
92+
/*addedGenericParams*/ {},
93+
std::move(requirements)},
94+
nullptr);
95+
}
96+
7097
Type TangentSpace::getType() const {
7198
switch (kind) {
7299
case Kind::TangentVector:

lib/SIL/SILFunctionType.cpp

Lines changed: 2 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -191,36 +191,6 @@ SILFunctionType::getWitnessMethodClass(SILModule &M) const {
191191
return nullptr;
192192
}
193193

194-
// Returns the canonical generic signature for an autodiff derivative function
195-
// given an existing derivative function generic signature. All
196-
// differentiability parameters are required to conform to `Differentiable`.
197-
static CanGenericSignature getAutoDiffDerivativeFunctionGenericSignature(
198-
CanGenericSignature derivativeFnGenSig,
199-
ArrayRef<SILParameterInfo> originalParameters,
200-
IndexSubset *parameterIndices, ModuleDecl *module) {
201-
if (!derivativeFnGenSig)
202-
return nullptr;
203-
auto &ctx = module->getASTContext();
204-
GenericSignatureBuilder builder(ctx);
205-
// Add derivative function generic signature.
206-
builder.addGenericSignature(derivativeFnGenSig);
207-
// All differentiability parameters are required to conform to
208-
// `Differentiable`.
209-
auto source =
210-
GenericSignatureBuilder::FloatingRequirementSource::forAbstract();
211-
auto *differentiableProtocol =
212-
ctx.getProtocol(KnownProtocolKind::Differentiable);
213-
for (unsigned paramIdx : parameterIndices->getIndices()) {
214-
auto paramType = originalParameters[paramIdx].getInterfaceType();
215-
Requirement req(RequirementKind::Conformance, paramType,
216-
differentiableProtocol->getDeclaredType());
217-
builder.addRequirement(req, source, module);
218-
}
219-
return std::move(builder)
220-
.computeGenericSignature(SourceLoc(), /*allowConcreteGenericParams*/ true)
221-
->getCanonicalSignature();
222-
}
223-
224194
CanSILFunctionType SILFunctionType::getAutoDiffDerivativeFunctionType(
225195
IndexSubset *parameterIndices, unsigned resultIndex,
226196
AutoDiffDerivativeFunctionKind kind, TypeConverter &TC,
@@ -243,8 +213,8 @@ CanSILFunctionType SILFunctionType::getAutoDiffDerivativeFunctionType(
243213
// Get the canonical derivative function generic signature.
244214
if (!derivativeFnGenSig)
245215
derivativeFnGenSig = getSubstGenericSignature();
246-
derivativeFnGenSig = getAutoDiffDerivativeFunctionGenericSignature(
247-
derivativeFnGenSig, getParameters(), parameterIndices, &TC.M);
216+
derivativeFnGenSig = autodiff::getConstrainedDerivativeGenericSignature(
217+
this, parameterIndices, derivativeFnGenSig).getCanonicalSignature();
248218

249219
// Given a type, returns its formal SIL parameter info.
250220
auto getTangentParameterInfoForOriginalResult =

0 commit comments

Comments
 (0)