Skip to content

Commit 19ecff5

Browse files
committed
Remove getter differentiation pass, fix tests.
1 parent e9fb45f commit 19ecff5

12 files changed

+58
-267
lines changed

include/swift/AST/DiagnosticsSema.def

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2734,7 +2734,7 @@ ERROR(differentiable_attr_unsupported_req_kind,none,
27342734
ERROR(differentiable_attr_class_unsupported,none,
27352735
"class members cannot be marked with '@differentiable'", ())
27362736
ERROR(differentiable_attr_stored_property_variable_unsupported,none,
2737-
"stored properties/variables cannot be marked with '@differentiable' with a custom VJP/JVP", ())
2737+
"'jvp:' or 'vjp:' cannot be specified for stored properties", ())
27382738
NOTE(protocol_witness_missing_specific_differentiable_attr,none,
27392739
"candidate is missing attribute '%0'", (StringRef))
27402740

lib/SILOptimizer/Mandatory/Differentiation.cpp

Lines changed: 9 additions & 153 deletions
Original file line numberDiff line numberDiff line change
@@ -737,11 +737,7 @@ enum class StructExtractDifferentiationStrategy {
737737
// that is zero except along the direction of the corresponding field.
738738
//
739739
// Fields correspond by matching name.
740-
Fieldwise,
741-
742-
// Differentiate the `struct_extract` by looking up the corresponding getter
743-
// and using its VJP.
744-
Getter
740+
Fieldwise
745741
};
746742

747743
static inline llvm::raw_ostream &operator<<(llvm::raw_ostream &os,
@@ -3232,59 +3228,10 @@ class VJPEmitter final
32323228
SILClonerWithScopes::visitStructExtractInst(sei);
32333229
return;
32343230
}
3235-
// This instruction is active. Determine the appropriate differentiation
3236-
// strategy, and use it.
3237-
// Find the corresponding getter.
3238-
auto *getterDecl = sei->getField()->getGetter();
3239-
assert(getterDecl);
3240-
auto *getterFn = getModule().lookUpFunction(
3241-
SILDeclRef(getterDecl, SILDeclRef::Kind::Func));
3242-
auto *structDecl = sei->getStructDecl();
3243-
if (!getterFn ||
3244-
structDecl->getAttrs().hasAttribute<FieldwiseDifferentiableAttr>()) {
3245-
strategies[sei] = StructExtractDifferentiationStrategy::Fieldwise;
3246-
SILClonerWithScopes::visitStructExtractInst(sei);
3247-
return;
3248-
}
3249-
// The FieldwiseProductSpace strategy is not appropriate, so use the Getter
3250-
// strategy.
3251-
assert(getterFn);
3252-
strategies[sei] = StructExtractDifferentiationStrategy::Getter;
3253-
SILAutoDiffIndices indices(/*source*/ 0,
3254-
AutoDiffIndexSubset::getDefault(getASTContext(), 1, true));
3255-
auto *attr = context.lookUpDifferentiableAttr(getterFn, indices);
3256-
if (!attr) {
3257-
context.emitNondifferentiabilityError(
3258-
sei, invoker, diag::autodiff_property_not_differentiable);
3259-
errorOccurred = true;
3260-
return;
3261-
}
3262-
// Reference and apply the VJP.
3263-
auto loc = sei->getLoc();
3264-
auto *getterVJP = getAssociatedFunction(
3265-
context, getterFn, attr, AutoDiffAssociatedFunctionKind::VJP,
3266-
attr->getVJPName());
3267-
assert(getterVJP && "Expected to find getter VJP");
3268-
auto *getterVJPRef = getBuilder().createFunctionRef(loc, getterVJP);
3269-
auto *getterVJPApply = getBuilder().createApply(
3270-
loc, getterVJPRef,
3271-
getOpSubstitutionMap(getterVJP->getForwardingSubstitutionMap()),
3272-
/*args*/ {getOpValue(sei->getOperand())}, /*isNonThrowing*/ false);
3273-
// Extract direct results from `getterVJPApply`.
3274-
SmallVector<SILValue, 8> vjpDirectResults;
3275-
extractAllElements(getterVJPApply, getBuilder(), vjpDirectResults);
3276-
// Map original result.
3277-
auto originalDirectResults =
3278-
ArrayRef<SILValue>(vjpDirectResults).drop_back(1);
3279-
auto originalDirectResult = joinElements(originalDirectResults,
3280-
getBuilder(),
3281-
getterVJPApply->getLoc());
3282-
mapValue(sei, originalDirectResult);
3283-
// Checkpoint the pullback.
3284-
auto pullback = vjpDirectResults.back();
3285-
// TODO: Check whether it's necessary to reabstract getter pullbacks.
3286-
pullbackInfo.addPullbackDecl(sei, getOpType(pullback->getType()));
3287-
pullbackValues[sei->getParent()].push_back(pullback);
3231+
// This instruction is active. Use the field wise differentiation strategy
3232+
// to differentiate the struct extract instruction.
3233+
strategies[sei] = StructExtractDifferentiationStrategy::Fieldwise;
3234+
SILClonerWithScopes::visitStructExtractInst(sei);
32883235
}
32893236

