Skip to content

Commit bcbe65e

Browse files
author
marcrasi
authored
[AutoDiff] diff wit silgen: type simplification, getters (#28398)
1 parent 3880378 commit bcbe65e

File tree

3 files changed

+25
-41
lines changed

3 files changed

+25
-41
lines changed

lib/SILGen/SILGen.cpp

Lines changed: 19 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -767,19 +767,25 @@ void SILGenModule::postEmitFunction(SILDeclRef constant,
767767
constant.kind != SILDeclRef::Kind::DefaultArgGenerator &&
768768
!constant.isThunk()) {
769769
auto *AFD = constant.getAbstractFunctionDecl();
770-
// Visit all `@differentiable` attributes.
771-
for (auto *diffAttr : AFD->getAttrs().getAttributes<DifferentiableAttr>()) {
772-
SILFunction *jvp = nullptr;
773-
SILFunction *vjp = nullptr;
774-
if (auto *jvpDecl = diffAttr->getJVPFunction())
775-
jvp = getFunction(SILDeclRef(jvpDecl), NotForDefinition);
776-
if (auto *vjpDecl = diffAttr->getVJPFunction())
777-
vjp = getFunction(SILDeclRef(vjpDecl), NotForDefinition);
778-
auto *resultIndices = IndexSubset::get(getASTContext(), 1, {0});
779-
AutoDiffConfig config(diffAttr->getParameterIndices(), resultIndices,
780-
diffAttr->getDerivativeGenericSignature().getPointer());
781-
emitDifferentiabilityWitness(AFD, F, config, jvp, vjp);
782-
}
770+
auto emitWitnesses = [&](DeclAttributes &Attrs) {
771+
for (auto *diffAttr : Attrs.getAttributes<DifferentiableAttr>()) {
772+
SILFunction *jvp = nullptr;
773+
SILFunction *vjp = nullptr;
774+
if (auto *jvpDecl = diffAttr->getJVPFunction())
775+
jvp = getFunction(SILDeclRef(jvpDecl), NotForDefinition);
776+
if (auto *vjpDecl = diffAttr->getVJPFunction())
777+
vjp = getFunction(SILDeclRef(vjpDecl), NotForDefinition);
778+
auto *resultIndices = IndexSubset::get(getASTContext(), 1, {0});
779+
AutoDiffConfig config(
780+
diffAttr->getParameterIndices(), resultIndices,
781+
diffAttr->getDerivativeGenericSignature().getPointer());
782+
emitDifferentiabilityWitness(AFD, F, config, jvp, vjp);
783+
}
784+
};
785+
if (auto *accessor = dyn_cast<AccessorDecl>(AFD))
786+
if (accessor->isGetter())
787+
emitWitnesses(accessor->getStorage()->getAttrs());
788+
emitWitnesses(AFD->getAttrs());
783789
}
784790
F->verify();
785791
}
@@ -817,21 +823,6 @@ void SILGenModule::emitDifferentiabilityWitness(
817823
CanGenericSignature derivativeCanGenSig;
818824
if (auto derivativeGenSig = config.derivativeGenericSignature)
819825
derivativeCanGenSig = derivativeGenSig->getCanonicalSignature();
820-
// TODO(TF-835): Use simpler derivative generic signature logic below when
821-
// type-checking no longer generates implicit `@differentiable` attributes.
822-
// See TF-835 for replacement code.
823-
if (jvp) {
824-
auto jvpCanGenSig = jvp->getLoweredFunctionType()->getSubstGenericSignature();
825-
if (!derivativeCanGenSig && jvpCanGenSig)
826-
derivativeCanGenSig = jvpCanGenSig;
827-
assert(derivativeCanGenSig == jvpCanGenSig);
828-
}
829-
if (vjp) {
830-
auto vjpCanGenSig = vjp->getLoweredFunctionType()->getSubstGenericSignature();
831-
if (!derivativeCanGenSig && vjpCanGenSig)
832-
derivativeCanGenSig = vjpCanGenSig;
833-
assert(derivativeCanGenSig == vjpCanGenSig);
834-
}
835826
// Create new SIL differentiability witness.
836827
// Witness JVP and VJP are set below.
837828
// TODO(TF-919): Explore creating serialized differentiability witnesses.

lib/TBDGen/TBDGen.cpp

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -195,20 +195,9 @@ void TBDGenVisitor::addDifferentiabilityWitness(
195195
attr->getParameterIndices(),
196196
original->getInterfaceType()->castTo<AnyFunctionType>());
197197

198-
GenericSignature genericSignature = attr->getDerivativeGenericSignature();
199-
if (auto *jvpDecl = attr->getJVPFunction()) {
200-
assert(!genericSignature ||
201-
jvpDecl->getGenericSignature()->isEqual(genericSignature));
202-
genericSignature = jvpDecl->getGenericSignature();
203-
}
204-
if (auto *vjpDecl = attr->getVJPFunction()) {
205-
assert(!genericSignature ||
206-
vjpDecl->getGenericSignature()->isEqual(genericSignature));
207-
genericSignature = vjpDecl->getGenericSignature();
208-
}
209-
210198
std::string originalMangledName = SILDeclRef(original).mangle();
211-
AutoDiffConfig config{loweredParamIndices, resultIndices, genericSignature};
199+
AutoDiffConfig config{loweredParamIndices, resultIndices,
200+
attr->getDerivativeGenericSignature()};
212201
SILDifferentiabilityWitnessKey key(originalMangledName, config);
213202

214203
Mangle::ASTMangler mangle;

test/AutoDiff/sil_differentiability_witness_silgen.swift

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,10 @@ func generic_vjp<T: Differentiable>(_ x: T, _ y: Float) -> (
7474
public struct Foo: Differentiable {
7575
public var x: Float
7676

77+
// CHECK-LABEL: // differentiability witness for Foo.x.getter
78+
// CHECK-NEXT: sil_differentiability_witness [parameters 0] [results 0] @$s36sil_differentiability_witness_silgen3FooV1xSfvg : $@convention(method) (Foo) -> Float {
79+
// CHECK-NEXT: }
80+
7781
@differentiable
7882
public init(_ x: Float) {
7983
self.x = x

0 commit comments

Comments
 (0)