Skip to content

Commit 02cd429

Browse files
authored
[AutoDiff] automatically handle fieldwise product spaces (#21575)
1 parent 988d1e3 commit 02cd429

File tree

5 files changed

+217
-23
lines changed

5 files changed

+217
-23
lines changed

include/swift/AST/Attr.def

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -394,6 +394,8 @@ SIMPLE_DECL_ATTR(TensorFlowGraph, TensorFlowGraph,
394394
OnFunc, 82)
395395
SIMPLE_DECL_ATTR(TFParameter, TFParameter,
396396
OnVar, 83)
397+
SIMPLE_DECL_ATTR(_fieldwiseProductSpace, FieldwiseProductSpace,
398+
OnTypeAlias | UserInaccessible, 84)
397399

398400
#undef TYPE_ATTR
399401
#undef DECL_ATTR_ALIAS

lib/SILOptimizer/Mandatory/TFDifferentiation.cpp

Lines changed: 147 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -669,6 +669,24 @@ struct NestedApplyActivity {
669669
SILAutoDiffIndices indices;
670670
};
671671

672+
/// Specifies how we should differentiate a `struct_extract` instruction.
673+
enum class StructExtractDifferentiationStrategy {
674+
// The `struct_extract` is not active, so do not differentiate it.
675+
Inactive,
676+
677+
// The `struct_extract` is extracting a field from a Differentiable struct
678+
// with @_fieldwiseProductSpace cotangent space. Therefore, differentiate the
679+
// `struct_extract` by setting the adjoint to a vector in the cotangent space
680+
// that is zero except along the direction of the corresponding field.
681+
//
682+
// Fields correspond by matching name.
683+
FieldwiseProductSpace,
684+
685+
// Differentiate the `struct_extract` by looking up the corresponding getter
686+
// and using its VJP.
687+
Getter
688+
};
689+
672690
/// A differentiation task, specifying the original function and the
673691
/// `[differentiable]` attribute on the function. PrimalGen and AdjointGen
674692
/// will synthesize the primal and the adjoint for this task, filling the primal
@@ -714,6 +732,10 @@ class DifferentiationTask {
714732
/// Note: This is only used when `DifferentiationUseVJP`.
715733
DenseMap<ApplyInst *, NestedApplyActivity> nestedApplyActivities;
716734

735+
/// Mapping from original `struct_extract` instructions to their strategies.
736+
DenseMap<StructExtractInst *, StructExtractDifferentiationStrategy>
737+
structExtractDifferentiationStrategies;
738+
717739
/// Cache for associated functions.
718740
SILFunction *primal = nullptr;
719741
SILFunction *adjoint = nullptr;
@@ -810,6 +832,11 @@ class DifferentiationTask {
810832
return nestedApplyActivities;
811833
}
812834

835+
DenseMap<StructExtractInst *, StructExtractDifferentiationStrategy> &
836+
getStructExtractDifferentiationStrategies() {
837+
return structExtractDifferentiationStrategies;
838+
}
839+
813840
bool isEqual(const DifferentiationTask &other) const {
814841
return original == other.original && attr == other.attr;
815842
}
@@ -2228,16 +2255,42 @@ class PrimalGenCloner final : public SILClonerWithScopes<PrimalGenCloner> {
22282255
}
22292256

22302257
void visitStructExtractInst(StructExtractInst *sei) {
2258+
auto &astCtx = getContext().getASTContext();
2259+
auto &structExtractDifferentiationStrategies =
2260+
getDifferentiationTask()->getStructExtractDifferentiationStrategies();
2261+
22312262
// Special handling logic only applies when the `struct_extract` is active.
22322263
// If not, just do standard cloning.
22332264
if (!activityInfo.isActive(sei, synthesis.indices)) {
22342265
LLVM_DEBUG(getADDebugStream() << "Not active:\n" << *sei << '\n');
2266+
structExtractDifferentiationStrategies.insert(
2267+
{sei, StructExtractDifferentiationStrategy::Inactive});
22352268
SILClonerWithScopes::visitStructExtractInst(sei);
22362269
return;
22372270
}
22382271

2239-
// This instruction is active. Replace it with a call to the corresponding
2240-
// getter's VJP.
2272+
// This instruction is active. Determine the appropriate differentiation
2273+
// strategy, and use it.
2274+
2275+
// Use the FieldwiseProductSpace strategy, if appropriate.
2276+
auto *structDecl = sei->getStructDecl();
2277+
auto aliasLookup = structDecl->lookupDirect(astCtx.Id_CotangentVector);
2278+
if (aliasLookup.size() >= 1) {
2279+
assert(aliasLookup.size() == 1);
2280+
assert(isa<TypeAliasDecl>(aliasLookup[0]));
2281+
auto *aliasDecl = cast<TypeAliasDecl>(aliasLookup[0]);
2282+
if (aliasDecl->getAttrs().hasAttribute<FieldwiseProductSpaceAttr>()) {
2283+
structExtractDifferentiationStrategies.insert(
2284+
{sei, StructExtractDifferentiationStrategy::FieldwiseProductSpace});
2285+
SILClonerWithScopes::visitStructExtractInst(sei);
2286+
return;
2287+
}
2288+
}
2289+
2290+
// The FieldwiseProductSpace strategy is not appropriate, so use the Getter
2291+
// strategy.
2292+
structExtractDifferentiationStrategies.insert(
2293+
{sei, StructExtractDifferentiationStrategy::Getter});
22412294

22422295
// Find the corresponding getter and its VJP.
22432296
auto *getterDecl = sei->getField()->getGetter();
@@ -3596,17 +3649,103 @@ class AdjointEmitter final : public SILInstructionVisitor<AdjointEmitter> {
35963649
}
35973650

35983651
void visitStructExtractInst(StructExtractInst *sei) {
3599-
// Replace a `struct_extract` with a call to its pullback.
36003652
auto loc = remapLocation(sei->getLoc());
3653+
auto &astCtx = getContext().getASTContext();
36013654

3602-
// Get the pullback.
3603-
auto *pullbackField = getPrimalInfo().lookUpPullbackDecl(sei);
3604-
if (!pullbackField) {
3605-
// Inactive `struct_extract` instructions don't need to be cloned into the
3606-
// adjoint.
3655+
auto &differentiationStrategies =
3656+
getDifferentiationTask()->getStructExtractDifferentiationStrategies();
3657+
auto differentiationStrategyLookUp = differentiationStrategies.find(sei);
3658+
assert(differentiationStrategyLookUp != differentiationStrategies.end());
3659+
auto differentiationStrategy = differentiationStrategyLookUp->second;
3660+
3661+
if (differentiationStrategy ==
3662+
StructExtractDifferentiationStrategy::Inactive) {
36073663
assert(!activityInfo.isActive(sei, synthesis.indices));
36083664
return;
36093665
}
3666+
3667+
if (differentiationStrategy ==
3668+
StructExtractDifferentiationStrategy::FieldwiseProductSpace) {
3669+
// Compute adjoint as follows:
3670+
// y = struct_extract <key>, x
3671+
// adj[x] = struct (0, ..., key': adj[y], ..., 0)
3672+
// where `key'` is the field in the cotangent space corresponding to
3673+
// `key`.
3674+
3675+
// Find the decl of the cotangent space type.
3676+
auto *structDecl = sei->getStructDecl();
3677+
auto aliasLookup = structDecl->lookupDirect(astCtx.Id_CotangentVector);
3678+
assert(aliasLookup.size() == 1);
3679+
assert(isa<TypeAliasDecl>(aliasLookup[0]));
3680+
auto *aliasDecl = cast<TypeAliasDecl>(aliasLookup[0]);
3681+
assert(aliasDecl->getAttrs().hasAttribute<FieldwiseProductSpaceAttr>());
3682+
auto cotangentVectorTy =
3683+
aliasDecl->getUnderlyingTypeLoc().getType()->getCanonicalType();
3684+
assert(!getModule()
3685+
.Types.getTypeLowering(cotangentVectorTy)
3686+
.isAddressOnly());
3687+
auto cotangentVectorSILTy =
3688+
SILType::getPrimitiveObjectType(cotangentVectorTy);
3689+
auto *cotangentVectorDecl =
3690+
cotangentVectorTy->getStructOrBoundGenericStruct();
3691+
assert(cotangentVectorDecl);
3692+
3693+
// Find the corresponding field in the cotangent space.
3694+
VarDecl *correspondingField = nullptr;
3695+
if (cotangentVectorDecl == structDecl)
3696+
correspondingField = sei->getField();
3697+
else {
3698+
auto correspondingFieldLookup =
3699+
cotangentVectorDecl->lookupDirect(sei->getField()->getName());
3700+
assert(correspondingFieldLookup.size() == 1);
3701+
assert(isa<VarDecl>(correspondingFieldLookup[0]));
3702+
correspondingField = cast<VarDecl>(correspondingFieldLookup[0]);
3703+
}
3704+
assert(correspondingField);
3705+
3706+
#ifndef NDEBUG
3707+
unsigned numMatchingStoredProperties = 0;
3708+
for (auto *storedProperty : cotangentVectorDecl->getStoredProperties())
3709+
if (storedProperty == correspondingField)
3710+
numMatchingStoredProperties += 1;
3711+
assert(numMatchingStoredProperties == 1);
3712+
#endif
3713+
3714+
// Compute adjoint.
3715+
auto av = getAdjointValue(sei);
3716+
switch (av.getKind()) {
3717+
case AdjointValue::Kind::Zero:
3718+
addAdjointValue(sei->getOperand(),
3719+
AdjointValue::getZero(cotangentVectorSILTy));
3720+
break;
3721+
case AdjointValue::Kind::Materialized:
3722+
case AdjointValue::Kind::Aggregate: {
3723+
SmallVector<AdjointValue, 8> eltVals;
3724+
for (auto *field : cotangentVectorDecl->getStoredProperties()) {
3725+
if (field == correspondingField)
3726+
eltVals.push_back(av);
3727+
else
3728+
eltVals.push_back(
3729+
AdjointValue::getZero(SILType::getPrimitiveObjectType(
3730+
field->getType()->getCanonicalType())));
3731+
}
3732+
addAdjointValue(sei->getOperand(),
3733+
AdjointValue::getAggregate(cotangentVectorSILTy,
3734+
eltVals, allocator));
3735+
}
3736+
}
3737+
3738+
return;
3739+
}
3740+
3741+
// The only remaining strategy is the getter strategy.
3742+
// Replace the `struct_extract` with a call to its pullback.
3743+
assert(differentiationStrategy ==
3744+
StructExtractDifferentiationStrategy::Getter);
3745+
3746+
// Get the pullback.
3747+
auto *pullbackField = getPrimalInfo().lookUpPullbackDecl(sei);
3748+
assert(pullbackField);
36103749
SILValue pullback = builder.createStructExtract(loc,
36113750
primalValueAggregateInAdj,
36123751
pullbackField);

lib/Sema/TypeCheckAttr.cpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,7 @@ class AttributeEarlyChecker : public AttributeVisitor<AttributeEarlyChecker> {
126126
IGNORED_ATTR(CompilerEvaluable)
127127
IGNORED_ATTR(TensorFlowGraph)
128128
IGNORED_ATTR(TFParameter)
129+
IGNORED_ATTR(FieldwiseProductSpace)
129130
#undef IGNORED_ATTR
130131

131132
// @noreturn has been replaced with a 'Never' return type.
@@ -884,6 +885,7 @@ class AttributeChecker : public AttributeVisitor<AttributeChecker> {
884885
void visitCompilerEvaluableAttr(CompilerEvaluableAttr *attr);
885886
void visitTensorFlowGraphAttr(TensorFlowGraphAttr *attr);
886887
void visitTFParameterAttr(TFParameterAttr *attr);
888+
void visitFieldwiseProductSpaceAttr(FieldwiseProductSpaceAttr *attr);
887889
};
888890
} // end anonymous namespace
889891

@@ -2705,6 +2707,19 @@ void AttributeChecker::visitTFParameterAttr(TFParameterAttr *attr) {
27052707
}
27062708
}
27072709

2710+
void AttributeChecker::visitFieldwiseProductSpaceAttr(
2711+
FieldwiseProductSpaceAttr *attr) {
2712+
// If we make this attribute user-facing, we'll need to do various checks.
2713+
// - check that this attribute is on a Tangent/Cotangent type alias
2714+
// - check that we can access the raw fields of the Tangent/Cotangent from
2715+
// this module (e.g. the Tangent can't be a public resilient struct
2716+
// defined in a different module).
2717+
// - check that the stored properties of the Tangent/Cotangent match
2718+
//
2719+
// If we don't make this attribute user-facing, we can avoid doing checks
2720+
// here: the assertions in TFDifferentiation suffice.
2721+
}
2722+
27082723
void TypeChecker::checkDeclAttributes(Decl *D) {
27092724
AttributeChecker Checker(*this, D);
27102725

lib/Sema/TypeCheckDeclOverride.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1218,6 +1218,7 @@ namespace {
12181218
UNINTERESTING_ATTR(CompilerEvaluable)
12191219
UNINTERESTING_ATTR(TensorFlowGraph)
12201220
UNINTERESTING_ATTR(TFParameter)
1221+
UNINTERESTING_ATTR(FieldwiseProductSpace)
12211222

12221223
// These can't appear on overridable declarations.
12231224
UNINTERESTING_ATTR(Prefix)

test/AutoDiff/e2e_differentiable_property.swift

Lines changed: 52 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -8,26 +8,13 @@ import StdlibUnittest
88

99
var E2EDifferentiablePropertyTests = TestSuite("E2EDifferentiableProperty")
1010

11-
struct TangentSpace {
11+
struct TangentSpace : VectorNumeric {
1212
let dx, dy: Float
1313
}
1414

15-
extension TangentSpace : Differentiable, VectorNumeric {
15+
extension TangentSpace : Differentiable {
1616
typealias TangentVector = TangentSpace
1717
typealias CotangentVector = TangentSpace
18-
typealias Scalar = Float
19-
static var zero: TangentSpace {
20-
return TangentSpace(dx: 0, dy: 0)
21-
}
22-
static func + (lhs: TangentSpace, rhs: TangentSpace) -> TangentSpace {
23-
return TangentSpace(dx: lhs.dx + rhs.dx, dy: lhs.dy + rhs.dy)
24-
}
25-
static func - (lhs: TangentSpace, rhs: TangentSpace) -> TangentSpace {
26-
return TangentSpace(dx: lhs.dx - rhs.dx, dy: lhs.dy - rhs.dy)
27-
}
28-
static func * (lhs: Float, rhs: TangentSpace) -> TangentSpace {
29-
return TangentSpace(dx: lhs * rhs.dx, dy: lhs * rhs.dy)
30-
}
3118
}
3219

3320
struct Space {
@@ -83,4 +70,54 @@ E2EDifferentiablePropertyTests.test("stored property") {
8370
expectEqual(expectedGrad, actualGrad)
8471
}
8572

73+
struct ProductSpaceSelfTangent : VectorNumeric {
74+
let x, y: Float
75+
}
76+
77+
extension ProductSpaceSelfTangent : Differentiable {
78+
@_fieldwiseProductSpace
79+
typealias TangentVector = ProductSpaceSelfTangent
80+
@_fieldwiseProductSpace
81+
typealias CotangentVector = ProductSpaceSelfTangent
82+
}
83+
84+
E2EDifferentiablePropertyTests.test("fieldwise product space, self tangent") {
85+
let actualGrad = gradient(at: ProductSpaceSelfTangent(x: 0, y: 0)) { (point: ProductSpaceSelfTangent) -> Float in
86+
return 5 * point.y
87+
}
88+
let expectedGrad = ProductSpaceSelfTangent(x: 0, y: 5)
89+
expectEqual(expectedGrad, actualGrad)
90+
}
91+
92+
struct ProductSpaceOtherTangentTangentSpace : VectorNumeric {
93+
let x, y: Float
94+
}
95+
96+
extension ProductSpaceOtherTangentTangentSpace : Differentiable {
97+
typealias TangentVector = ProductSpaceOtherTangentTangentSpace
98+
typealias CotangentVector = ProductSpaceOtherTangentTangentSpace
99+
}
100+
101+
struct ProductSpaceOtherTangent {
102+
let x, y: Float
103+
}
104+
105+
extension ProductSpaceOtherTangent : Differentiable {
106+
@_fieldwiseProductSpace
107+
typealias TangentVector = ProductSpaceOtherTangentTangentSpace
108+
@_fieldwiseProductSpace
109+
typealias CotangentVector = ProductSpaceOtherTangentTangentSpace
110+
func moved(along: ProductSpaceOtherTangentTangentSpace) -> ProductSpaceOtherTangent {
111+
return ProductSpaceOtherTangent(x: x + along.x, y: y + along.y)
112+
}
113+
}
114+
115+
E2EDifferentiablePropertyTests.test("fieldwise product space, other tangent") {
116+
let actualGrad = gradient(at: ProductSpaceOtherTangent(x: 0, y: 0)) { (point: ProductSpaceOtherTangent) -> Float in
117+
return 7 * point.y
118+
}
119+
let expectedGrad = ProductSpaceOtherTangentTangentSpace(x: 0, y: 7)
120+
expectEqual(expectedGrad, actualGrad)
121+
}
122+
86123
runAllTests()

0 commit comments

Comments
 (0)