32903237
void visitStructElementAddrInst(StructElementAddrInst *seai) {
@@ -3297,78 +3244,10 @@ class VJPEmitter final
32973244
SILClonerWithScopes::visitStructElementAddrInst(seai);
32983245
return;
32993246
}
3300-
// This instruction is active. Determine the appropriate differentiation
3301-
// strategy, and use it.
3302-
// Find the corresponding getter.
3303-
auto *getterDecl = seai->getField()->getGetter();
3304-
assert(getterDecl);
3305-
auto *getterFn = getModule().lookUpFunction(
3306-
SILDeclRef(getterDecl, SILDeclRef::Kind::Func));
3307-
auto *structDecl = seai->getStructDecl();
3308-
if (!getterFn ||
3309-
structDecl->getAttrs().hasAttribute<FieldwiseDifferentiableAttr>()) {
3310-
strategies[seai] = StructExtractDifferentiationStrategy::Fieldwise;
3311-
SILClonerWithScopes::visitStructElementAddrInst(seai);
3312-
return;
3313-
}
3314-
// The FieldwiseProductSpace strategy is not appropriate, so use the Getter
3315-
// strategy.
3316-
assert(getterFn);
3317-
strategies[seai] = StructExtractDifferentiationStrategy::Getter;
3318-
SILAutoDiffIndices indices(/*source*/ 0,
3319-
AutoDiffIndexSubset::getDefault(getASTContext(), 1, true));
3320-
auto *attr = context.lookUpDifferentiableAttr(getterFn, indices);
3321-
if (!attr) {
3322-
context.emitNondifferentiabilityError(
3323-
seai, invoker, diag::autodiff_property_not_differentiable);
3324-
errorOccurred = true;
3325-
return;
3326-
}
3327-
// Set generic context scope before getting VJP function type.
3328-
auto vjpGenSig = SubsMap.getGenericSignature()
3329-
? SubsMap.getGenericSignature()->getCanonicalSignature()
3330-
: nullptr;
3331-
Lowering::GenericContextScope genericContextScope(
3332-
context.getTypeConverter(), vjpGenSig);
3333-
// Reference the getter VJP.
3334-
auto loc = seai->getLoc();
3335-
auto *getterVJP = getModule().lookUpFunction(attr->getVJPName());
3336-
assert(getterVJP && "Expected to find getter VJP");
3337-
auto vjpFnTy = getterVJP->getLoweredFunctionType();
3338-
auto *getterVJPRef = getBuilder().createFunctionRef(loc, getterVJP);
3339-
// Store getter VJP arguments and indirect result buffers.
3340-
SmallVector<SILValue, 8> vjpArgs;
3341-
SmallVector<AllocStackInst *, 8> vjpIndirectResults;
3342-
for (auto indRes : vjpFnTy->getIndirectFormalResults()) {
3343-
auto *alloc = getBuilder().createAllocStack(
3344-
loc, getOpType(indRes.getSILStorageType()));
3345-
vjpArgs.push_back(alloc);
3346-
vjpIndirectResults.push_back(alloc);
3347-
}
3348-
vjpArgs.push_back(getOpValue(seai->getOperand()));
3349-
// Apply the getter VJP.
3350-
auto *getterVJPApply = getBuilder().createApply(
3351-
loc, getterVJPRef,
3352-
getOpSubstitutionMap(getterVJP->getForwardingSubstitutionMap()),
3353-
vjpArgs, /*isNonThrowing*/ false);
3354-
// Collect all results from `getterVJPApply` in type-defined order.
3355-
SmallVector<SILValue, 8> vjpDirectResults;
3356-
extractAllElements(getterVJPApply, getBuilder(), vjpDirectResults);
3357-
SmallVector<SILValue, 8> allResults;
3358-
collectAllActualResultsInTypeOrder(
3359-
getterVJPApply, vjpDirectResults,
3360-
getterVJPApply->getIndirectSILResults(), allResults);
3361-
// Deallocate VJP indirect results.
3362-
for (auto alloc : vjpIndirectResults)
3363-
getBuilder().createDeallocStack(loc, alloc);
3364-
auto originalDirectResult = allResults[indices.source];
3365-
// Map original result.
3366-
mapValue(seai, originalDirectResult);
3367-
// Checkpoint the pullback.
3368-
SILValue pullback = vjpDirectResults.back();
3369-
// TODO: Check whether it's necessary to reabstract getter pullbacks.
3370-
pullbackInfo.addPullbackDecl(seai, getOpType(pullback->getType()));
3371-
pullbackValues[seai->getParent()].push_back(pullback);
3247+
// This instruction is active. Use the field wise differentiation strategy
3248+
// to differentiate the struct extract instruction.
3249+
strategies[seai] = StructExtractDifferentiationStrategy::Fieldwise;
3250+
SILClonerWithScopes::visitStructElementAddrInst(seai);
33723251
}
33733252

