Skip to content

Commit ba9b37e

Browse files
authored
differentiate struct_extract instructions using VJP (#21567)
1 parent 0e7438e commit ba9b37e

File tree

9 files changed

+208
-42
lines changed

9 files changed

+208
-42
lines changed

include/swift/AST/DiagnosticsSIL.def

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -374,6 +374,8 @@ ERROR(autodiff_unsupported_type,none,
374374
"differentiating '%0' is not supported yet", (Type))
375375
ERROR(autodiff_function_not_differentiable,none,
376376
"function is not differentiable", ())
377+
ERROR(autodiff_property_not_differentiable,none,
378+
"property is not differentiable", ())
377379
NOTE(autodiff_function_generic_functions_unsupported,none,
378380
"differentiating generic functions is not supported yet", ())
379381
NOTE(autodiff_value_defined_here,none,

lib/SILOptimizer/Mandatory/TFDifferentiation.cpp

Lines changed: 108 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -472,9 +472,9 @@ class PrimalInfo {
472472
/// corresponding tape of its type.
473473
DenseMap<ApplyInst *, VarDecl *> nestedStaticPrimalValueMap;
474474

475-
/// Mapping from `apply` instructions in the original function to the
476-
/// corresponding pullback decl in the primal struct.
477-
DenseMap<ApplyInst *, VarDecl *> pullbackValueMap;
475+
/// Mapping from `apply` and `struct_extract` instructions in the original
476+
/// function to the corresponding pullback decl in the primal struct.
477+
DenseMap<SILInstruction *, VarDecl *> pullbackValueMap;
478478

479479
/// Mapping from types of control-dependent nested primal values to district
480480
/// tapes.
@@ -573,7 +573,7 @@ class PrimalInfo {
573573
}
574574

575575
/// Add a pullback to the primal value struct.
576-
VarDecl *addPullbackDecl(ApplyInst *inst, Type pullbackType) {
576+
VarDecl *addPullbackDecl(SILInstruction *inst, Type pullbackType) {
577577
// Decls must have AST types (not `SILFunctionType`), so we convert the
578578
// `SILFunctionType` of the pullback to a `FunctionType` with the same
579579
// parameters and results.
@@ -605,9 +605,9 @@ class PrimalInfo {
605605
: lookup->getSecond();
606606
}
607607

608-
/// Finds the pullback decl in the primal value struct for an `apply` in the
609-
/// original function.
610-
VarDecl *lookUpPullbackDecl(ApplyInst *inst) {
608+
/// Finds the pullback decl in the primal value struct for an `apply` or
609+
/// `struct_extract` in the original function.
610+
VarDecl *lookUpPullbackDecl(SILInstruction *inst) {
611611
auto lookup = pullbackValueMap.find(inst);
612612
return lookup == pullbackValueMap.end() ? nullptr
613613
: lookup->getSecond();
@@ -2227,6 +2227,79 @@ class PrimalGenCloner final : public SILClonerWithScopes<PrimalGenCloner> {
22272227
SILClonerWithScopes::visitReleaseValueInst(rvi);
22282228
}
22292229

2230+
void visitStructExtractInst(StructExtractInst *sei) {
2231+
// Special handling logic only applies when the `struct_extract` is active.
2232+
// If not, just do standard cloning.
2233+
if (!activityInfo.isActive(sei, synthesis.indices)) {
2234+
LLVM_DEBUG(getADDebugStream() << "Not active:\n" << *sei << '\n');
2235+
SILClonerWithScopes::visitStructExtractInst(sei);
2236+
return;
2237+
}
2238+
2239+
// This instruction is active. Replace it with a call to the corresponding
2240+
// getter's VJP.
2241+
2242+
// Find the corresponding getter and its VJP.
2243+
auto *getterDecl = sei->getField()->getGetter();
2244+
assert(getterDecl);
2245+
auto *getterFn = getContext().getModule().lookUpFunction(
2246+
SILDeclRef(getterDecl, SILDeclRef::Kind::Func));
2247+
if (!getterFn) {
2248+
getContext().emitNondifferentiabilityError(
2249+
sei, synthesis.task, diag::autodiff_property_not_differentiable);
2250+
errorOccurred = true;
2251+
return;
2252+
}
2253+
auto getterDiffAttrs = getterFn->getDifferentiableAttrs();
2254+
if (getterDiffAttrs.size() < 1) {
2255+
getContext().emitNondifferentiabilityError(
2256+
sei, synthesis.task, diag::autodiff_property_not_differentiable);
2257+
errorOccurred = true;
2258+
return;
2259+
}
2260+
auto *getterDiffAttr = getterDiffAttrs[0];
2261+
if (!getterDiffAttr->hasVJP()) {
2262+
getContext().emitNondifferentiabilityError(
2263+
sei, synthesis.task, diag::autodiff_property_not_differentiable);
2264+
errorOccurred = true;
2265+
return;
2266+
}
2267+
assert(getterDiffAttr->getIndices() ==
2268+
SILAutoDiffIndices(/*source*/ 0, /*parameters*/{0}));
2269+
auto *getterVJP = lookUpOrLinkFunction(getterDiffAttr->getVJPName(),
2270+
getContext().getModule());
2271+
2272+
// Reference and apply the VJP.
2273+
auto loc = sei->getLoc();
2274+
auto *getterVJPRef = getBuilder().createFunctionRef(loc, getterVJP);
2275+
auto *getterVJPApply = getBuilder().createApply(
2276+
loc, getterVJPRef, /*substitutionMap*/ {},
2277+
/*args*/ {getMappedValue(sei->getOperand())}, /*isNonThrowing*/ false);
2278+
2279+
// Get the VJP results (original results and pullback).
2280+
SmallVector<SILValue, 8> vjpDirectResults;
2281+
extractAllElements(getterVJPApply, getBuilder(), vjpDirectResults);
2282+
ArrayRef<SILValue> originalDirectResults =
2283+
ArrayRef<SILValue>(vjpDirectResults).drop_back(1);
2284+
SILValue originalDirectResult = joinElements(originalDirectResults,
2285+
getBuilder(),
2286+
getterVJPApply->getLoc());
2287+
SILValue pullback = vjpDirectResults.back();
2288+
2289+
// Store the original result to the value map.
2290+
mapValue(sei, originalDirectResult);
2291+
2292+
// Checkpoint the original results.
2293+
getPrimalInfo().addStaticPrimalValueDecl(sei);
2294+
getBuilder().createRetainValue(loc, originalDirectResult,
2295+
getBuilder().getDefaultAtomicity());
2296+
staticPrimalValues.push_back(originalDirectResult);
2297+
2298+
// Checkpoint the pullback.
2299+
getPrimalInfo().addPullbackDecl(sei, pullback->getType().getASTType());
2300+
staticPrimalValues.push_back(pullback);
2301+
}
2302+
22302303
void visitApplyInst(ApplyInst *ai) {
22312304
if (DifferentiationUseVJP)
22322305
visitApplyInstWithVJP(ai);
@@ -3522,33 +3595,36 @@ class AdjointEmitter final : public SILInstructionVisitor<AdjointEmitter> {
35223595
}
35233596
}
35243597

3525-
/// Handle `struct_extract` instruction.
3526-
/// y = struct_extract <key>, x
3527-
/// adj[x] = struct (0, ..., key: adj[y], ..., 0)
35283598
void visitStructExtractInst(StructExtractInst *sei) {
3529-
auto *structDecl = sei->getStructDecl();
3530-
auto av = getAdjointValue(sei);
3531-
switch (av.getKind()) {
3532-
case AdjointValue::Kind::Zero:
3533-
addAdjointValue(sei->getOperand(),
3534-
AdjointValue::getZero(sei->getOperand()->getType()));
3535-
break;
3536-
case AdjointValue::Kind::Materialized:
3537-
case AdjointValue::Kind::Aggregate: {
3538-
SmallVector<AdjointValue, 8> eltVals;
3539-
for (auto *field : structDecl->getStoredProperties()) {
3540-
if (field == sei->getField())
3541-
eltVals.push_back(av);
3542-
else
3543-
eltVals.push_back(AdjointValue::getZero(
3544-
SILType::getPrimitiveObjectType(
3545-
field->getType()->getCanonicalType())));
3546-
}
3547-
addAdjointValue(sei->getOperand(),
3548-
AdjointValue::getAggregate(sei->getOperand()->getType(),
3549-
eltVals, allocator));
3550-
}
3599+
// Replace a `struct_extract` with a call to its pullback.
3600+
auto loc = remapLocation(sei->getLoc());
3601+
3602+
// Get the pullback.
3603+
auto *pullbackField = getPrimalInfo().lookUpPullbackDecl(sei);
3604+
if (!pullbackField) {
3605+
// Inactive `struct_extract` instructions don't need to be cloned into the
3606+
// adjoint.
3607+
assert(!activityInfo.isActive(sei, synthesis.indices));
3608+
return;
35513609
}
3610+
SILValue pullback = builder.createStructExtract(loc,
3611+
primalValueAggregateInAdj,
3612+
pullbackField);
3613+
3614+
// Construct the pullback arguments.
3615+
SmallVector<SILValue, 8> args;
3616+
auto seed = getAdjointValue(sei);
3617+
assert(seed.getType().isObject());
3618+
args.push_back(materializeAdjointDirect(seed, loc));
3619+
3620+
// Call the pullback.
3621+
auto *pullbackCall = builder.createApply(loc, pullback, SubstitutionMap(),
3622+
args, /*isNonThrowing*/ false);
3623+
assert(!pullbackCall->hasIndirectResults());
3624+
3625+
// Set adjoint for the `struct_extract` operand.
3626+
addAdjointValue(sei->getOperand(),
3627+
AdjointValue::getMaterialized(pullbackCall));
35523628
}
35533629

35543630
/// Handle `tuple` instruction.

stdlib/public/core/FloatingPointTypes.swift.gyb

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1582,11 +1582,25 @@ extension ${Self} {
15821582

15831583
extension ${Self} {
15841584
@_transparent
1585+
// SWIFT_ENABLE_TENSORFLOW
1586+
@differentiable(adjoint: _adjointNegate)
15851587
public static prefix func - (x: ${Self}) -> ${Self} {
15861588
return ${Self}(Builtin.fneg_FPIEEE${bits}(x._value))
15871589
}
15881590
}
15891591

1592+
// SWIFT_ENABLE_TENSORFLOW
1593+
extension ${Self} {
1594+
@usableFromInline
1595+
@_transparent
1596+
// SWIFT_ENABLE_TENSORFLOW
1597+
static func _adjointNegate(
1598+
seed: ${Self}, originalValue: ${Self}, x: ${Self}
1599+
) -> ${Self} {
1600+
return -seed
1601+
}
1602+
}
1603+
15901604
//===----------------------------------------------------------------------===//
15911605
// Explicit conversions between types.
15921606
//===----------------------------------------------------------------------===//

test/AutoDiff/autodiff_diagnostics.swift

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,28 @@ func generic<T: Differentiable & FloatingPoint>(_ x: T) -> T {
3131
return x + 1
3232
}
3333

34+
//===----------------------------------------------------------------------===//
35+
// Non-differentiable stored properties
36+
//===----------------------------------------------------------------------===//
37+
38+
struct S {
39+
let p: Float
40+
}
41+
42+
extension S : Differentiable, VectorNumeric {
43+
static var zero: S { return S(p: 0) }
44+
typealias Scalar = Float
45+
static func + (lhs: S, rhs: S) -> S { return S(p: lhs.p + rhs.p) }
46+
static func - (lhs: S, rhs: S) -> S { return S(p: lhs.p - rhs.p) }
47+
static func * (lhs: Float, rhs: S) -> S { return S(p: lhs * rhs.p) }
48+
49+
typealias TangentVector = S
50+
typealias CotangentVector = S
51+
}
52+
53+
// expected-error @+1 {{property is not differentiable}}
54+
_ = gradient(at: S(p: 0)) { s in 2 * s.p }
55+
3456
//===----------------------------------------------------------------------===//
3557
// Function composition
3658
//===----------------------------------------------------------------------===//

test/AutoDiff/e2e_differentiable_property.swift

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -75,14 +75,12 @@ E2EDifferentiablePropertyTests.test("computed property") {
7575
expectEqual(expectedGrad, actualGrad)
7676
}
7777

78-
// FIXME: The AD pass cannot differentiate this because it sees
79-
// `struct_extract`s instead of calls to getters.
80-
// E2EDifferentiablePropertyTests.test("stored property") {
81-
// let actualGrad = gradient(at: Space(x: 0, y: 0)) { (point: Space) -> Float in
82-
// return 3 * point.y
83-
// }
84-
// let expectedGrad = TangentSpace(dx: 0, dy: 3)
85-
// expectEqual(expectedGrad, actualGrad)
86-
// }
78+
E2EDifferentiablePropertyTests.test("stored property") {
79+
let actualGrad = gradient(at: Space(x: 0, y: 0)) { (point: Space) -> Float in
80+
return 3 * point.y
81+
}
82+
let expectedGrad = TangentSpace(dx: 0, dy: 3)
83+
expectEqual(expectedGrad, actualGrad)
84+
}
8785

8886
runAllTests()

test/AutoDiff/method.swift

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,12 @@ var MethodTests = TestSuite("Method")
99
// ==== Tests with generated adjoint ====
1010

1111
struct Parameter : Equatable {
12+
@differentiable(wrt: (self), vjp: vjpX)
1213
let x: Float
14+
15+
func vjpX() -> (Float, (Float) -> Parameter) {
16+
return (x, { dx in Parameter(x: dx) } )
17+
}
1318
}
1419

1520
extension Parameter {
@@ -132,7 +137,12 @@ MethodTests.test("static method with generated adjoint, wrt all params") {
132137
// ==== Tests with custom adjoint ====
133138

134139
struct CustomParameter : Equatable {
140+
@differentiable(wrt: (self), vjp: vjpX)
135141
let x: Float
142+
143+
func vjpX() -> (Float, (Float) -> CustomParameter) {
144+
return (x, { dx in CustomParameter(x: dx) })
145+
}
136146
}
137147

138148
extension CustomParameter : Differentiable, VectorNumeric {

test/AutoDiff/protocol_requirement_autodiff.swift

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,24 @@ struct Quadratic : DiffReq, Equatable {
2727
typealias TangentVector = Quadratic
2828
typealias CotangentVector = Quadratic
2929

30-
let a, b, c: Float
30+
@differentiable(wrt: (self), vjp: vjpA)
31+
let a: Float
32+
func vjpA() -> (Float, (Float) -> Quadratic) {
33+
return (a, { da in Quadratic(da, 0, 0) } )
34+
}
35+
36+
@differentiable(wrt: (self), vjp: vjpB)
37+
let b: Float
38+
func vjpB() -> (Float, (Float) -> Quadratic) {
39+
return (b, { db in Quadratic(0, db, 0) } )
40+
}
41+
42+
@differentiable(wrt: (self), vjp: vjpC)
43+
let c: Float
44+
func vjpC() -> (Float, (Float) -> Quadratic) {
45+
return (c, { dc in Quadratic(0, 0, dc) } )
46+
}
47+
3148
init(_ a: Float, _ b: Float, _ c: Float) {
3249
self.a = a
3350
self.b = b

test/AutoDiff/simple_model.swift

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,17 @@ import StdlibUnittest
77
var SimpleModelTests = TestSuite("SimpleModel")
88

99
struct DenseLayer : Equatable {
10+
@differentiable(wrt: (self), vjp: vjpW)
1011
let w: Float
12+
func vjpW() -> (Float, (Float) -> DenseLayer) {
13+
return (w, { dw in DenseLayer(w: dw, b: 0) } )
14+
}
15+
16+
@differentiable(wrt: (self), vjp: vjpB)
1117
let b: Float
18+
func vjpB() -> (Float, (Float) -> DenseLayer) {
19+
return (b, { db in DenseLayer(w: 0, b: db) } )
20+
}
1221
}
1322

1423
extension DenseLayer : Differentiable, VectorNumeric {
@@ -39,9 +48,23 @@ extension DenseLayer {
3948
}
4049

4150
struct Model : Equatable {
51+
@differentiable(wrt: (self), vjp: vjpL1)
4252
let l1: DenseLayer
53+
func vjpL1() -> (DenseLayer, (DenseLayer) -> Model) {
54+
return (l1, { dl1 in Model(l1: dl1, l2: DenseLayer.zero, l3: DenseLayer.zero) } )
55+
}
56+
57+
@differentiable(wrt: (self), vjp: vjpL2)
4358
let l2: DenseLayer
59+
func vjpL2() -> (DenseLayer, (DenseLayer) -> Model) {
60+
return (l2, { dl2 in Model(l1: DenseLayer.zero, l2: dl2, l3: DenseLayer.zero) } )
61+
}
62+
63+
@differentiable(wrt: (self), vjp: vjpL3)
4464
let l3: DenseLayer
65+
func vjpL3() -> (DenseLayer, (DenseLayer) -> Model) {
66+
return (l3, { dl3 in Model(l1: DenseLayer.zero, l2: DenseLayer.zero, l3: dl3) } )
67+
}
4568
}
4669

4770
extension Model : Differentiable, VectorNumeric {

test/AutoDiff/witness_table_silgen.swift

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,11 @@ struct S : Proto, VectorNumeric {
2121
typealias TangentVector = S
2222
typealias CotangentVector = S
2323

24+
@differentiable(wrt: (self), vjp: vjpP)
2425
let p: Float
26+
func vjpP() -> (Float, (Float) -> S) {
27+
return (p, { dp in S(p: dp) })
28+
}
2529

2630
func function1(_ x: Float, _ y: Float) -> Float {
2731
return x + y + p

0 commit comments

Comments
 (0)