Skip to content

Commit c2ac96f

Browse files
authored
[AutoDiff upstream] Add SIL transpose function type calculation. (#29755)
Add `SILFunctionType::getAutoDiffTransposeFunctionType`. It computes the transpose `SILFucntionType` for an original `SILFunctionType`, given: - Linearity parameter indices - Transpose function generic signature (optional) - Other auxiliary parameters Add doc comments explaining typing rules, preconditions, and other details. Add `isTranspose` flag to `autodiff::getConstrainedDerivativeGenericSignature`. Partially resolves TF-1125. Unblocks TF-1141: upstream `differentiability_witness_function` instruction.
1 parent 9092b82 commit c2ac96f

File tree

4 files changed

+146
-10
lines changed

4 files changed

+146
-10
lines changed

include/swift/AST/AutoDiff.h

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -240,14 +240,17 @@ void getSubsetParameterTypes(IndexSubset *indices, AnyFunctionType *type,
240240
/// "Constrained" derivative generic signatures require all differentiability
241241
/// parameters to conform to the `Differentiable` protocol.
242242
///
243-
/// Returns the "constrained" derivative generic signature given:
243+
/// "Constrained" transpose generic signatures additionally require all
244+
/// linearity parameters to satisfy `Self == Self.TangentVector`.
245+
///
246+
/// Returns the "constrained" derivative/transpose generic signature given:
244247
/// - An original SIL function type.
245248
/// - Differentiability parameter indices.
246249
/// - A possibly "unconstrained" derivative generic signature.
247-
GenericSignature
248-
getConstrainedDerivativeGenericSignature(SILFunctionType *originalFnTy,
249-
IndexSubset *diffParamIndices,
250-
GenericSignature derivativeGenSig);
250+
GenericSignature getConstrainedDerivativeGenericSignature(
251+
SILFunctionType *originalFnTy, IndexSubset *diffParamIndices,
252+
GenericSignature derivativeGenSig, LookupConformanceFn lookupConformance,
253+
bool isTranspose = false);
251254

252255
} // end namespace autodiff
253256

include/swift/AST/Types.h

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4542,6 +4542,42 @@ class SILFunctionType final : public TypeBase, public llvm::FoldingSetNode,
45424542
CanGenericSignature derivativeFunctionGenericSignature = nullptr,
45434543
bool isReabstractionThunk = false);
45444544

4545+
/// Returns the type of the transpose function for the given parameter
4546+
/// indices, transpose function generic signature (optional), and other
4547+
/// auxiliary parameters.
4548+
///
4549+
/// Preconditions:
4550+
/// - Linearity parameters corresponding to parameter indices must conform to
4551+
/// `Differentiable` and satisfy `Self == Self.TangentVector`.
4552+
///
4553+
/// Typing rules, given:
4554+
/// - Original function type: $(T0, T1, ...) -> (R0, R1, ...)
4555+
///
4556+
/// Transpose function type:
4557+
/// - Takes non-linearity parameters, followed by original results, as
4558+
/// parameters.
4559+
/// - Returns linearity parameters.
4560+
///
4561+
/// A "constrained transpose generic signature" is computed from
4562+
/// `transposeFunctionGenericSignature`, if specified. Otherwise, it is
4563+
/// computed from the original generic signature. A "constrained transpose
4564+
/// generic signature" requires all linearity parameters to conform to
4565+
/// `Differentiable` and to satisfy `Self == Self.TangentVector`; this is
4566+
/// important for correctness.
4567+
///
4568+
/// This "constrained transpose generic signature" is used for
4569+
/// parameter/result type lowering. It is used as the actual generic signature
4570+
/// of the transpose function type iff the original function type has a
4571+
/// generic signature and not all generic parameters are bound to concrete
4572+
/// types. Otherwise, no transpose generic signature is used.
4573+
///
4574+
/// Other properties of the original function type are copied exactly:
4575+
/// `ExtInfo`, callee convention, witness method conformance, etc.
4576+
CanSILFunctionType getAutoDiffTransposeFunctionType(
4577+
IndexSubset *parameterIndices, Lowering::TypeConverter &TC,
4578+
LookupConformanceFn lookupConformance,
4579+
CanGenericSignature transposeFunctionGenericSignature = nullptr);
4580+
45454581
ExtInfo getExtInfo() const {
45464582
return ExtInfo(Bits.SILFunctionType.ExtInfoBits, getClangFunctionType());
45474583
}