33743253
// If an `apply` has active results or active inout parameters, replace it
@@ -4839,29 +4718,6 @@ class AdjointEmitter final : public SILInstructionVisitor<AdjointEmitter> {
48394718
}
48404719
return;
48414720
}
4842-
case StructExtractDifferentiationStrategy::Getter: {
4843-
// Get the pullback.
4844-
auto *pullbackField = getPullbackInfo().lookUpPullbackDecl(sei);
4845-
assert(pullbackField);
4846-
auto pullback = builder.createStructExtract(
4847-
loc, getAdjointBlockPullbackStructArgument(sei->getParent()),
4848-
pullbackField);
4849-
4850-
// Construct the pullback arguments.
4851-
auto av = takeAdjointValue(sei);
4852-
auto vector = materializeAdjointDirect(std::move(av), loc);
4853-
4854-
// Call the pullback.
4855-
auto *pullbackCall = builder.createApply(
4856-
loc, pullback, SubstitutionMap(), {vector}, /*isNonThrowing*/ false);
4857-
assert(!pullbackCall->hasIndirectResults());
4858-
4859-
// Accumulate adjoint for the `struct_extract` operand.
4860-
addAdjointValue(sei->getOperand(),
4861-
makeConcreteAdjointValue(
4862-
ValueWithCleanup(pullbackCall, vector.getCleanup())));
4863-
break;
4864-
}
48654721
}
48664722
}
48674723

lib/Sema/TypeCheckAttr.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2891,7 +2891,7 @@ void AttributeChecker::visitDifferentiableAttr(DifferentiableAttr *attr) {
28912891
(attr->getJVP() || attr->getVJP())) {
28922892
diagnoseAndRemoveAttr(attr,
28932893
diag::differentiable_attr_stored_property_variable_unsupported);
2894-
return;
2894+
return;
28952895
}
28962896
// When used directly on a storage decl (stored/computed property or
28972897
// subscript), the getter is currently inferred to be `@differentiable`.

test/AutoDiff/autodiff_diagnostics.swift

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,6 @@ extension S : Differentiable, VectorNumeric {
3838
typealias TangentVector = S
3939
}
4040

