Skip to content

[AutoDiff] automatically handle fieldwise product spaces #21575

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Dec 31, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions include/swift/AST/Attr.def
Original file line number Diff line number Diff line change
Expand Up @@ -394,6 +394,8 @@ SIMPLE_DECL_ATTR(TensorFlowGraph, TensorFlowGraph,
OnFunc, 82)
SIMPLE_DECL_ATTR(TFParameter, TFParameter,
OnVar, 83)
SIMPLE_DECL_ATTR(_fieldwiseProductSpace, FieldwiseProductSpace,
OnTypeAlias | UserInaccessible, 84)

#undef TYPE_ATTR
#undef DECL_ATTR_ALIAS
Expand Down
155 changes: 147 additions & 8 deletions lib/SILOptimizer/Mandatory/TFDifferentiation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -669,6 +669,24 @@ struct NestedApplyActivity {
SILAutoDiffIndices indices;
};

/// Specifies how we should differentiate a `struct_extract` instruction.
enum class StructExtractDifferentiationStrategy {
// The `struct_extract` is not active, so do not differentiate it.
Inactive,

// The `struct_extract` is extracting a field from a Differentiable struct
// with @_fieldwiseProductSpace cotangent space. Therefore, differentiate the
// `struct_extract` by setting the adjoint to a vector in the cotangent space
// that is zero except along the direction of the corresponding field.
//
// Fields correspond by matching name.
FieldwiseProductSpace,

// Differentiate the `struct_extract` by looking up the corresponding getter
// and using its VJP.
Getter
};

