Skip to content

differentiate struct_extract instructions using VJP #21567

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Dec 31, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions include/swift/AST/DiagnosticsSIL.def
Original file line number Diff line number Diff line change
Expand Up @@ -374,6 +374,8 @@ ERROR(autodiff_unsupported_type,none,
"differentiating '%0' is not supported yet", (Type))
ERROR(autodiff_function_not_differentiable,none,
"function is not differentiable", ())
ERROR(autodiff_property_not_differentiable,none,
"property is not differentiable", ())
NOTE(autodiff_function_generic_functions_unsupported,none,
"differentiating generic functions is not supported yet", ())
NOTE(autodiff_value_defined_here,none,
Expand Down
140 changes: 108 additions & 32 deletions lib/SILOptimizer/Mandatory/TFDifferentiation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -472,9 +472,9 @@ class PrimalInfo {
/// corresponding tape of its type.
DenseMap<ApplyInst *, VarDecl *> nestedStaticPrimalValueMap;

/// Mapping from `apply` instructions in the original function to the
/// corresponding pullback decl in the primal struct.
DenseMap<ApplyInst *, VarDecl *> pullbackValueMap;
/// Mapping from `apply` and `struct_extract` instructions in the original
/// function to the corresponding pullback decl in the primal struct.
DenseMap<SILInstruction *, VarDecl *> pullbackValueMap;

/// Mapping from types of control-dependent nested primal values to district
/// tapes.
Expand Down Expand Up @@ -573,7 +573,7 @@ class PrimalInfo {
}

/// Add a pullback to the primal value struct.
VarDecl *addPullbackDecl(ApplyInst *inst, Type pullbackType) {
VarDecl *addPullbackDecl(SILInstruction *inst, Type pullbackType) {
// Decls must have AST types (not `SILFunctionType`), so we convert the
// `SILFunctionType` of the pullback to a `FunctionType` with the same
// parameters and results.
Expand Down Expand Up @@ -605,9 +605,9 @@ class PrimalInfo {
: lookup->getSecond();
}

/// Finds the pullback decl in the primal value struct for an `apply` in the
/// original function.
VarDecl *lookUpPullbackDecl(ApplyInst *inst) {
/// Finds the pullback decl in the primal value struct for an `apply` or
/// `struct_extract` in the original function.
VarDecl *lookUpPullbackDecl(SILInstruction *inst) {
auto lookup = pullbackValueMap.find(inst);
return lookup == pullbackValueMap.end() ? nullptr
: lookup->getSecond();
Expand Down Expand Up @@ -2227,6 +2227,79 @@ class PrimalGenCloner final : public SILClonerWithScopes<PrimalGenCloner> {
SILClonerWithScopes::visitReleaseValueInst(rvi);
}

void visitStructExtractInst(StructExtractInst *sei) {
// Special handling logic only applies when the `struct_extract` is active.
// If not, just do standard cloning.
if (!activityInfo.isActive(sei, synthesis.indices)) {
LLVM_DEBUG(getADDebugStream() << "Not active:\n" << *sei << '\n');
SILClonerWithScopes::visitStructExtractInst(sei);
return;
}

// This instruction is active. Replace it with a call to the corresponding
// getter's VJP.

// Find the corresponding getter and its VJP.
auto *getterDecl = sei->getField()->getGetter();
assert(getterDecl);
auto *getterFn = getContext().getModule().lookUpFunction(
SILDeclRef(getterDecl, SILDeclRef::Kind::Func));
if (!getterFn) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When will a getter not exist?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One example is the case that I added to test/AutoDiff/autodiff_diagnostics.swift.

I don't know why the getter doesn't exist in that case, and if I put the same code in a file an compile it with a plain call to swiftc, the getter does exist and the code ends up triggering the case where the getter exists but doesn't have a VJP.

I haven't investigated any farther than that, so I don't really know what's going on but the logic seems to handle all the cases I can think of correctly.

getContext().emitNondifferentiabilityError(
sei, synthesis.task, diag::autodiff_property_not_differentiable);
errorOccurred = true;
return;
}
auto getterDiffAttrs = getterFn->getDifferentiableAttrs();
if (getterDiffAttrs.size() < 1) {
getContext().emitNondifferentiabilityError(
sei, synthesis.task, diag::autodiff_property_not_differentiable);
errorOccurred = true;
return;
}
auto *getterDiffAttr = getterDiffAttrs[0];
if (!getterDiffAttr->hasVJP()) {
getContext().emitNondifferentiabilityError(
sei, synthesis.task, diag::autodiff_property_not_differentiable);
errorOccurred = true;
return;
}
assert(getterDiffAttr->getIndices() ==
SILAutoDiffIndices(/*source*/ 0, /*parameters*/{0}));
auto *getterVJP = lookUpOrLinkFunction(getterDiffAttr->getVJPName(),
getContext().getModule());

// Reference and apply the VJP.
auto loc = sei->getLoc();
auto *getterVJPRef = getBuilder().createFunctionRef(loc, getterVJP);
auto *getterVJPApply = getBuilder().createApply(
loc, getterVJPRef, /*substitutionMap*/ {},
/*args*/ {getMappedValue(sei->getOperand())}, /*isNonThrowing*/ false);

// Get the VJP results (original results and pullback).
SmallVector<SILValue, 8> vjpDirectResults;
extractAllElements(getterVJPApply, getBuilder(), vjpDirectResults);
ArrayRef<SILValue> originalDirectResults =
ArrayRef<SILValue>(vjpDirectResults).drop_back(1);
SILValue originalDirectResult = joinElements(originalDirectResults,
getBuilder(),
getterVJPApply->getLoc());
SILValue pullback = vjpDirectResults.back();

// Store the original result to the value map.
mapValue(sei, originalDirectResult);

// Checkpoint the original results.
getPrimalInfo().addStaticPrimalValueDecl(sei);
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

reminder to self: must add a retain here

    getBuilder().createRetainValue(loc, originalDirectResult,
                                   getBuilder().getDefaultAtomicity());

getBuilder().createRetainValue(loc, originalDirectResult,
getBuilder().getDefaultAtomicity());
staticPrimalValues.push_back(originalDirectResult);

// Checkpoint the pullback.
getPrimalInfo().addPullbackDecl(sei, pullback->getType().getASTType());
staticPrimalValues.push_back(pullback);
}

void visitApplyInst(ApplyInst *ai) {
if (DifferentiationUseVJP)
visitApplyInstWithVJP(ai);
Expand Down Expand Up @@ -3522,33 +3595,36 @@ class AdjointEmitter final : public SILInstructionVisitor<AdjointEmitter> {
}
}

/// Handle `struct_extract` instruction.
/// y = struct_extract <key>, x
/// adj[x] = struct (0, ..., key: adj[y], ..., 0)
void visitStructExtractInst(StructExtractInst *sei) {
auto *structDecl = sei->getStructDecl();
Copy link
Contributor

@rxwei rxwei Dec 30, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When we have derived conformances to Differentiable, it'd be much more efficient to use the original struct_extract differentiation logic when we know that the cotangent type equals the original type and that the property vjp is compiler-synthesized (this can be checked via Decl::isImplicit(), I think).

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In fact, this gives me an idea. The original struct_extract differentiation logic does not need getter VJPs. So if I can get the original struct_extract differentiation logic working on structs with derived conformance to Differentiable (even those with cotangent type not equal to self), then I don't need to write any code that synthesizes getter VJPs.

I think that the original struct_extract logic can be made to work even on derived conformances to Differentiable with cotangent type not equal to self, because it should be possible to determine which fields correspond.

I will try out this approach.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds reasonable for internal structs. Public structs need property VJPs in any case because of resilience requirements.

auto av = getAdjointValue(sei);
switch (av.getKind()) {
case AdjointValue::Kind::Zero:
addAdjointValue(sei->getOperand(),
AdjointValue::getZero(sei->getOperand()->getType()));
break;
case AdjointValue::Kind::Materialized:
case AdjointValue::Kind::Aggregate: {
SmallVector<AdjointValue, 8> eltVals;
for (auto *field : structDecl->getStoredProperties()) {
if (field == sei->getField())
eltVals.push_back(av);
else
eltVals.push_back(AdjointValue::getZero(
SILType::getPrimitiveObjectType(
field->getType()->getCanonicalType())));
}
addAdjointValue(sei->getOperand(),
AdjointValue::getAggregate(sei->getOperand()->getType(),
eltVals, allocator));
}
// Replace a `struct_extract` with a call to its pullback.
auto loc = remapLocation(sei->getLoc());

// Get the pullback.
auto *pullbackField = getPrimalInfo().lookUpPullbackDecl(sei);
if (!pullbackField) {
// Inactive `struct_extract` instructions don't need to be cloned into the
// adjoint.
assert(!activityInfo.isActive(sei, synthesis.indices));
return;
}
SILValue pullback = builder.createStructExtract(loc,
primalValueAggregateInAdj,
pullbackField);

// Construct the pullback arguments.
SmallVector<SILValue, 8> args;
auto seed = getAdjointValue(sei);
assert(seed.getType().isObject());
args.push_back(materializeAdjointDirect(seed, loc));

// Call the pullback.
auto *pullbackCall = builder.createApply(loc, pullback, SubstitutionMap(),
args, /*isNonThrowing*/ false);
assert(!pullbackCall->hasIndirectResults());

// Set adjoint for the `struct_extract` operand.
addAdjointValue(sei->getOperand(),
AdjointValue::getMaterialized(pullbackCall));
}

/// Handle `tuple` instruction.
Expand Down
14 changes: 14 additions & 0 deletions stdlib/public/core/FloatingPointTypes.swift.gyb
Original file line number Diff line number Diff line change
Expand Up @@ -1582,11 +1582,25 @@ extension ${Self} {

extension ${Self} {
@_transparent
// SWIFT_ENABLE_TENSORFLOW
@differentiable(adjoint: _adjointNegate)
public static prefix func - (x: ${Self}) -> ${Self} {
return ${Self}(Builtin.fneg_FPIEEE${bits}(x._value))
}
}

// SWIFT_ENABLE_TENSORFLOW
extension ${Self} {
@usableFromInline
@_transparent
// SWIFT_ENABLE_TENSORFLOW
static func _adjointNegate(
seed: ${Self}, originalValue: ${Self}, x: ${Self}
) -> ${Self} {
return -seed
}
}

//===----------------------------------------------------------------------===//
// Explicit conversions between types.
//===----------------------------------------------------------------------===//
Expand Down
22 changes: 22 additions & 0 deletions test/AutoDiff/autodiff_diagnostics.swift
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,28 @@ func generic<T: Differentiable & FloatingPoint>(_ x: T) -> T {
return x + 1
}

//===----------------------------------------------------------------------===//
// Non-differentiable stored properties
//===----------------------------------------------------------------------===//

struct S {
let p: Float
}

extension S : Differentiable, VectorNumeric {
static var zero: S { return S(p: 0) }
typealias Scalar = Float
static func + (lhs: S, rhs: S) -> S { return S(p: lhs.p + rhs.p) }
static func - (lhs: S, rhs: S) -> S { return S(p: lhs.p - rhs.p) }
static func * (lhs: Float, rhs: S) -> S { return S(p: lhs * rhs.p) }

typealias TangentVector = S
typealias CotangentVector = S
}

// expected-error @+1 {{property is not differentiable}}
_ = gradient(at: S(p: 0)) { s in 2 * s.p }

//===----------------------------------------------------------------------===//
// Function composition
//===----------------------------------------------------------------------===//
Expand Down
16 changes: 7 additions & 9 deletions test/AutoDiff/e2e_differentiable_property.swift
Original file line number Diff line number Diff line change
Expand Up @@ -75,14 +75,12 @@ E2EDifferentiablePropertyTests.test("computed property") {
expectEqual(expectedGrad, actualGrad)
}

// FIXME: The AD pass cannot differentiate this because it sees
// `struct_extract`s instead of calls to getters.
// E2EDifferentiablePropertyTests.test("stored property") {
// let actualGrad = gradient(at: Space(x: 0, y: 0)) { (point: Space) -> Float in
// return 3 * point.y
// }
// let expectedGrad = TangentSpace(dx: 0, dy: 3)
// expectEqual(expectedGrad, actualGrad)
// }
E2EDifferentiablePropertyTests.test("stored property") {
let actualGrad = gradient(at: Space(x: 0, y: 0)) { (point: Space) -> Float in
return 3 * point.y
}
let expectedGrad = TangentSpace(dx: 0, dy: 3)
expectEqual(expectedGrad, actualGrad)
}

runAllTests()
10 changes: 10 additions & 0 deletions test/AutoDiff/method.swift
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,12 @@ var MethodTests = TestSuite("Method")
// ==== Tests with generated adjoint ====

struct Parameter : Equatable {
@differentiable(wrt: (self), vjp: vjpX)
let x: Float

func vjpX() -> (Float, (Float) -> Parameter) {
return (x, { dx in Parameter(x: dx) } )
}
}

extension Parameter {
Expand Down Expand Up @@ -132,7 +137,12 @@ MethodTests.test("static method with generated adjoint, wrt all params") {
// ==== Tests with custom adjoint ====

struct CustomParameter : Equatable {
@differentiable(wrt: (self), vjp: vjpX)
let x: Float

func vjpX() -> (Float, (Float) -> CustomParameter) {
return (x, { dx in CustomParameter(x: dx) })
}
}

extension CustomParameter : Differentiable, VectorNumeric {
Expand Down
19 changes: 18 additions & 1 deletion test/AutoDiff/protocol_requirement_autodiff.swift
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,24 @@ struct Quadratic : DiffReq, Equatable {
typealias TangentVector = Quadratic
typealias CotangentVector = Quadratic

let a, b, c: Float
@differentiable(wrt: (self), vjp: vjpA)
let a: Float
func vjpA() -> (Float, (Float) -> Quadratic) {
return (a, { da in Quadratic(da, 0, 0) } )
}

@differentiable(wrt: (self), vjp: vjpB)
let b: Float
func vjpB() -> (Float, (Float) -> Quadratic) {
return (b, { db in Quadratic(0, db, 0) } )
}

@differentiable(wrt: (self), vjp: vjpC)
let c: Float
func vjpC() -> (Float, (Float) -> Quadratic) {
return (c, { dc in Quadratic(0, 0, dc) } )
}

init(_ a: Float, _ b: Float, _ c: Float) {
self.a = a
self.b = b
Expand Down
23 changes: 23 additions & 0 deletions test/AutoDiff/simple_model.swift
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,17 @@ import StdlibUnittest
var SimpleModelTests = TestSuite("SimpleModel")

struct DenseLayer : Equatable {
@differentiable(wrt: (self), vjp: vjpW)
let w: Float
func vjpW() -> (Float, (Float) -> DenseLayer) {
return (w, { dw in DenseLayer(w: dw, b: 0) } )
}

@differentiable(wrt: (self), vjp: vjpB)
let b: Float
func vjpB() -> (Float, (Float) -> DenseLayer) {
return (b, { db in DenseLayer(w: 0, b: db) } )
}
}

extension DenseLayer : Differentiable, VectorNumeric {
Expand Down Expand Up @@ -39,9 +48,23 @@ extension DenseLayer {
}

struct Model : Equatable {
@differentiable(wrt: (self), vjp: vjpL1)
let l1: DenseLayer
func vjpL1() -> (DenseLayer, (DenseLayer) -> Model) {
return (l1, { dl1 in Model(l1: dl1, l2: DenseLayer.zero, l3: DenseLayer.zero) } )
}

@differentiable(wrt: (self), vjp: vjpL2)
let l2: DenseLayer
func vjpL2() -> (DenseLayer, (DenseLayer) -> Model) {
return (l2, { dl2 in Model(l1: DenseLayer.zero, l2: dl2, l3: DenseLayer.zero) } )
}

@differentiable(wrt: (self), vjp: vjpL3)
let l3: DenseLayer
func vjpL3() -> (DenseLayer, (DenseLayer) -> Model) {
return (l3, { dl3 in Model(l1: DenseLayer.zero, l2: DenseLayer.zero, l3: dl3) } )
}
}

extension Model : Differentiable, VectorNumeric {
Expand Down
4 changes: 4 additions & 0 deletions test/AutoDiff/witness_table_silgen.swift
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,11 @@ struct S : Proto, VectorNumeric {
typealias TangentVector = S
typealias CotangentVector = S

@differentiable(wrt: (self), vjp: vjpP)
let p: Float
func vjpP() -> (Float, (Float) -> S) {
return (p, { dp in S(p: dp) })
}

func function1(_ x: Float, _ y: Float) -> Float {
return x + y + p
Expand Down