41-
// expected-error @+2 {{function is not differentiable}}
42-
// expected-note @+1 {{property is not differentiable}}
4341
_ = gradient(at: S(p: 0)) { s in 2 * s.p }
4442

4543
struct NoDerivativeProperty : Differentiable {

test/AutoDiff/differentiable_attr_silgen.swift

Lines changed: 0 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -76,43 +76,6 @@ public func dhasvjp(_ x: Float, _ y: Float) -> (Float, (Float) -> (Float, Float)
7676

7777
// CHECK-LABEL: sil [ossa] @dhasvjp
7878

79-
//===----------------------------------------------------------------------===//
80-
// Stored property
81-
//===----------------------------------------------------------------------===//
82-
83-
struct DiffStoredProp {
84-
@differentiable(wrt: (self), jvp: storedPropJVP, vjp: storedPropVJP)
85-
let storedProp: Float
86-
87-
@_silgen_name("storedPropJVP")
88-
func storedPropJVP() -> (Float, (DiffStoredProp) -> Float) {
89-
fatalError("unimplemented")
90-
}
91-
92-
@_silgen_name("storedPropVJP")
93-
func storedPropVJP() -> (Float, (Float) -> DiffStoredProp) {
94-
fatalError("unimplemented")
95-
}
96-
}
97-
98-
extension DiffStoredProp : VectorNumeric {
99-
static var zero: DiffStoredProp { fatalError("unimplemented") }
100-
static func + (lhs: DiffStoredProp, rhs: DiffStoredProp) -> DiffStoredProp {
101-
fatalError("unimplemented")
102-
}
103-
static func - (lhs: DiffStoredProp, rhs: DiffStoredProp) -> DiffStoredProp {
104-
fatalError("unimplemented")
105-
}
106-
typealias Scalar = Float
107-
static func * (lhs: Float, rhs: DiffStoredProp) -> DiffStoredProp {
108-
fatalError("unimplemented")
109-
}
110-
}
111-
112-
extension DiffStoredProp : Differentiable {
113-
typealias TangentVector = DiffStoredProp
114-
}
115-
11679
//===----------------------------------------------------------------------===//
11780
// Computed property
11881
//===----------------------------------------------------------------------===//

test/AutoDiff/differentiable_attr_type_checking.swift

Lines changed: 17 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
// RUN: %target-swift-frontend -typecheck -verify %s
22

33
@differentiable // expected-error {{'@differentiable' attribute cannot be applied to this declaration}}
4-
let global: Float = 1
4+
let globalConst: Float = 1
5+
6+
@differentiable // expected-error {{'@differentiable' attribute cannot be applied to this declaration}}
7+
var globalVar: Float = 1
58

69
func testLocalVariables() {
710
// expected-error @+1 {{'_' has no parameters to differentiate with respect to}}
@@ -225,25 +228,18 @@ class Foo {
225228
}
226229

227230
struct JVPStruct {
231+
@differentiable
228232
let p: Float
229233

230-
@differentiable(wrt: (self), jvp: storedPropJVP)
231-
let storedImmutableOk: Float
232-
233-
// expected-error @+1 {{'storedPropJVP' does not have expected type '(JVPStruct) -> () -> (Double, (JVPStruct.TangentVector) -> Double.TangentVector)' (aka '(JVPStruct) -> () -> (Double, (JVPStruct) -> Double)'}}
234-
@differentiable(wrt: (self), jvp: storedPropJVP)
235-
let storedImmutableWrongType: Double
236-
237-
@differentiable(wrt: (self), jvp: storedPropJVP)
238-
var storedMutableOk: Float
239-
240-
// expected-error @+1 {{'storedPropJVP' does not have expected type '(JVPStruct) -> () -> (Double, (JVPStruct.TangentVector) -> Double.TangentVector)' (aka '(JVPStruct) -> () -> (Double, (JVPStruct) -> Double)'}}
241-
@differentiable(wrt: (self), jvp: storedPropJVP)
242-
var storedMutableWrongType: Double
234+
// expected-error @+1 {{'funcJVP' does not have expected type '(JVPStruct) -> () -> (Double, (JVPStruct.TangentVector) -> Double.TangentVector)' (aka '(JVPStruct) -> () -> (Double, (JVPStruct) -> Double)'}}
235+
@differentiable(wrt: (self), jvp: funcJVP)
236+
func funcWrongType() -> Double {
237+
fatalError("unimplemented")
238+
}
243239
}
244240

245241
extension JVPStruct {
246-
func storedPropJVP() -> (Float, (JVPStruct) -> Float) {
242+
func funcJVP() -> (Float, (JVPStruct) -> Float) {
247243
fatalError("unimplemented")
248244
}
249245
}
@@ -383,23 +379,15 @@ func vjpNonDiffResult2(x: Float) -> (Float, Int) {
383379
struct VJPStruct {
384380
let p: Float
385381

386-
@differentiable(vjp: storedPropVJP)
387-
let storedImmutableOk: Float
388-
389-
// expected-error @+1 {{'storedPropVJP' does not have expected type '(VJPStruct) -> () -> (Double, (Double.TangentVector) -> VJPStruct.TangentVector)' (aka '(VJPStruct) -> () -> (Double, (Double) -> VJPStruct)'}}
390-
@differentiable(vjp: storedPropVJP)
391-
let storedImmutableWrongType: Double
392-
393-
@differentiable(vjp: storedPropVJP)
394-
var storedMutableOk: Float
395-
396-
// expected-error @+1 {{'storedPropVJP' does not have expected type '(VJPStruct) -> () -> (Double, (Double.TangentVector) -> VJPStruct.TangentVector)' (aka '(VJPStruct) -> () -> (Double, (Double) -> VJPStruct)'}}
397-
@differentiable(vjp: storedPropVJP)
398-
var storedMutableWrongType: Double
382+
// expected-error @+1 {{'funcVJP' does not have expected type '(VJPStruct) -> () -> (Double, (Double.TangentVector) -> VJPStruct.TangentVector)' (aka '(VJPStruct) -> () -> (Double, (Double) -> VJPStruct)'}}
383+
@differentiable(vjp: funcVJP)
384+
func funcWrongType() -> Double {
385+
fatalError("unimplemented")
386+
}
399387
}
400388

401389
extension VJPStruct {
402-
func storedPropVJP() -> (Float, (Float) -> VJPStruct) {
390+
func funcVJP() -> (Float, (Float) -> VJPStruct) {
403391
fatalError("unimplemented")
404392
}
405393
}

test/AutoDiff/differentiating_attr_type_checking.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -298,7 +298,7 @@ func vjpConsistent(_ x: Float) -> (value: Float, pullback: (Float) -> Float) {
298298

299299
// Test usage of `@differentiable` on a stored property
300300
struct PropertyDiff : Differentiable & AdditiveArithmetic {
301-
// expected-error @+1 {{stored properties cannot be marked with '@differentiable'}}
301+
// expected-error @+1 {{'jvp:' or 'vjp:' cannot be specified for stored properties}}
302302
@differentiable(vjp: vjpPropertyA)
303303
var a: Float = 1
304304
typealias TangentVector = PropertyDiff

test/AutoDiff/e2e_differentiable_property.swift

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,10 +30,9 @@ struct Space {
3030
}
3131

3232
private let storedX: Float
33-
34-
/// `y` is a stored property with a custom vjp for its getter.
35-
@differentiable(vjp: vjpY)
36-
let y: Float
33+
34+
@differentiable
35+
var y: Float
3736

3837
func vjpY() -> (Float, (Float) -> TangentSpace) {
3938
return (y, { v in TangentSpace(dx: 0, dy: v) })
@@ -70,7 +69,7 @@ E2EDifferentiablePropertyTests.test("stored property") {
7069

7170
struct GenericMemberWrapper<T : Differentiable> : Differentiable {
7271
// Stored property.
73-
@differentiable(vjp: vjpX)
72+
@differentiable
7473
var x: T
7574

7675
func vjpX() -> (T, (T.TangentVector) -> GenericMemberWrapper.TangentVector) {

0 commit comments

Comments
 (0)