/// A differentiation task, specifying the original function and the
/// `[differentiable]` attribute on the function. PrimalGen and AdjointGen
/// will synthesize the primal and the adjoint for this task, filling the primal
Expand Down Expand Up @@ -714,6 +732,10 @@ class DifferentiationTask {
/// Note: This is only used when `DifferentiationUseVJP`.
DenseMap<ApplyInst *, NestedApplyActivity> nestedApplyActivities;

/// Mapping from original `struct_extract` instructions to their strategies.
DenseMap<StructExtractInst *, StructExtractDifferentiationStrategy>
structExtractDifferentiationStrategies;

/// Cache for associated functions.
SILFunction *primal = nullptr;
SILFunction *adjoint = nullptr;
Expand Down Expand Up @@ -810,6 +832,11 @@ class DifferentiationTask {
return nestedApplyActivities;
}

DenseMap<StructExtractInst *, StructExtractDifferentiationStrategy> &
getStructExtractDifferentiationStrategies() {
return structExtractDifferentiationStrategies;
}

bool isEqual(const DifferentiationTask &other) const {
return original == other.original && attr == other.attr;
}
Expand Down Expand Up @@ -2228,16 +2255,42 @@ class PrimalGenCloner final : public SILClonerWithScopes<PrimalGenCloner> {
}

void visitStructExtractInst(StructExtractInst *sei) {
auto &astCtx = getContext().getASTContext();
auto &structExtractDifferentiationStrategies =
getDifferentiationTask()->getStructExtractDifferentiationStrategies();

// Special handling logic only applies when the `struct_extract` is active.
// If not, just do standard cloning.
if (!activityInfo.isActive(sei, synthesis.indices)) {
LLVM_DEBUG(getADDebugStream() << "Not active:\n" << *sei << '\n');
structExtractDifferentiationStrategies.insert(
{sei, StructExtractDifferentiationStrategy::Inactive});
SILClonerWithScopes::visitStructExtractInst(sei);
return;
}

// This instruction is active. Replace it with a call to the corresponding
// getter's VJP.
// This instruction is active. Determine the appropriate differentiation
// strategy, and use it.

// Use the FieldwiseProductSpace strategy, if appropriate.
auto *structDecl = sei->getStructDecl();
auto aliasLookup = structDecl->lookupDirect(astCtx.Id_CotangentVector);
if (aliasLookup.size() >= 1) {
assert(aliasLookup.size() == 1);
assert(isa<TypeAliasDecl>(aliasLookup[0]));
auto *aliasDecl = cast<TypeAliasDecl>(aliasLookup[0]);
if (aliasDecl->getAttrs().hasAttribute<FieldwiseProductSpaceAttr>()) {
structExtractDifferentiationStrategies.insert(
{sei, StructExtractDifferentiationStrategy::FieldwiseProductSpace});
SILClonerWithScopes::visitStructExtractInst(sei);
return;
}
}

// The FieldwiseProductSpace strategy is not appropriate, so use the Getter
// strategy.
structExtractDifferentiationStrategies.insert(
{sei, StructExtractDifferentiationStrategy::Getter});

// Find the corresponding getter and its VJP.
auto *getterDecl = sei->getField()->getGetter();
Expand Down Expand Up @@ -3596,17 +3649,103 @@ class AdjointEmitter final : public SILInstructionVisitor<AdjointEmitter> {
}

void visitStructExtractInst(StructExtractInst *sei) {
// Replace a `struct_extract` with a call to its pullback.
auto loc = remapLocation(sei->getLoc());
auto &astCtx = getContext().getASTContext();

// Get the pullback.
auto *pullbackField = getPrimalInfo().lookUpPullbackDecl(sei);
if (!pullbackField) {
// Inactive `struct_extract` instructions don't need to be cloned into the
// adjoint.
auto &differentiationStrategies =
getDifferentiationTask()->getStructExtractDifferentiationStrategies();
auto differentiationStrategyLookUp = differentiationStrategies.find(sei);
assert(differentiationStrategyLookUp != differentiationStrategies.end());
auto differentiationStrategy = differentiationStrategyLookUp->second;

if (differentiationStrategy ==
StructExtractDifferentiationStrategy::Inactive) {
assert(!activityInfo.isActive(sei, synthesis.indices));
return;
}

if (differentiationStrategy ==
StructExtractDifferentiationStrategy::FieldwiseProductSpace) {
// Compute adjoint as follows:
// y = struct_extract <key>, x
// adj[x] = struct (0, ..., key': adj[y], ..., 0)
// where `key'` is the field in the cotangent space corresponding to
// `key`.

// Find the decl of the cotangent space type.
auto *structDecl = sei->getStructDecl();
auto aliasLookup = structDecl->lookupDirect(astCtx.Id_CotangentVector);
assert(aliasLookup.size() == 1);
assert(isa<TypeAliasDecl>(aliasLookup[0]));
auto *aliasDecl = cast<TypeAliasDecl>(aliasLookup[0]);
assert(aliasDecl->getAttrs().hasAttribute<FieldwiseProductSpaceAttr>());
auto cotangentVectorTy =
aliasDecl->getUnderlyingTypeLoc().getType()->getCanonicalType();
assert(!getModule()
.Types.getTypeLowering(cotangentVectorTy)
.isAddressOnly());
auto cotangentVectorSILTy =
SILType::getPrimitiveObjectType(cotangentVectorTy);
auto *cotangentVectorDecl =
cotangentVectorTy->getStructOrBoundGenericStruct();
assert(cotangentVectorDecl);

// Find the corresponding field in the cotangent space.
VarDecl *correspondingField = nullptr;
if (cotangentVectorDecl == structDecl)
correspondingField = sei->getField();
else {
auto correspondingFieldLookup =
cotangentVectorDecl->lookupDirect(sei->getField()->getName());
assert(correspondingFieldLookup.size() == 1);
assert(isa<VarDecl>(correspondingFieldLookup[0]));
correspondingField = cast<VarDecl>(correspondingFieldLookup[0]);
}
assert(correspondingField);

#ifndef NDEBUG
unsigned numMatchingStoredProperties = 0;
for (auto *storedProperty : cotangentVectorDecl->getStoredProperties())
if (storedProperty == correspondingField)
numMatchingStoredProperties += 1;
assert(numMatchingStoredProperties == 1);
#endif

// Compute adjoint.
auto av = getAdjointValue(sei);
switch (av.getKind()) {
case AdjointValue::Kind::Zero:
addAdjointValue(sei->getOperand(),
AdjointValue::getZero(cotangentVectorSILTy));
break;
case AdjointValue::Kind::Materialized:
case AdjointValue::Kind::Aggregate: {
SmallVector<AdjointValue, 8> eltVals;
for (auto *field : cotangentVectorDecl->getStoredProperties()) {
if (field == correspondingField)
eltVals.push_back(av);
else
eltVals.push_back(
AdjointValue::getZero(SILType::getPrimitiveObjectType(
field->getType()->getCanonicalType())));
}
addAdjointValue(sei->getOperand(),
AdjointValue::getAggregate(cotangentVectorSILTy,
eltVals, allocator));
}
}

return;
}

// The only remaining strategy is the getter strategy.
// Replace the `struct_extract` with a call to its pullback.
assert(differentiationStrategy ==
StructExtractDifferentiationStrategy::Getter);

// Get the pullback.
auto *pullbackField = getPrimalInfo().lookUpPullbackDecl(sei);
assert(pullbackField);
SILValue pullback = builder.createStructExtract(loc,
primalValueAggregateInAdj,
pullbackField);
Expand Down
15 changes: 15 additions & 0 deletions lib/Sema/TypeCheckAttr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ class AttributeEarlyChecker : public AttributeVisitor<AttributeEarlyChecker> {
IGNORED_ATTR(CompilerEvaluable)
IGNORED_ATTR(TensorFlowGraph)
IGNORED_ATTR(TFParameter)
IGNORED_ATTR(FieldwiseProductSpace)
#undef IGNORED_ATTR

// @noreturn has been replaced with a 'Never' return type.
Expand Down Expand Up @@ -884,6 +885,7 @@ class AttributeChecker : public AttributeVisitor<AttributeChecker> {
void visitCompilerEvaluableAttr(CompilerEvaluableAttr *attr);
void visitTensorFlowGraphAttr(TensorFlowGraphAttr *attr);
void visitTFParameterAttr(TFParameterAttr *attr);
void visitFieldwiseProductSpaceAttr(FieldwiseProductSpaceAttr *attr);
};
} // end anonymous namespace

Expand Down Expand Up @@ -2705,6 +2707,19 @@ void AttributeChecker::visitTFParameterAttr(TFParameterAttr *attr) {
}
}

void AttributeChecker::visitFieldwiseProductSpaceAttr(
FieldwiseProductSpaceAttr *attr) {
// If we make this attribute user-facing, we'll need to do various checks.
// - check that this attribute is on a Tangent/Cotangent type alias
// - check that we can access the raw fields of the Tangent/Cotangent from
// this module (e.g. the Tangent can't be a public resilient struct
// defined in a different module).
// - check that the stored properties of the Tangent/Cotangent match
//
// If we don't make this attribute user-facing, we can avoid doing checks
// here: the assertions in TFDifferentiation suffice.
}

void TypeChecker::checkDeclAttributes(Decl *D) {
AttributeChecker Checker(*this, D);

Expand Down
1 change: 1 addition & 0 deletions lib/Sema/TypeCheckDeclOverride.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1218,6 +1218,7 @@ namespace {
UNINTERESTING_ATTR(CompilerEvaluable)
UNINTERESTING_ATTR(TensorFlowGraph)
UNINTERESTING_ATTR(TFParameter)
UNINTERESTING_ATTR(FieldwiseProductSpace)

// These can't appear on overridable declarations.
UNINTERESTING_ATTR(Prefix)
Expand Down
67 changes: 52 additions & 15 deletions test/AutoDiff/e2e_differentiable_property.swift
Original file line number Diff line number Diff line change
Expand Up @@ -8,26 +8,13 @@ import StdlibUnittest

var E2EDifferentiablePropertyTests = TestSuite("E2EDifferentiableProperty")

struct TangentSpace {
struct TangentSpace : VectorNumeric {
let dx, dy: Float
}

extension TangentSpace : Differentiable, VectorNumeric {
extension TangentSpace : Differentiable {
typealias TangentVector = TangentSpace
typealias CotangentVector = TangentSpace
typealias Scalar = Float
static var zero: TangentSpace {
return TangentSpace(dx: 0, dy: 0)
}
static func + (lhs: TangentSpace, rhs: TangentSpace) -> TangentSpace {
return TangentSpace(dx: lhs.dx + rhs.dx, dy: lhs.dy + rhs.dy)
}
static func - (lhs: TangentSpace, rhs: TangentSpace) -> TangentSpace {
return TangentSpace(dx: lhs.dx - rhs.dx, dy: lhs.dy - rhs.dy)
}
static func * (lhs: Float, rhs: TangentSpace) -> TangentSpace {
return TangentSpace(dx: lhs * rhs.dx, dy: lhs * rhs.dy)
}
}

struct Space {
Expand Down Expand Up @@ -83,4 +70,54 @@ E2EDifferentiablePropertyTests.test("stored property") {
expectEqual(expectedGrad, actualGrad)
}

struct ProductSpaceSelfTangent : VectorNumeric {
let x, y: Float
}

extension ProductSpaceSelfTangent : Differentiable {
@_fieldwiseProductSpace
typealias TangentVector = ProductSpaceSelfTangent
@_fieldwiseProductSpace
typealias CotangentVector = ProductSpaceSelfTangent
}

E2EDifferentiablePropertyTests.test("fieldwise product space, self tangent") {
let actualGrad = gradient(at: ProductSpaceSelfTangent(x: 0, y: 0)) { (point: ProductSpaceSelfTangent) -> Float in
return 5 * point.y
}
let expectedGrad = ProductSpaceSelfTangent(x: 0, y: 5)
expectEqual(expectedGrad, actualGrad)
}

struct ProductSpaceOtherTangentTangentSpace : VectorNumeric {
let x, y: Float
}

extension ProductSpaceOtherTangentTangentSpace : Differentiable {
typealias TangentVector = ProductSpaceOtherTangentTangentSpace
typealias CotangentVector = ProductSpaceOtherTangentTangentSpace
}

struct ProductSpaceOtherTangent {
let x, y: Float
}

extension ProductSpaceOtherTangent : Differentiable {
@_fieldwiseProductSpace
typealias TangentVector = ProductSpaceOtherTangentTangentSpace
@_fieldwiseProductSpace
typealias CotangentVector = ProductSpaceOtherTangentTangentSpace
func moved(along: ProductSpaceOtherTangentTangentSpace) -> ProductSpaceOtherTangent {
return ProductSpaceOtherTangent(x: x + along.x, y: y + along.y)
}
}

E2EDifferentiablePropertyTests.test("fieldwise product space, other tangent") {
let actualGrad = gradient(at: ProductSpaceOtherTangent(x: 0, y: 0)) { (point: ProductSpaceOtherTangent) -> Float in
return 7 * point.y
}
let expectedGrad = ProductSpaceOtherTangentTangentSpace(x: 0, y: 7)
expectEqual(expectedGrad, actualGrad)
}

runAllTests()