Skip to content

Commit f6d6d5c

Browse files
authored
[AutoDiff] Improve getConstrainedDerivativeGenericSignature helper. (#29621)
Move `getConstrainedDerivativeGenericSignature` under `autodiff` namespace. Improve naming and documentation.
1 parent 6b8983d commit f6d6d5c

File tree

4 files changed

+47
-69
lines changed

4 files changed

+47
-69
lines changed

include/swift/AST/AutoDiff.h

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ namespace swift {
3131
class AnyFunctionType;
3232
class TupleType;
3333
struct SILAutoDiffIndices;
34-
34+
class SILFunctionType;
3535

3636
/// A function type differentiability kind.
3737
enum class DifferentiabilityKind : uint8_t {
@@ -240,6 +240,18 @@ void getSubsetParameterTypes(IndexSubset *indices, AnyFunctionType *type,
240240
SmallVectorImpl<Type> &results,
241241
bool reverseCurryLevels = false);
242242

243+
/// "Constrained" derivative generic signatures require all differentiability
244+
/// parameters to conform to the `Differentiable` protocol.
245+
///
246+
/// Returns the "constrained" derivative generic signature given:
247+
/// - An original SIL function type.
248+
/// - Differentiability parameter indices.
249+
/// - A possibly "unconstrained" derivative generic signature.
250+
GenericSignature
251+
getConstrainedDerivativeGenericSignature(SILFunctionType *originalFnTy,
252+
IndexSubset *diffParamIndices,
253+
GenericSignature derivativeGenSig);
254+
243255
} // end namespace autodiff
244256

245257
} // end namespace swift
@@ -342,7 +354,6 @@ namespace swift {
342354

343355
class ASTContext;
344356
class AnyFunctionType;
345-
class SILFunctionType;
346357
typedef CanTypeWrapper<SILFunctionType> CanSILFunctionType;
347358
enum class SILLinkage : uint8_t;
348359

lib/AST/AutoDiff.cpp

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,9 @@
1010
//
1111
//===----------------------------------------------------------------------===//
1212

13+
#include "swift/AST/ASTContext.h"
1314
#include "swift/AST/AutoDiff.h"
15+
#include "swift/AST/TypeCheckRequests.h"
1416
#include "swift/AST/Types.h"
1517

1618
using namespace swift;
@@ -67,6 +69,31 @@ void autodiff::getSubsetParameterTypes(IndexSubset *subset,
6769
}
6870
}
6971

72+
GenericSignature autodiff::getConstrainedDerivativeGenericSignature(
73+
SILFunctionType *originalFnTy, IndexSubset *diffParamIndices,
74+
GenericSignature derivativeGenSig) {
75+
if (!derivativeGenSig)
76+
derivativeGenSig = originalFnTy->getSubstGenericSignature();
77+
if (!derivativeGenSig)
78+
return nullptr;
79+
// Constrain all differentiability parameters to `Differentiable`.
80+
auto &ctx = originalFnTy->getASTContext();
81+
auto *diffableProto = ctx.getProtocol(KnownProtocolKind::Differentiable);
82+
SmallVector<Requirement, 4> requirements;
83+
for (unsigned paramIdx : diffParamIndices->getIndices()) {
84+
auto paramType = originalFnTy->getParameters()[paramIdx].getInterfaceType();
85+
Requirement req(RequirementKind::Conformance, paramType,
86+
diffableProto->getDeclaredType());
87+
requirements.push_back(req);
88+
}
89+
return evaluateOrDefault(
90+
ctx.evaluator,
91+
AbstractGenericSignatureRequest{derivativeGenSig.getPointer(),
92+
/*addedGenericParams*/ {},
93+
std::move(requirements)},
94+
nullptr);
95+
}
96+
7097
// SWIFT_ENABLE_TENSORFLOW
7198
// Not-yet-upstreamed `tensorflow` branch additions are below.
7299

lib/SIL/SILFunctionType.cpp

Lines changed: 6 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -226,35 +226,6 @@ CanSILFunctionType SILFunctionType::getWithoutDifferentiability() {
226226
isGenericSignatureImplied(), getASTContext());
227227
}
228228

