Skip to content

Commit efb3a49

Browse files
authored
Merge pull request #40063 from rxwei/rdar84716758
[AutoDiff] Plumb witness derivative generic signatures through SILGen.
2 parents 2b41bee + 9bcba98 commit efb3a49

17 files changed

+200
-144
lines changed

include/swift/AST/Witness.h

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,9 @@ class Witness {
9494
ConcreteDeclRef declRef;
9595
GenericEnvironment *syntheticEnvironment;
9696
SubstitutionMap reqToSyntheticEnvSubs;
97+
/// The derivative generic signature, when the requirement is a derivative
98+
/// function.
99+
GenericSignature derivativeGenSig;
97100
};
98101

99102
llvm::PointerUnion<ValueDecl *, StoredWitness *> storage;
@@ -124,7 +127,8 @@ class Witness {
124127
static Witness forDeserialized(ValueDecl *decl,
125128
SubstitutionMap substitutions) {
126129
// TODO: It's probably a good idea to have a separate 'deserialized' bit.
127-
return Witness(decl, substitutions, nullptr, SubstitutionMap());
130+
return Witness(
131+
decl, substitutions, nullptr, SubstitutionMap(), CanGenericSignature());
128132
}
129133

130134
/// Create a witness that requires substitutions.
@@ -138,10 +142,14 @@ class Witness {
138142
///
139143
/// \param reqToSyntheticEnvSubs The mapping from the interface types of the
140144
/// requirement into the interface types of the synthetic environment.
145+
///
146+
/// \param derivativeGenSig The derivative generic signature, when the
147+
/// requirement is a derivative function.
141148
Witness(ValueDecl *decl,
142149
SubstitutionMap substitutions,
143150
GenericEnvironment *syntheticEnv,
144-
SubstitutionMap reqToSyntheticEnvSubs);
151+
SubstitutionMap reqToSyntheticEnvSubs,
152+
GenericSignature derivativeGenSig);
145153

146154
/// Retrieve the witness declaration reference, which includes the
147155
/// substitutions needed to use the witness from the synthetic environment
@@ -183,6 +191,13 @@ class Witness {
183191
return {};
184192
}
185193

194+
/// Retrieve the derivative generic signature.
195+
GenericSignature getDerivativeGenericSignature() const {
196+
if (auto *storedWitness = storage.dyn_cast<StoredWitness *>())
197+
return storedWitness->derivativeGenSig;
198+
return GenericSignature();
199+
}
200+
186201
SWIFT_DEBUG_DUMP;
187202

188203
void dump(llvm::raw_ostream &out) const;

lib/AST/ProtocolConformance.cpp

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,8 @@ using namespace swift;
4242

4343
Witness::Witness(ValueDecl *decl, SubstitutionMap substitutions,
4444
GenericEnvironment *syntheticEnv,
45-
SubstitutionMap reqToSynthesizedEnvSubs) {
45+
SubstitutionMap reqToSynthesizedEnvSubs,
46+
GenericSignature derivativeGenSig) {
4647
if (!syntheticEnv && substitutions.empty() &&
4748
reqToSynthesizedEnvSubs.empty()) {
4849
storage = decl;
@@ -53,7 +54,8 @@ Witness::Witness(ValueDecl *decl, SubstitutionMap substitutions,
5354
auto declRef = ConcreteDeclRef(decl, substitutions);
5455
auto storedMem = ctx.Allocate(sizeof(StoredWitness), alignof(StoredWitness));
5556
auto stored = new (storedMem) StoredWitness{declRef, syntheticEnv,
56-
reqToSynthesizedEnvSubs};
57+
reqToSynthesizedEnvSubs,
58+
derivativeGenSig};
5759

5860
storage = stored;
5961
}
@@ -892,7 +894,8 @@ NormalProtocolConformance::getWitnessUncached(ValueDecl *requirement) const {
892894
}
893895

894896
Witness SelfProtocolConformance::getWitness(ValueDecl *requirement) const {
895-
return Witness(requirement, SubstitutionMap(), nullptr, SubstitutionMap());
897+
return Witness(requirement, SubstitutionMap(), nullptr, SubstitutionMap(),
898+
GenericSignature());
896899
}
897900

898901
ConcreteDeclRef

lib/SIL/IR/SILFunctionType.cpp

Lines changed: 75 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -408,6 +408,32 @@ static CanGenericSignature buildDifferentiableGenericSignature(CanGenericSignatu
408408
return buildGenericSignature(ctx, sig, {}, reqs).getCanonicalSignature();
409409
}
410410

411+
/// Given an original type, computes its tangent type for the purpose of
412+
/// building a linear map using this type. When the original type is an
413+
/// archetype or contains a type parameter, appends a new generic parameter and
414+
/// a corresponding replacement type to the given containers.
415+
static CanType getAutoDiffTangentTypeForLinearMap(
416+
Type originalType,
417+
LookupConformanceFn lookupConformance,
418+
SmallVectorImpl<GenericTypeParamType *> &substGenericParams,
419+
SmallVectorImpl<Type> &substReplacements,
420+
ASTContext &context
421+
) {
422+
auto maybeTanType = originalType->getAutoDiffTangentSpace(lookupConformance);
423+
assert(maybeTanType && "Type does not have a tangent space?");
424+
auto tanType = maybeTanType->getCanonicalType();
425+
// If concrete, the tangent type is concrete.
426+
if (!tanType->hasArchetype() && !tanType->hasTypeParameter())
427+
return tanType;
428+
// Otherwise, the tangent type is a new generic parameter substituted for the
429+
// tangent type.
430+
auto gpIndex = substGenericParams.size();
431+
auto gpType = CanGenericTypeParamType::get(0, gpIndex, context);
432+
substGenericParams.push_back(gpType);
433+
substReplacements.push_back(tanType);
434+
return gpType;
435+
}
436+
411437
/// Returns the differential type for the given original function type,
412438
/// parameter indices, and result index.
413439
static CanSILFunctionType getAutoDiffDifferentialType(
@@ -484,45 +510,32 @@ static CanSILFunctionType getAutoDiffDifferentialType(
484510
getDifferentiabilityParameters(originalFnTy, parameterIndices, diffParams);
485511
SmallVector<SILParameterInfo, 8> differentialParams;
486512
for (auto &param : diffParams) {
487-
auto paramTan =
488-
param.getInterfaceType()->getAutoDiffTangentSpace(lookupConformance);
489-
assert(paramTan && "Parameter type does not have a tangent space?");
490-
auto paramTanType = paramTan->getCanonicalType();
491-
auto paramConv = getTangentParameterConvention(paramTanType,
492-
param.getConvention());
493-
if (!paramTanType->hasArchetype() && !paramTanType->hasTypeParameter()) {
494-
differentialParams.push_back(
495-
{paramTan->getCanonicalType(), paramConv});
496-
} else {
497-
auto gpIndex = substGenericParams.size();
498-
auto gpType = CanGenericTypeParamType::get(0, gpIndex, ctx);
499-
substGenericParams.push_back(gpType);
500-
substReplacements.push_back(paramTanType);
501-
differentialParams.push_back({gpType, paramConv});
502-
}
513+
auto paramTanType = getAutoDiffTangentTypeForLinearMap(
514+
param.getInterfaceType(), lookupConformance,
515+
substGenericParams, substReplacements, ctx);
516+
auto paramConv = getTangentParameterConvention(
517+
// FIXME(rdar://82549134): Use `resultTanType` to compute it instead.
518+
param.getInterfaceType()
519+
->getAutoDiffTangentSpace(lookupConformance)
520+
->getCanonicalType(),
521+
param.getConvention());
522+
differentialParams.push_back({paramTanType, paramConv});
503523
}
504524
SmallVector<SILResultInfo, 1> differentialResults;
505525
for (auto resultIndex : resultIndices->getIndices()) {
506526
// Handle formal original result.
507527
if (resultIndex < originalFnTy->getNumResults()) {
508528
auto &result = originalResults[resultIndex];
509-
auto resultTan =
510-
result.getInterfaceType()->getAutoDiffTangentSpace(lookupConformance);
511-
assert(resultTan && "Result type does not have a tangent space?");
512-
auto resultTanType = resultTan->getCanonicalType();
513-
auto resultConv =
514-
getTangentResultConvention(resultTanType, result.getConvention());
515-
if (!resultTanType->hasArchetype() &&
516-
!resultTanType->hasTypeParameter()) {
517-
differentialResults.push_back(
518-
{resultTan->getCanonicalType(), resultConv});
519-
} else {
520-
auto gpIndex = substGenericParams.size();
521-
auto gpType = CanGenericTypeParamType::get(0, gpIndex, ctx);
522-
substGenericParams.push_back(gpType);
523-
substReplacements.push_back(resultTanType);
524-
differentialResults.push_back({gpType, resultConv});
525-
}
529+
auto resultTanType = getAutoDiffTangentTypeForLinearMap(
530+
result.getInterfaceType(), lookupConformance,
531+
substGenericParams, substReplacements, ctx);
532+
auto resultConv = getTangentResultConvention(
533+
// FIXME(rdar://82549134): Use `resultTanType` to compute it instead.
534+
result.getInterfaceType()
535+
->getAutoDiffTangentSpace(lookupConformance)
536+
->getCanonicalType(),
537+
result.getConvention());
538+
differentialResults.push_back({resultTanType, resultConv});
526539
continue;
527540
}
528541
// Handle original `inout` parameter.
@@ -537,11 +550,11 @@ static CanSILFunctionType getAutoDiffDifferentialType(
537550
if (parameterIndices->contains(paramIndex))
538551
continue;
539552
auto inoutParam = originalFnTy->getParameters()[paramIndex];
540-
auto paramTan = inoutParam.getInterfaceType()->getAutoDiffTangentSpace(
541-
lookupConformance);
542-
assert(paramTan && "Parameter type does not have a tangent space?");
553+
auto inoutParamTanType = getAutoDiffTangentTypeForLinearMap(
554+
inoutParam.getInterfaceType(), lookupConformance,
555+
substGenericParams, substReplacements, ctx);
543556
differentialResults.push_back(
544-
{paramTan->getCanonicalType(), ResultConvention::Indirect});
557+
{inoutParamTanType, ResultConvention::Indirect});
545558
}
546559

547560
SubstitutionMap substitutions;
@@ -648,23 +661,16 @@ static CanSILFunctionType getAutoDiffPullbackType(
648661
// Handle formal original result.
649662
if (resultIndex < originalFnTy->getNumResults()) {
650663
auto &origRes = originalResults[resultIndex];
651-
auto resultTan = origRes.getInterfaceType()->getAutoDiffTangentSpace(
652-
lookupConformance);
653-
assert(resultTan && "Result type does not have a tangent space?");
654-
auto resultTanType = resultTan->getCanonicalType();
655-
auto paramTanConvention = getTangentParameterConventionForOriginalResult(
656-
resultTanType, origRes.getConvention());
657-
if (!resultTanType->hasArchetype() &&
658-
!resultTanType->hasTypeParameter()) {
659-
auto resultTanType = resultTan->getCanonicalType();
660-
pullbackParams.push_back({resultTanType, paramTanConvention});
661-
} else {
662-
auto gpIndex = substGenericParams.size();
663-
auto gpType = CanGenericTypeParamType::get(0, gpIndex, ctx);
664-
substGenericParams.push_back(gpType);
665-
substReplacements.push_back(resultTanType);
666-
pullbackParams.push_back({gpType, paramTanConvention});
667-
}
664+
auto resultTanType = getAutoDiffTangentTypeForLinearMap(
665+
origRes.getInterfaceType(), lookupConformance,
666+
substGenericParams, substReplacements, ctx);
667+
auto paramConv = getTangentParameterConventionForOriginalResult(
668+
// FIXME(rdar://82549134): Use `resultTanType` to compute it instead.
669+
origRes.getInterfaceType()
670+
->getAutoDiffTangentSpace(lookupConformance)
671+
->getCanonicalType(),
672+
origRes.getConvention());
673+
pullbackParams.push_back({resultTanType, paramConv});
668674
continue;
669675
}
670676
// Handle original `inout` parameter.
@@ -674,28 +680,18 @@ static CanSILFunctionType getAutoDiffPullbackType(
674680
auto paramIndex =
675681
std::distance(originalFnTy->getParameters().begin(), &*inoutParamIt);
676682
auto inoutParam = originalFnTy->getParameters()[paramIndex];
677-
auto paramTan = inoutParam.getInterfaceType()->getAutoDiffTangentSpace(
678-
lookupConformance);
679-
assert(paramTan && "Parameter type does not have a tangent space?");
680683
// The pullback parameter convention depends on whether the original `inout`
681684
// paramater is a differentiability parameter.
682685
// - If yes, the pullback parameter convention is `@inout`.
683686
// - If no, the pullback parameter convention is `@in_guaranteed`.
687+
auto inoutParamTanType = getAutoDiffTangentTypeForLinearMap(
688+
inoutParam.getInterfaceType(), lookupConformance,
689+
substGenericParams, substReplacements, ctx);
684690
bool isWrtInoutParameter = parameterIndices->contains(paramIndex);
685691
auto paramTanConvention = isWrtInoutParameter
686-
? inoutParam.getConvention()
687-
: ParameterConvention::Indirect_In_Guaranteed;
688-
auto paramTanType = paramTan->getCanonicalType();
689-
if (!paramTanType->hasArchetype() && !paramTanType->hasTypeParameter()) {
690-
pullbackParams.push_back(
691-
SILParameterInfo(paramTanType, paramTanConvention));
692-
} else {
693-
auto gpIndex = substGenericParams.size();
694-
auto gpType = CanGenericTypeParamType::get(0, gpIndex, ctx);
695-
substGenericParams.push_back(gpType);
696-
substReplacements.push_back(paramTanType);
697-
pullbackParams.push_back({gpType, paramTanConvention});
698-
}
692+
? inoutParam.getConvention()
693+
: ParameterConvention::Indirect_In_Guaranteed;
694+
pullbackParams.push_back({inoutParamTanType, paramTanConvention});
699695
}
700696

701697
// Collect pullback results.
@@ -707,21 +703,16 @@ static CanSILFunctionType getAutoDiffPullbackType(
707703
// and always appear as pullback parameters.
708704
if (param.isIndirectInOut())
709705
continue;
710-
auto paramTan =
711-
param.getInterfaceType()->getAutoDiffTangentSpace(lookupConformance);
712-
assert(paramTan && "Parameter type does not have a tangent space?");
713-
auto paramTanType = paramTan->getCanonicalType();
706+
auto paramTanType = getAutoDiffTangentTypeForLinearMap(
707+
param.getInterfaceType(), lookupConformance,
708+
substGenericParams, substReplacements, ctx);
714709
auto resultTanConvention = getTangentResultConventionForOriginalParameter(
715-
paramTanType, param.getConvention());
716-
if (!paramTanType->hasArchetype() && !paramTanType->hasTypeParameter()) {
717-
pullbackResults.push_back({paramTanType, resultTanConvention});
718-
} else {
719-
auto gpIndex = substGenericParams.size();
720-
auto gpType = CanGenericTypeParamType::get(0, gpIndex, ctx);
721-
substGenericParams.push_back(gpType);
722-
substReplacements.push_back(paramTanType);
723-
pullbackResults.push_back({gpType, resultTanConvention});
724-
}
710+
// FIXME(rdar://82549134): Use `resultTanType` to compute it instead.
711+
param.getInterfaceType()
712+
->getAutoDiffTangentSpace(lookupConformance)
713+
->getCanonicalType(),
714+
param.getConvention());
715+
pullbackResults.push_back({paramTanType, resultTanConvention});
725716
}
726717
SubstitutionMap substitutions;
727718
if (!substGenericParams.empty()) {

lib/SIL/IR/TypeLowering.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2631,7 +2631,8 @@ CanAnyFunctionType TypeConverter::makeConstantInterfaceType(SILDeclRef c) {
26312631
makeConstantInterfaceType(c.asAutoDiffOriginalFunction());
26322632
auto *derivativeFnTy = originalFnTy->getAutoDiffDerivativeFunctionType(
26332633
derivativeId->getParameterIndices(), derivativeId->getKind(),
2634-
LookUpConformanceInModule(&M));
2634+
LookUpConformanceInModule(&M),
2635+
derivativeId->getDerivativeGenericSignature());
26352636
return cast<AnyFunctionType>(derivativeFnTy->getCanonicalType());
26362637
}
26372638

lib/SILGen/SILGenType.cpp

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -416,7 +416,7 @@ template<typename T> class SILGenWitnessTable : public SILWitnessVisitor<T> {
416416
if (!reqAccessor) {
417417
if (auto witness = asDerived().getWitness(reqDecl)) {
418418
return addMethodImplementation(
419-
requirementRef, requirementRef.withDecl(witness.getDecl()),
419+
requirementRef, getWitnessRef(requirementRef, witness),
420420
witness);
421421
}
422422

@@ -444,7 +444,8 @@ template<typename T> class SILGenWitnessTable : public SILWitnessVisitor<T> {
444444
witnessStorage->getSynthesizedAccessor(reqAccessor->getAccessorKind());
445445

446446
return addMethodImplementation(
447-
requirementRef, requirementRef.withDecl(witnessAccessor), witness);
447+
requirementRef, getWitnessRef(requirementRef, witnessAccessor),
448+
witness);
448449
}
449450

450451
private:
@@ -458,6 +459,21 @@ template<typename T> class SILGenWitnessTable : public SILWitnessVisitor<T> {
458459
asDerived().addMethodImplementation(requirementRef, witnessRef,
459460
isFree, witness);
460461
}
462+
463+
SILDeclRef getWitnessRef(SILDeclRef requirementRef, Witness witness) {
464+
auto witnessRef = requirementRef.withDecl(witness.getDecl());
465+
// If the requirement/witness is a derivative function, we need to
466+
// substitute the witness's derivative generic signature in its derivative
467+
// function identifier.
468+
if (requirementRef.isAutoDiffDerivativeFunction()) {
469+
auto *reqrRerivativeId = requirementRef.getDerivativeFunctionIdentifier();
470+
auto *witnessDerivativeId = AutoDiffDerivativeFunctionIdentifier::get(
471+
reqrRerivativeId->getKind(), reqrRerivativeId->getParameterIndices(),
472+
witness.getDerivativeGenericSignature(), witnessRef.getASTContext());
473+
witnessRef = witnessRef.asAutoDiffDerivativeFunction(witnessDerivativeId);
474+
}
475+
return witnessRef;
476+
}
461477
};
462478

463479
static IsSerialized_t isConformanceSerialized(RootProtocolConformance *conf) {

0 commit comments

Comments
 (0)