Skip to content

Commit 51e64db

Browse files
authored
enable @differentiable on stored and computed properties (#21556)
1 parent 247bb4b commit 51e64db

File tree

6 files changed

+365
-57
lines changed

6 files changed

+365
-57
lines changed

include/swift/AST/Attr.def

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -386,7 +386,7 @@ SIMPLE_DECL_ATTR(_nonoverride, NonOverride,
386386

387387
// SWIFT_ENABLE_TENSORFLOW
388388
DECL_ATTR(differentiable, Differentiable,
389-
OnFunc | LongAttribute, 80)
389+
OnAccessor | OnFunc | OnVar | LongAttribute, 80)
390390
SIMPLE_DECL_ATTR(compilerEvaluable, CompilerEvaluable,
391391
OnAccessor | OnFunc | OnConstructor | OnSubscript,
392392
/* Not serialized */ 81)

lib/SILGen/SILGen.cpp

Lines changed: 50 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -719,51 +719,62 @@ void SILGenModule::emitAbstractFuncDecl(AbstractFunctionDecl *AFD) {
719719
}
720720

721721
// SWIFT_ENABLE_TENSORFLOW
722-
// [differentiable] attributes only make sense on functions with
723-
// bodies, because [differentiable] attributes declare actual primals
724-
// and adjoints corresponding to the function body.
722+
// [differentiable] attributes only make sense on functions with bodies,
723+
// because [differentiable] attributes declare actual associated functions
724+
// corresponding to the function body.
725725
if (!AFD->hasBody())
726726
return;
727727

728-
// If the declaration has a @differentiable(reverse) attribute, turn it into a
729-
// SIL [differentiable] attribute with lowered associated function names and
730-
// lowered differentiation parameter indices.
731-
//
728+
// Look for a @differentiable attribute on the decl.
732729
// FIXME: Handle multiple @differentiable attributes.
733-
if (auto *diffAttr = cast_or_null<DifferentiableAttr>(
734-
AFD->getAttrs().getAttribute(DeclAttrKind::DAK_Differentiable))) {
735-
auto silOriginalFn = getFunction(SILDeclRef(AFD), ForDefinition);
736-
// Either only adjoint is specified, or both primal and adjoint are
737-
// spcified.
738-
StringRef primName, adjName, jvpName, vjpName;
739-
bool hasPrimitiveAdjoint = false;
740-
if (auto *primFn = diffAttr->getPrimalFunction())
741-
primName = getFunction(SILDeclRef(primFn), ForDefinition)->getName();
742-
if (auto *adjointFn = diffAttr->getAdjointFunction()) {
743-
// If the adjoint is specified but the primal is not, then we treat the
744-
// original as the primal.
745-
if (primName.empty())
746-
primName = silOriginalFn->getName();
747-
adjName = getFunction(SILDeclRef(adjointFn), ForDefinition)->getName();
748-
hasPrimitiveAdjoint = true;
749-
}
750-
else {
751-
assert(primName.empty() &&
752-
"Primal cannot be present if adjoint is not");
730+
DifferentiableAttr *diffAttr = nullptr;
731+
if (AFD->getAttrs().hasAttribute<DifferentiableAttr>())
732+
diffAttr = AFD->getAttrs().getAttribute<DifferentiableAttr>();
733+
// If the AFD is the getter for a storage decl, also look for a
734+
// @differentiable attribute on the storage decl, because @differentiable
735+
// attributes on storage decls modify the getter.
736+
if (auto *accessor = dyn_cast<AccessorDecl>(AFD)) {
737+
if (accessor->isGetter()) {
738+
auto &storageAttrs = accessor->getStorage()->getAttrs();
739+
if (storageAttrs.hasAttribute<DifferentiableAttr>())
740+
diffAttr = storageAttrs.getAttribute<DifferentiableAttr>();
753741
}
754-
if (auto *jvpFn = diffAttr->getJVPFunction())
755-
jvpName = getFunction(SILDeclRef(jvpFn), ForDefinition)->getName();
756-
if (auto *vjpFn = diffAttr->getVJPFunction())
757-
vjpName = getFunction(SILDeclRef(vjpFn), ForDefinition)->getName();
758-
// Get lowered argument indices.
759-
auto paramIndices = diffAttr->getCheckedParameterIndices()->getLowered(
760-
AFD->getInterfaceType()->castTo<AnyFunctionType>());
761-
SILAutoDiffIndices indices(/*source*/ 0, paramIndices);
762-
silOriginalFn->addDifferentiableAttr(
763-
SILDifferentiableAttr::create(
764-
M, indices, primName, adjName,
765-
/*primitive*/ hasPrimitiveAdjoint, jvpName, vjpName));
766742
}
743+
744+
if (!diffAttr)
745+
return;
746+
747+
// The declaration (or its storage decl) has a @differentiable attribute, so
748+
// turn it into a SIL [differentiable] attribute with lowered associated
749+
// function names and lowered differentiation parameter indices.
750+
auto silOriginalFn = getFunction(SILDeclRef(AFD), ForDefinition);
751+
// Either only adjoint is specified, or both primal and adjoint are
752+
// spcified.
753+
StringRef primName, adjName, jvpName, vjpName;
754+
bool hasPrimitiveAdjoint = false;
755+
if (auto *primFn = diffAttr->getPrimalFunction())
756+
primName = getFunction(SILDeclRef(primFn), ForDefinition)->getName();
757+
if (auto *adjointFn = diffAttr->getAdjointFunction()) {
758+
// If the adjoint is specified but the primal is not, then we treat the
759+
// original as the primal.
760+
if (primName.empty())
761+
primName = silOriginalFn->getName();
762+
adjName = getFunction(SILDeclRef(adjointFn), ForDefinition)->getName();
763+
hasPrimitiveAdjoint = true;
764+
} else {
765+
assert(primName.empty() && "Primal cannot be present if adjoint is not");
766+
}
767+
if (auto *jvpFn = diffAttr->getJVPFunction())
768+
jvpName = getFunction(SILDeclRef(jvpFn), ForDefinition)->getName();
769+
if (auto *vjpFn = diffAttr->getVJPFunction())
770+
vjpName = getFunction(SILDeclRef(vjpFn), ForDefinition)->getName();
771+
// Get lowered argument indices.
772+
auto paramIndices = diffAttr->getCheckedParameterIndices()->getLowered(
773+
AFD->getInterfaceType()->castTo<AnyFunctionType>());
774+
SILAutoDiffIndices indices(/*source*/ 0, paramIndices);
775+
silOriginalFn->addDifferentiableAttr(SILDifferentiableAttr::create(
776+
M, indices, primName, adjName,
777+
/*primitive*/ hasPrimitiveAdjoint, jvpName, vjpName));
767778
}
768779

769780
void SILGenModule::emitFunction(FuncDecl *fd) {

lib/Sema/TypeCheckAttr.cpp

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2190,15 +2190,28 @@ static FuncDecl *resolveAutoDiffAssociatedFunction(
21902190
}
21912191

21922192
void AttributeChecker::visitDifferentiableAttr(DifferentiableAttr *attr) {
2193-
// '@differentiable' attribute is OnFunc only, rejected by the early checker.
2194-
auto *original = cast<FuncDecl>(D);
2195-
auto isInstanceMethod = original->isInstanceMember();
2196-
auto &ctx = original->getASTContext();
2197-
AnyFunctionType *originalFnTy =
2198-
original->getInterfaceType()->castTo<AnyFunctionType>();
2193+
auto &ctx = TC.Context;
21992194
auto lookupConformance =
22002195
LookUpConformanceInModule(D->getDeclContext()->getParentModule());
22012196

2197+
FuncDecl *original = nullptr;
2198+
if (isa<VarDecl>(D)) {
2199+
// When used on a storage decl, @differentiable refers to its getter.
2200+
original = cast<VarDecl>(D)->getGetter();
2201+
} else if (isa<FuncDecl>(D)) {
2202+
original = cast<FuncDecl>(D);
2203+
}
2204+
if (!original) {
2205+
// Global immutable vars, for example, have no getter, and therefore trigger
2206+
// this.
2207+
diagnoseAndRemoveAttr(attr, diag::invalid_decl_attribute, attr);
2208+
return;
2209+
}
2210+
2211+
TC.resolveDeclSignature(original);
2212+
auto *originalFnTy = original->getInterfaceType()->castTo<AnyFunctionType>();
2213+
auto isInstanceMethod = original->isInstanceMember();
2214+
22022215
// If the original function has no parameters or returns the empty tuple
22032216
// type, there's nothing to differentiate from or with-respect-to.
22042217
auto &originalParams = *original->getParameters();

test/AutoDiff/differentiable_attr_silgen.swift

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,3 +104,87 @@ public func dhasvjp(_ x: Float, _ y: Float) -> (Float, (Float) -> (Float, Float)
104104
}
105105

106106
// CHECK-LABEL: sil @dhasvjp
107+
108+
//===----------------------------------------------------------------------===//
109+
// Stored property
110+
//===----------------------------------------------------------------------===//
111+
112+
struct DiffStoredProp {
113+
@differentiable(wrt: (self), jvp: storedPropJVP, vjp: storedPropVJP)
114+
let storedProp: Float
115+
116+
@_silgen_name("storedPropJVP")
117+
func storedPropJVP() -> (Float, (DiffStoredProp) -> Float) {
118+
fatalError("unimplemented")
119+
}
120+
121+
@_silgen_name("storedPropVJP")
122+
func storedPropVJP() -> (Float, (Float) -> DiffStoredProp) {
123+
fatalError("unimplemented")
124+
}
125+
}
126+
127+
extension DiffStoredProp : VectorNumeric {
128+
static var zero: DiffStoredProp { fatalError("unimplemented") }
129+
static func + (lhs: DiffStoredProp, rhs: DiffStoredProp) -> DiffStoredProp {
130+
fatalError("unimplemented")
131+
}
132+
static func - (lhs: DiffStoredProp, rhs: DiffStoredProp) -> DiffStoredProp {
133+
fatalError("unimplemented")
134+
}
135+
typealias Scalar = Float
136+
static func * (lhs: Float, rhs: DiffStoredProp) -> DiffStoredProp {
137+
fatalError("unimplemented")
138+
}
139+
}
140+
141+
extension DiffStoredProp : Differentiable {
142+
typealias TangentVector = DiffStoredProp
143+
typealias CotangentVector = DiffStoredProp
144+
}
145+
146+
// CHECK-LABEL: DiffStoredProp.storedProp.getter
147+
// CHECK-NEXT: sil {{.*}} [differentiable source 0 wrt 0 jvp @storedPropJVP vjp @storedPropVJP]
148+
149+
//===----------------------------------------------------------------------===//
150+
// Computed property
151+
//===----------------------------------------------------------------------===//
152+
153+
struct DiffComputedProp {
154+
@differentiable(wrt: (self), jvp: computedPropJVP, vjp: computedPropVJP)
155+
var computedProp: Float {
156+
return 0
157+
}
158+
159+
@_silgen_name("computedPropJVP")
160+
func computedPropJVP() -> (Float, (DiffComputedProp) -> Float) {
161+
fatalError("unimplemented")
162+
}
163+
164+
@_silgen_name("computedPropVJP")
165+
func computedPropVJP() -> (Float, (Float) -> DiffComputedProp) {
166+
fatalError("unimplemented")
167+
}
168+
}
169+
170+
extension DiffComputedProp : VectorNumeric {
171+
static var zero: DiffComputedProp { fatalError("unimplemented") }
172+
static func + (lhs: DiffComputedProp, rhs: DiffComputedProp) -> DiffComputedProp {
173+
fatalError("unimplemented")
174+
}
175+
static func - (lhs: DiffComputedProp, rhs: DiffComputedProp) -> DiffComputedProp {
176+
fatalError("unimplemented")
177+
}
178+
typealias Scalar = Float
179+
static func * (lhs: Float, rhs: DiffComputedProp) -> DiffComputedProp {
180+
fatalError("unimplemented")
181+
}
182+
}
183+
184+
extension DiffComputedProp : Differentiable {
185+
typealias TangentVector = DiffComputedProp
186+
typealias CotangentVector = DiffComputedProp
187+
}
188+
189+
// CHECK-LABEL: DiffComputedProp.computedProp.getter
190+
// CHECK-NEXT: sil {{.*}} [differentiable source 0 wrt 0 jvp @computedPropJVP vjp @computedPropVJP]

0 commit comments

Comments
 (0)