lib/AST/AutoDiff.cpp

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -83,20 +83,29 @@ void autodiff::getSubsetParameterTypes(IndexSubset *subset,
8383

8484
GenericSignature autodiff::getConstrainedDerivativeGenericSignature(
8585
SILFunctionType *originalFnTy, IndexSubset *diffParamIndices,
86-
GenericSignature derivativeGenSig) {
86+
GenericSignature derivativeGenSig, LookupConformanceFn lookupConformance,
87+
bool isTranspose) {
8788
if (!derivativeGenSig)
8889
derivativeGenSig = originalFnTy->getSubstGenericSignature();
8990
if (!derivativeGenSig)
9091
return nullptr;
91-
// Constrain all differentiability parameters to `Differentiable`.
9292
auto &ctx = originalFnTy->getASTContext();
9393
auto *diffableProto = ctx.getProtocol(KnownProtocolKind::Differentiable);
9494
SmallVector<Requirement, 4> requirements;
9595
for (unsigned paramIdx : diffParamIndices->getIndices()) {
96+
// Require differentiability parameters to conform to `Differentiable`.
9697
auto paramType = originalFnTy->getParameters()[paramIdx].getInterfaceType();
9798
Requirement req(RequirementKind::Conformance, paramType,
9899
diffableProto->getDeclaredType());
99100
requirements.push_back(req);
101+
if (isTranspose) {
102+
// Require linearity parameters to additionally satisfy
103+
// `Self == Self.TangentVector`.
104+
auto tanSpace = paramType->getAutoDiffTangentSpace(lookupConformance);
105+
auto paramTanType = tanSpace->getCanonicalType();
106+
Requirement req(RequirementKind::SameType, paramType, paramTanType);
107+
requirements.push_back(req);
108+
}
100109
}
101110
return evaluateOrDefault(
102111
ctx.evaluator,

lib/SIL/SILFunctionType.cpp

Lines changed: 91 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -260,11 +260,13 @@ CanSILFunctionType SILFunctionType::getAutoDiffDerivativeFunctionType(
260260
if (isDiffParamIndex(valueAndIndex.index()))
261261
diffParams.push_back(valueAndIndex.value());
262262

263-
// Get the canonical derivative function generic signature.
263+
// Get the "constrained" derivative function generic signature.
264264
if (!derivativeFnGenSig)
265265
derivativeFnGenSig = getSubstGenericSignature();
266-
derivativeFnGenSig = autodiff::getConstrainedDerivativeGenericSignature(
267-
this, parameterIndices, derivativeFnGenSig).getCanonicalSignature();
266+
derivativeFnGenSig =
267+
autodiff::getConstrainedDerivativeGenericSignature(
268+
this, parameterIndices, derivativeFnGenSig, lookupConformance)
269+
.getCanonicalSignature();
268270

269271
// Given a type, returns its formal SIL parameter info.
270272
auto getTangentParameterInfoForOriginalResult =
@@ -401,6 +403,92 @@ CanSILFunctionType SILFunctionType::getAutoDiffDerivativeFunctionType(
401403
ctx, getWitnessMethodConformanceOrInvalid());
402404
}
403405

406+
CanSILFunctionType SILFunctionType::getAutoDiffTransposeFunctionType(
407+
IndexSubset *parameterIndices, Lowering::TypeConverter &TC,
408+
LookupConformanceFn lookupConformance,
409+
CanGenericSignature transposeFnGenSig) {
410+
// Get the "constrained" transpose function generic signature.
411+
if (!transposeFnGenSig)
412+
transposeFnGenSig = getSubstGenericSignature();
413+
transposeFnGenSig = autodiff::getConstrainedDerivativeGenericSignature(
414+
this, parameterIndices, transposeFnGenSig,
415+
lookupConformance, /*isLinear*/ true)
416+
.getCanonicalSignature();
417+
418+
// Given a type, returns its formal SIL parameter info.
419+
auto getParameterInfoForOriginalResult =
420+
[&](const SILResultInfo &result) -> SILParameterInfo {
421+
AbstractionPattern pattern(transposeFnGenSig, result.getInterfaceType());
422+
auto &tl = TC.getTypeLowering(pattern, result.getInterfaceType(),
423+
TypeExpansionContext::minimal());
424+
ParameterConvention newConv;
425+
switch (result.getConvention()) {
426+
case ResultConvention::Owned:
427+
case ResultConvention::Autoreleased:
428+
newConv = tl.isTrivial() ? ParameterConvention::Direct_Unowned
429+
: ParameterConvention::Direct_Guaranteed;
430+
break;
431+
case ResultConvention::Unowned:
432+
case ResultConvention::UnownedInnerPointer:
433+
newConv = ParameterConvention::Direct_Unowned;
434+
break;
435+
case ResultConvention::Indirect:
436+
newConv = ParameterConvention::Indirect_In_Guaranteed;
437+
break;
438+
}
439+
return {result.getInterfaceType()->getCanonicalType(transposeFnGenSig),
440+
newConv};
441+
};
442+
443+
// Given a type, returns its formal SIL result info.
444+
auto getResultInfoForOriginalParameter =
445+
[&](const SILParameterInfo &param) -> SILResultInfo {
446+
AbstractionPattern pattern(transposeFnGenSig, param.getInterfaceType());
447+
auto &tl = TC.getTypeLowering(pattern, param.getInterfaceType(),
448+
TypeExpansionContext::minimal());
449+
ResultConvention newConv;
450+
switch (param.getConvention()) {
451+
case ParameterConvention::Direct_Owned:
452+
case ParameterConvention::Direct_Guaranteed:
453+
case ParameterConvention::Direct_Unowned:
454+
newConv =
455+
tl.isTrivial() ? ResultConvention::Unowned : ResultConvention::Owned;
456+
break;
457+
case ParameterConvention::Indirect_In:
458+
case ParameterConvention::Indirect_Inout:
459+
case ParameterConvention::Indirect_In_Constant:
460+
case ParameterConvention::Indirect_In_Guaranteed:
461+
case ParameterConvention::Indirect_InoutAliasable:
462+
newConv = ResultConvention::Indirect;
463+
break;
464+
}
465+
return {param.getInterfaceType()->getCanonicalType(transposeFnGenSig),
466+
newConv};
467+
};
468+
469+
SmallVector<SILParameterInfo, 4> newParameters;
470+
SmallVector<SILResultInfo, 4> newResults;
471+
for (auto param : llvm::enumerate(getParameters())) {
472+
if (parameterIndices->contains(param.index()))
473+
newResults.push_back(getResultInfoForOriginalParameter(param.value()));
474+
else
475+
newParameters.push_back(param.value());
476+
}
477+
for (auto &res : getResults())
478+
newParameters.push_back(getParameterInfoForOriginalResult(res));
479+
// Transpose function type has a generic signature only if the original
480+
// function type does, and if `transposeFnGenSig` does not have all concrete
481+
// generic parameters.
482+
CanGenericSignature canGenSig;
483+
if (getSubstGenericSignature() && transposeFnGenSig &&
484+
!transposeFnGenSig->areAllParamsConcrete())
485+
canGenSig = transposeFnGenSig;
486+
return SILFunctionType::get(
487+
canGenSig, getExtInfo(), getCoroutineKind(), getCalleeConvention(),
488+
newParameters, getYields(), newResults, getOptionalErrorResult(),
489+
getSubstitutions(), isGenericSignatureImplied(), getASTContext());
490+
}
491+
404492
static CanType getKnownType(Optional<CanType> &cacheSlot, ASTContext &C,
405493
StringRef moduleName, StringRef typeName) {
406494
if (!cacheSlot) {

0 commit comments

Comments
 (0)