Skip to content

Commit eaf695d

Browse files
author
Marc Rasi
committed
differentiate struct_extract instructions
1 parent 0e7438e commit eaf695d

File tree

9 files changed

+222
-41
lines changed

9 files changed

+222
-41
lines changed

include/swift/AST/DiagnosticsSIL.def

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -374,6 +374,8 @@ ERROR(autodiff_unsupported_type,none,
374374
"differentiating '%0' is not supported yet", (Type))
375375
ERROR(autodiff_function_not_differentiable,none,
376376
"function is not differentiable", ())
377+
ERROR(autodiff_property_not_differentiable,none,
378+
"property is not differentiable", ())
377379
NOTE(autodiff_function_generic_functions_unsupported,none,
378380
"differentiating generic functions is not supported yet", ())
379381
NOTE(autodiff_value_defined_here,none,

lib/SILOptimizer/Mandatory/TFDifferentiation.cpp

Lines changed: 122 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -472,9 +472,9 @@ class PrimalInfo {
472472
/// corresponding tape of its type.
473473
DenseMap<ApplyInst *, VarDecl *> nestedStaticPrimalValueMap;
474474

475-
/// Mapping from `apply` instructions in the original function to the
476-
/// corresponding pullback decl in the primal struct.
477-
DenseMap<ApplyInst *, VarDecl *> pullbackValueMap;
475+
/// Mapping from `apply` and `struct_extract` instructions in the original
476+
/// function to the corresponding pullback decl in the primal struct.
477+
DenseMap<SILInstruction *, VarDecl *> pullbackValueMap;
478478

479479
/// Mapping from types of control-dependent nested primal values to district
480480
/// tapes.
@@ -573,7 +573,7 @@ class PrimalInfo {
573573
}
574574

575575
/// Add a pullback to the primal value struct.
576-
VarDecl *addPullbackDecl(ApplyInst *inst, Type pullbackType) {
576+
VarDecl *addPullbackDecl(SILInstruction *inst, Type pullbackType) {
577577
// Decls must have AST types (not `SILFunctionType`), so we convert the
578578
// `SILFunctionType` of the pullback to a `FunctionType` with the same
579579
// parameters and results.
@@ -605,9 +605,9 @@ class PrimalInfo {
605605
: lookup->getSecond();
606606
}
607607

608-
/// Finds the pullback decl in the primal value struct for an `apply` in the
609-
/// original function.
610-
VarDecl *lookUpPullbackDecl(ApplyInst *inst) {
608+
/// Finds the pullback decl in the primal value struct for an `apply` or
609+
/// `struct_extract` in the original function.
610+
VarDecl *lookUpPullbackDecl(SILInstruction *inst) {
611611
auto lookup = pullbackValueMap.find(inst);
612612
return lookup == pullbackValueMap.end() ? nullptr
613613
: lookup->getSecond();
@@ -714,6 +714,9 @@ class DifferentiationTask {
714714
/// Note: This is only used when `DifferentiationUseVJP`.
715715
DenseMap<ApplyInst *, NestedApplyActivity> nestedApplyActivities;
716716

717+
DenseMap<StructExtractInst *, NestedStructExtractStrategy>
718+
nestedStructExtractStrategies;
719+
717720
/// Cache for associated functions.
718721
SILFunction *primal = nullptr;
719722
SILFunction *adjoint = nullptr;
@@ -2227,6 +2230,77 @@ class PrimalGenCloner final : public SILClonerWithScopes<PrimalGenCloner> {
22272230
SILClonerWithScopes::visitReleaseValueInst(rvi);
22282231
}
22292232

2233+
void visitStructExtractInst(StructExtractInst *sei) {
2234+
// Special handling logic only applies when the `struct_extract` is active.
2235+
// If not, just do standard cloning.
2236+
if (!activityInfo.isActive(sei, synthesis.indices)) {
2237+
LLVM_DEBUG(getADDebugStream() << "Not active:\n" << *sei << '\n');
2238+
SILClonerWithScopes::visitStructExtractInst(sei);
2239+
return;
2240+
}
2241+
2242+
// This instruction is active. Replace it with a call to the corresponding
2243+
// getter's VJP.
2244+
2245+
// Find the corresponding getter and its VJP.
2246+
auto *getterDecl = sei->getField()->getGetter();
2247+
assert(getterDecl);
2248+
auto *getterFn = getContext().getModule().lookUpFunction(
2249+
SILDeclRef(getterDecl, SILDeclRef::Kind::Func));
2250+
if (!getterFn) {
2251+
getContext().emitNondifferentiabilityError(
2252+
sei, synthesis.task, diag::autodiff_property_not_differentiable);
2253+
errorOccurred = true;
2254+
return;
2255+
}
2256+
auto getterDiffAttrs = getterFn->getDifferentiableAttrs();
2257+
if (getterDiffAttrs.size() < 1) {
2258+
getContext().emitNondifferentiabilityError(
2259+
sei, synthesis.task, diag::autodiff_property_not_differentiable);
2260+
errorOccurred = true;
2261+
return;
2262+
}
2263+
auto *getterDiffAttr = getterDiffAttrs[0];
2264+
if (!getterDiffAttr->hasVJP()) {
2265+
getContext().emitNondifferentiabilityError(
2266+
sei, synthesis.task, diag::autodiff_property_not_differentiable);
2267+
errorOccurred = true;
2268+
return;
2269+
}
2270+
assert(getterDiffAttr->getIndices() ==
2271+
SILAutoDiffIndices(/*source*/ 0, /*parameters*/{0}));
2272+
auto *getterVJP = lookUpOrLinkFunction(getterDiffAttr->getVJPName(),
2273+
getContext().getModule());
2274+
2275+
// Reference and apply the VJP.
2276+
auto loc = sei->getLoc();
2277+
auto *getterVJPRef = getBuilder().createFunctionRef(loc, getterVJP);
2278+
auto *getterVJPApply = getBuilder().createApply(
2279+
loc, getterVJPRef, /*substitutionMap*/ {},
2280+
/*args*/ {getMappedValue(sei->getOperand())}, /*isNonThrowing*/ false);
2281+
2282+
// Get the VJP results (original results and pullback)
2283+
SmallVector<SILValue, 8> vjpDirectResults;
2284+
extractAllElements(getterVJPApply, getBuilder(), vjpDirectResults);
2285+
ArrayRef<SILValue> originalDirectResults =
2286+
ArrayRef<SILValue>(vjpDirectResults).drop_back(1);
2287+
SILValue originalDirectResult = joinElements(originalDirectResults,
2288+
getBuilder(),
2289+
getterVJPApply->getLoc());
2290+
SILValue pullback = vjpDirectResults.back();
2291+
2292+
// Store the original result to the value map.
2293+
mapValue(sei, originalDirectResult);
2294+
2295+
// Checkpoint the original results.
2296+
getPrimalInfo().addStaticPrimalValueDecl(sei);
2297+
staticPrimalValues.push_back(originalDirectResult);
2298+
2299+
// Checkpoint the pullback.
2300+
getPrimalInfo().addPullbackDecl(sei, pullback->getType().getASTType());
2301+
staticPrimalValues.push_back(pullback);
2302+
}
2303+
22302304
void visitApplyInst(ApplyInst *ai) {
22312305
if (DifferentiationUseVJP)
22322306
visitApplyInstWithVJP(ai);
@@ -3522,33 +3596,50 @@ class AdjointEmitter final : public SILInstructionVisitor<AdjointEmitter> {
35223596
}
35233597
}
35243598

3525-
/// Handle `struct_extract` instruction.
3526-
/// y = struct_extract <key>, x
3527-
/// adj[x] = struct (0, ..., key: adj[y], ..., 0)
35283599
void visitStructExtractInst(StructExtractInst *sei) {
3529-
auto *structDecl = sei->getStructDecl();
3530-
auto av = getAdjointValue(sei);
3531-
switch (av.getKind()) {
3532-
case AdjointValue::Kind::Zero:
3533-
addAdjointValue(sei->getOperand(),
3534-
AdjointValue::getZero(sei->getOperand()->getType()));
3535-
break;
3536-
case AdjointValue::Kind::Materialized:
3537-
case AdjointValue::Kind::Aggregate: {
3538-
SmallVector<AdjointValue, 8> eltVals;
3539-
for (auto *field : structDecl->getStoredProperties()) {
3540-
if (field == sei->getField())
3541-
eltVals.push_back(av);
3542-
else
3543-
eltVals.push_back(AdjointValue::getZero(
3544-
SILType::getPrimitiveObjectType(
3545-
field->getType()->getCanonicalType())));
3546-
}
3547-
addAdjointValue(sei->getOperand(),
3548-
AdjointValue::getAggregate(sei->getOperand()->getType(),
3549-
eltVals, allocator));
3600+
// Replace a `struct_extract` with a call to its pullback.
3601+
auto loc = remapLocation(sei->getLoc());
3602+
3603+
// Get the pullback.
3604+
auto *pullbackField = getPrimalInfo().lookUpPullbackDecl(sei);
3605+
if (!pullbackField) {
3606+
// Inactive `struct_extract` instructions don't need to be cloned into the
3607+
// adjoint.
3608+
assert(!activityInfo.isActive(sei, synthesis.indices));
3609+
return;
35503610
}
3611+
SILValue pullback = builder.createStructExtract(loc,
3612+
primalValueAggregateInAdj,
3613+
pullbackField);
3614+
3615+
// Construct the pullback arguments.
3616+
SmallVector<SILValue, 8> args;
3617+
auto seed = getAdjointValue(sei);
3618+
auto *seedBuf = builder.createAllocStack(loc, seed.getType());
3619+
materializeAdjointIndirectHelper(seed, seedBuf);
3620+
if (seed.getType().isAddressOnly(getModule()))
3621+
args.push_back(seedBuf);
3622+
else {
3623+
auto access = builder.createBeginAccess(
3624+
loc, seedBuf, SILAccessKind::Read, SILAccessEnforcement::Static,
3625+
/*noNestedConflict*/ true,
3626+
/*fromBuiltin*/ false);
3627+
args.push_back(builder.createLoad(
3628+
loc, access, getBufferLOQ(seed.getSwiftType(), getAdjoint())));
3629+
builder.createEndAccess(loc, access, /*aborted*/ false);
35513630
}
3631+
3632+
// Call the pullback.
3633+
auto *pullbackCall = builder.createApply(loc, pullback, SubstitutionMap(),
3634+
args, /*isNonThrowing*/ false);
3635+
assert(!pullbackCall->hasIndirectResults());
3636+
3637+
// Clean up seed allocation.
3638+
builder.createDeallocStack(loc, seedBuf);
3639+
3640+
// Set adjoint for the `struct_extract` operand.
3641+
addAdjointValue(sei->getOperand(),
3642+
AdjointValue::getMaterialized(pullbackCall));
35523643
}
35533644

35543645
/// Handle `tuple` instruction.

stdlib/public/core/FloatingPointTypes.swift.gyb

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1582,11 +1582,25 @@ extension ${Self} {
15821582

15831583
extension ${Self} {
15841584
@_transparent
1585+
// SWIFT_ENABLE_TENSORFLOW
1586+
@differentiable(adjoint: _adjointNegate)
15851587
public static prefix func - (x: ${Self}) -> ${Self} {
15861588
return ${Self}(Builtin.fneg_FPIEEE${bits}(x._value))
15871589
}
15881590
}
15891591

1592+
// SWIFT_ENABLE_TENSORFLOW
1593+
extension ${Self} {
1594+
@inlinable
1595+
@_transparent
1596+
// SWIFT_ENABLE_TENSORFLOW
1597+
static func _adjointNegate(
1598+
seed: ${Self}, originalValue: ${Self}, x: ${Self}
1599+
) -> ${Self} {
1600+
return -seed
1601+
}
1602+
}
1603+
15901604
//===----------------------------------------------------------------------===//
15911605
// Explicit conversions between types.
15921606
//===----------------------------------------------------------------------===//

test/AutoDiff/autodiff_diagnostics.swift

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,28 @@ func generic<T: Differentiable & FloatingPoint>(_ x: T) -> T {
3131
return x + 1
3232
}
3333

34+
//===----------------------------------------------------------------------===//
35+
// Non-differentiable stored properties
36+
//===----------------------------------------------------------------------===//
37+
38+
struct S {
39+
let p: Float
40+
}
41+
42+
extension S : Differentiable, VectorNumeric {
43+
static var zero: S { return S(p: 0) }
44+
typealias Scalar = Float
45+
static func + (lhs: S, rhs: S) -> S { return S(p: lhs.p + rhs.p) }
46+
static func - (lhs: S, rhs: S) -> S { return S(p: lhs.p - rhs.p) }
47+
static func * (lhs: Float, rhs: S) -> S { return S(p: lhs * rhs.p) }
48+
49+
typealias TangentVector = S
50+
typealias CotangentVector = S
51+
}
52+
53+
// expected-error @+1 {{property is not differentiable}}
54+
_ = gradient(at: S(p: 0)) { s in 2 * s.p }
55+
3456
//===----------------------------------------------------------------------===//
3557
// Function composition
3658
//===----------------------------------------------------------------------===//

test/AutoDiff/e2e_differentiable_property.swift

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -75,14 +75,12 @@ E2EDifferentiablePropertyTests.test("computed property") {
7575
expectEqual(expectedGrad, actualGrad)
7676
}
7777

78-
// FIXME: The AD pass cannot differentiate this because it sees
79-
// `struct_extract`s instead of calls to getters.
80-
// E2EDifferentiablePropertyTests.test("stored property") {
81-
// let actualGrad = gradient(at: Space(x: 0, y: 0)) { (point: Space) -> Float in
82-
// return 3 * point.y
83-
// }
84-
// let expectedGrad = TangentSpace(dx: 0, dy: 3)
85-
// expectEqual(expectedGrad, actualGrad)
86-
// }
78+
E2EDifferentiablePropertyTests.test("stored property") {
79+
let actualGrad = gradient(at: Space(x: 0, y: 0)) { (point: Space) -> Float in
80+
return 3 * point.y
81+
}
82+
let expectedGrad = TangentSpace(dx: 0, dy: 3)
83+
expectEqual(expectedGrad, actualGrad)
84+
}
8785

8886
runAllTests()

test/AutoDiff/method.swift

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,12 @@ var MethodTests = TestSuite("Method")
99
// ==== Tests with generated adjoint ====
1010

1111
struct Parameter : Equatable {
12+
@differentiable(wrt: (self), vjp: vjpX)
1213
let x: Float
14+
15+
func vjpX() -> (Float, (Float) -> Parameter) {
16+
return (x, { dx in Parameter(x: dx) } )
17+
}
1318
}
1419

1520
extension Parameter {
@@ -132,7 +137,12 @@ MethodTests.test("static method with generated adjoint, wrt all params") {
132137
// ==== Tests with custom adjoint ====
133138

134139
struct CustomParameter : Equatable {
140+
@differentiable(wrt: (self), vjp: vjpX)
135141
let x: Float
142+
143+
func vjpX() -> (Float, (Float) -> CustomParameter) {
144+
return (x, { dx in CustomParameter(x: dx) })
145+
}
136146
}
137147

138148
extension CustomParameter : Differentiable, VectorNumeric {

test/AutoDiff/protocol_requirement_autodiff.swift

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,24 @@ struct Quadratic : DiffReq, Equatable {
2727
typealias TangentVector = Quadratic
2828
typealias CotangentVector = Quadratic
2929

30-
let a, b, c: Float
30+
@differentiable(wrt: (self), vjp: vjpA)
31+
let a: Float
32+
func vjpA() -> (Float, (Float) -> Quadratic) {
33+
return (a, { da in Quadratic(da, 0, 0) } )
34+
}
35+
36+
@differentiable(wrt: (self), vjp: vjpB)
37+
let b: Float
38+
func vjpB() -> (Float, (Float) -> Quadratic) {
39+
return (b, { db in Quadratic(0, db, 0) } )
40+
}
41+
42+
@differentiable(wrt: (self), vjp: vjpC)
43+
let c: Float
44+
func vjpC() -> (Float, (Float) -> Quadratic) {
45+
return (c, { dc in Quadratic(0, 0, dc) } )
46+
}
47+
3148
init(_ a: Float, _ b: Float, _ c: Float) {
3249
self.a = a
3350
self.b = b

test/AutoDiff/simple_model.swift

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,17 @@ import StdlibUnittest
77
var SimpleModelTests = TestSuite("SimpleModel")
88

99
struct DenseLayer : Equatable {
10+
@differentiable(wrt: (self), vjp: vjpW)
1011
let w: Float
12+
func vjpW() -> (Float, (Float) -> DenseLayer) {
13+
return (w, { dw in DenseLayer(w: dw, b: 0) } )
14+
}
15+
16+
@differentiable(wrt: (self), vjp: vjpB)
1117
let b: Float
18+
func vjpB() -> (Float, (Float) -> DenseLayer) {
19+
return (b, { db in DenseLayer(w: 0, b: db) } )
20+
}
1221
}
1322

1423
extension DenseLayer : Differentiable, VectorNumeric {
@@ -39,9 +48,23 @@ extension DenseLayer {
3948
}
4049

4150
struct Model : Equatable {
51+
@differentiable(wrt: (self), vjp: vjpL1)
4252
let l1: DenseLayer
53+
func vjpL1() -> (DenseLayer, (DenseLayer) -> Model) {
54+
return (l1, { dl1 in Model(l1: dl1, l2: DenseLayer.zero, l3: DenseLayer.zero) } )
55+
}
56+
57+
@differentiable(wrt: (self), vjp: vjpL2)
4358
let l2: DenseLayer
59+
func vjpL2() -> (DenseLayer, (DenseLayer) -> Model) {
60+
return (l2, { dl2 in Model(l1: DenseLayer.zero, l2: dl2, l3: DenseLayer.zero) } )
61+
}
62+
63+
@differentiable(wrt: (self), vjp: vjpL3)
4464
let l3: DenseLayer
65+
func vjpL3() -> (DenseLayer, (DenseLayer) -> Model) {
66+
return (l3, { dl3 in Model(l1: DenseLayer.zero, l2: DenseLayer.zero, l3: dl3) } )
67+
}
4568
}
4669

4770
extension Model : Differentiable, VectorNumeric {

test/AutoDiff/witness_table_silgen.swift

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,11 @@ struct S : Proto, VectorNumeric {
2121
typealias TangentVector = S
2222
typealias CotangentVector = S
2323

24+
@differentiable(wrt: (self), vjp: vjpP)
2425
let p: Float
26+
func vjpP() -> (Float, (Float) -> S) {
27+
return (p, { dp in S(p: dp) })
28+
}
2529

2630
func function1(_ x: Float, _ y: Float) -> Float {
2731
return x + y + p

0 commit comments

Comments
 (0)