229-
// Returns the canonical generic signature for an autodiff derivative function
230-
// given an existing derivative function generic signature. All differentiation
231-
// parameters are constrained to conform to `Differentiable`.
232-
static CanGenericSignature getAutoDiffDerivativeFunctionGenericSignature(
233-
CanGenericSignature derivativeFnGenSig,
234-
ArrayRef<SILParameterInfo> originalParameters,
235-
IndexSubset *parameterIndices, ModuleDecl *module) {
236-
if (!derivativeFnGenSig)
237-
return nullptr;
238-
auto &ctx = module->getASTContext();
239-
GenericSignatureBuilder builder(ctx);
240-
241-
// Add derivative function generic signature.
242-
builder.addGenericSignature(derivativeFnGenSig);
243-
// Constrain all wrt parameters to conform to `Differentiable`.
244-
auto source =
245-
GenericSignatureBuilder::FloatingRequirementSource::forAbstract();
246-
auto *diffableProto = ctx.getProtocol(KnownProtocolKind::Differentiable);
247-
for (unsigned paramIdx : parameterIndices->getIndices()) {
248-
auto paramType = originalParameters[paramIdx].getInterfaceType();
249-
Requirement req(RequirementKind::Conformance, paramType,
250-
diffableProto->getDeclaredType());
251-
builder.addRequirement(req, source, module);
252-
}
253-
return std::move(builder)
254-
.computeGenericSignature(SourceLoc(), /*allowConcreteGenericParams*/ true)
255-
->getCanonicalSignature();
256-
}
257-
258229
CanSILFunctionType SILFunctionType::getAutoDiffDerivativeFunctionType(
259230
IndexSubset *parameterIndices, unsigned resultIndex,
260231
AutoDiffDerivativeFunctionKind kind, TypeConverter &TC,
@@ -293,8 +264,9 @@ CanSILFunctionType SILFunctionType::getAutoDiffDerivativeFunctionType(
293264
// Get the canonical derivative function generic signature.
294265
if (!derivativeFnGenSig)
295266
derivativeFnGenSig = getSubstGenericSignature();
296-
derivativeFnGenSig = getAutoDiffDerivativeFunctionGenericSignature(
297-
derivativeFnGenSig, getParameters(), parameterIndices, &TC.M);
267+
derivativeFnGenSig = autodiff::getConstrainedDerivativeGenericSignature(
268+
this, parameterIndices, derivativeFnGenSig)
269+
.getCanonicalSignature();
298270

299271
// Given a type, returns its formal SIL parameter info.
300272
auto getTangentParameterInfoForOriginalResult =
@@ -456,8 +428,9 @@ CanSILFunctionType SILFunctionType::getAutoDiffTransposeFunctionType(
456428
// Get the canonical transpose function generic signature.
457429
if (!genSig)
458430
genSig = getSubstGenericSignature();
459-
genSig = getAutoDiffDerivativeFunctionGenericSignature(
460-
genSig, getParameters(), parameterIndices, &TC.M);
431+
genSig = autodiff::getConstrainedDerivativeGenericSignature(
432+
this, parameterIndices, genSig)
433+
.getCanonicalSignature();
461434

462435
// Given a type, returns its formal SIL parameter info.
463436
auto getParameterInfoForOriginalResult =

lib/SILOptimizer/Mandatory/Differentiation.cpp

Lines changed: 1 addition & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -84,39 +84,6 @@ template <typename T> static inline void debugDump(T &v) {
8484
<< v << "\n==== END DEBUG DUMP ====\n");
8585
}
8686

87-
/// Returns the "constrained" derivative generic signature given:
88-
/// - An original SIL function type.
89-
/// - A wrt parameter index subset.
90-
/// - A possibly uncanonical derivative generic signature (optional).
91-
/// - Additional derivative requirements (optional).
92-
/// The constrained derivative generic signature constrains all wrt parameters
93-
/// to conform to `Differentiable`.
94-
static GenericSignature
95-
getConstrainedDerivativeGenericSignature(CanSILFunctionType originalFnTy,
96-
IndexSubset *paramIndexSet,
97-
GenericSignature derivativeGenSig) {
98-
if (!derivativeGenSig)
99-
derivativeGenSig = originalFnTy->getSubstGenericSignature();
100-
if (!derivativeGenSig)
101-
return nullptr;
102-
// Constrain all wrt parameters to `Differentiable`.
103-
auto &ctx = derivativeGenSig->getASTContext();
104-
auto *diffableProto = ctx.getProtocol(KnownProtocolKind::Differentiable);
105-
SmallVector<Requirement, 4> requirements;
106-
for (unsigned paramIdx : paramIndexSet->getIndices()) {
107-
auto paramType = originalFnTy->getParameters()[paramIdx].getInterfaceType();
108-
Requirement req(RequirementKind::Conformance, paramType,
109-
diffableProto->getDeclaredType());
110-
requirements.push_back(req);
111-
}
112-
return evaluateOrDefault(
113-
ctx.evaluator,
114-
AbstractGenericSignatureRequest{derivativeGenSig.getPointer(),
115-
/*addedGenericParams*/ {},
116-
std::move(requirements)},
117-
nullptr);
118-
}
119-
12087
namespace {
12188

12289
class DifferentiationTransformer {
@@ -597,7 +564,7 @@ emitDerivativeFunctionReference(
597564
invoker.getIndirectDifferentiation()
598565
.second->getDerivativeGenericSignature();
599566
auto derivativeConstrainedGenSig =
600-
getConstrainedDerivativeGenericSignature(
567+
autodiff::getConstrainedDerivativeGenericSignature(
601568
originalFn->getLoweredFunctionType(), desiredParameterIndices,
602569
contextualDerivativeGenSig);
603570
minimalWitness = SILDifferentiabilityWitness::createDefinition(

0 commit comments

Comments
 (0)