Skip to content

Commit 0cce4b8

Browse files
rxweidan-zheng
authored andcommitted
[AutoDiff] Use @_fieldwiseDifferentiable on struct decls and deprecate @_fieldwiseProductSpace. (#22009)
* Derived conformances for `Differentiable` puts `@_fieldwiseProductSpace` on either all associated types or none, so it is redundant to keep this attribute on every typealias in a `Differentiable` type. Instead, we move it up to the parent struct decl so that it simplifies both derivation and checking. * Simplify `struct_extract` differentiation logic and improve code style. * Remove legacy `lookUpOrLinkFunction` and have everything go through `DifferentiationTask`.
1 parent afe9307 commit 0cce4b8

11 files changed

+132
-194
lines changed

include/swift/AST/Attr.def

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -395,8 +395,8 @@ SIMPLE_DECL_ATTR(TensorFlowGraph, TensorFlowGraph,
395395
OnFunc, 82)
396396
SIMPLE_DECL_ATTR(TFParameter, TFParameter,
397397
OnVar, 83)
398-
SIMPLE_DECL_ATTR(_fieldwiseProductSpace, FieldwiseProductSpace,
399-
OnTypeAlias | OnNominalType | UserInaccessible, 84)
398+
SIMPLE_DECL_ATTR(_fieldwiseDifferentiable, FieldwiseDifferentiable,
399+
OnNominalType | UserInaccessible, 84)
400400
SIMPLE_DECL_ATTR(noDerivative, NoDerivative,
401401
OnVar, 85)
402402

include/swift/AST/DiagnosticsSema.def

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2742,6 +2742,11 @@ ERROR(noderivative_only_on_stored_properties_in_differentiable_structs,none,
27422742
"@noDerivative is only allowed on stored properties in structure types "
27432743
"that declare a conformance to 'Differentiable'", ())
27442744

2745+
// @_fieldwiseDifferentiable attribute
2746+
ERROR(fieldwise_differentiable_only_on_differentiable_structs,none,
2747+
"@_fieldwiseDifferentiable is only allowed on structure types that "
2748+
"conform to 'Differentiable'", ())
2749+
27452750
//------------------------------------------------------------------------------
27462751
// MARK: Type Check Expressions
27472752
//------------------------------------------------------------------------------
@@ -3655,11 +3660,15 @@ ERROR(unreferenced_generic_parameter,none,
36553660
// SWIFT_ENABLE_TENSORFLOW
36563661
// Function differentiability
36573662
ERROR(autodiff_attr_argument_not_differentiable,none,
3658-
"argument is not differentiable, but the enclosing function type is marked '@autodiff'; did you want to add '@nondiff' to this argument?", ())
3663+
"argument is not differentiable, but the enclosing function type is "
3664+
"marked '@autodiff'; did you want to add '@nondiff' to this argument?",
3665+
())
36593666
ERROR(autodiff_attr_result_not_differentiable,none,
3660-
"result is not differentiable, but the function type is marked '@autodiff'", ())
3667+
"result is not differentiable, but the function type is marked "
3668+
"'@autodiff'", ())
36613669
ERROR(nondiff_attr_invalid_on_nondifferentiable_function,none,
3662-
"'nondiff' cannot be applied to arguments of a non-differentiable function", ())
3670+
"'nondiff' cannot be applied to arguments of a non-differentiable "
3671+
"function", ())
36633672

36643673
// SIL
36653674
ERROR(opened_non_protocol,none,

lib/SILOptimizer/Mandatory/Differentiation.cpp

Lines changed: 60 additions & 123 deletions
Original file line numberDiff line numberDiff line change
@@ -99,16 +99,6 @@ static void createEntryArguments(SILFunction *f) {
9999
}
100100
}
101101

102-
/// Looks up a function in the current module. If it exists, returns it.
103-
/// Otherwise, attempt to link it from imported modules. Returns null if such
104-
/// function name does not exist.
105-
static SILFunction *lookUpOrLinkFunction(StringRef name, SILModule &module) {
106-
assert(!name.empty());
107-
if (auto *localFn = module.lookUpFunction(name))
108-
return localFn;
109-
return module.findFunction(name, SILLinkage::PublicExternal);
110-
}
111-
112102
/// Computes the correct linkage for functions generated by the AD pass
113103
/// associated with a function with linkage `originalLinkage`.
114104
static SILLinkage getAutoDiffFunctionLinkage(SILLinkage originalLinkage) {
@@ -528,7 +518,7 @@ enum class StructExtractDifferentiationStrategy {
528518
// that is zero except along the direction of the corresponding field.
529519
//
530520
// Fields correspond by matching name.
531-
FieldwiseProductSpace,
521+
Fieldwise,
532522

533523
// Differentiate the `struct_extract` by looking up the corresponding getter
534524
// and using its VJP.
@@ -1291,6 +1281,7 @@ void DifferentiableActivityInfo::analyze(DominanceInfo *di,
12911281
if (isVaried(cai->getSrc(), i))
12921282
recursivelySetVariedIfDifferentiable(cai->getDest(), i);
12931283
}
1284+
// Handle `struct_extract`.
12941285
else if (auto *sei = dyn_cast<StructExtractInst>(&inst)) {
12951286
if (isVaried(sei->getOperand(), i)) {
12961287
auto hasNoDeriv = sei->getField()->getAttrs()
@@ -2091,46 +2082,30 @@ class PrimalGenCloner final : public SILClonerWithScopes<PrimalGenCloner> {
20912082
}
20922083

20932084
void visitStructExtractInst(StructExtractInst *sei) {
2094-
auto &astCtx = getContext().getASTContext();
2095-
auto &structExtractDifferentiationStrategies =
2085+
auto &strategies =
20962086
getDifferentiationTask()->getStructExtractDifferentiationStrategies();
2097-
20982087
// Special handling logic only applies when the `struct_extract` is active.
20992088
// If not, just do standard cloning.
21002089
if (!activityInfo.isActive(sei, synthesis.indices)) {
21012090
LLVM_DEBUG(getADDebugStream() << "Not active:\n" << *sei << '\n');
2102-
structExtractDifferentiationStrategies.insert(
2091+
strategies.insert(
21032092
{sei, StructExtractDifferentiationStrategy::Inactive});
21042093
SILClonerWithScopes::visitStructExtractInst(sei);
21052094
return;
21062095
}
2107-
21082096
// This instruction is active. Determine the appropriate differentiation
21092097
// strategy, and use it.
2110-
2111-
// Use the FieldwiseProductSpace strategy, if appropriate.
21122098
auto *structDecl = sei->getStructDecl();
2113-
auto cotangentDeclLookup =
2114-
structDecl->lookupDirect(astCtx.Id_CotangentVector);
2115-
if (cotangentDeclLookup.size() >= 1) {
2116-
assert(cotangentDeclLookup.size() == 1);
2117-
auto cotangentTypeDecl = cotangentDeclLookup.front();
2118-
assert(isa<TypeAliasDecl>(cotangentTypeDecl) ||
2119-
isa<StructDecl>(cotangentTypeDecl));
2120-
if (cotangentTypeDecl->getAttrs()
2121-
.hasAttribute<FieldwiseProductSpaceAttr>()) {
2122-
structExtractDifferentiationStrategies.insert(
2123-
{sei, StructExtractDifferentiationStrategy::FieldwiseProductSpace});
2124-
SILClonerWithScopes::visitStructExtractInst(sei);
2125-
return;
2126-
}
2099+
if (structDecl->getAttrs().hasAttribute<FieldwiseDifferentiableAttr>()) {
2100+
strategies.insert(
2101+
{sei, StructExtractDifferentiationStrategy::Fieldwise});
2102+
SILClonerWithScopes::visitStructExtractInst(sei);
2103+
return;
21272104
}
2128-
21292105
// The FieldwiseProductSpace strategy is not appropriate, so use the Getter
21302106
// strategy.
2131-
structExtractDifferentiationStrategies.insert(
2107+
strategies.insert(
21322108
{sei, StructExtractDifferentiationStrategy::Getter});
2133-
21342109
// Find the corresponding getter and its VJP.
21352110
auto *getterDecl = sei->getField()->getGetter();
21362111
assert(getterDecl);
@@ -2142,42 +2117,29 @@ class PrimalGenCloner final : public SILClonerWithScopes<PrimalGenCloner> {
21422117
errorOccurred = true;
21432118
return;
21442119
}
2145-
auto getterDiffAttrs = getterFn->getDifferentiableAttrs();
2146-
if (getterDiffAttrs.size() < 1) {
2147-
getContext().emitNondifferentiabilityError(
2148-
sei, synthesis.task, diag::autodiff_property_not_differentiable);
2149-
errorOccurred = true;
2150-
return;
2151-
}
2152-
auto *getterDiffAttr = getterDiffAttrs[0];
2153-
if (!getterDiffAttr->hasVJP()) {
2120+
SILAutoDiffIndices indices(/*source*/ 0, /*parameters*/ {0});
2121+
auto *task = getContext().lookUpDifferentiationTask(getterFn, indices);
2122+
if (!task) {
21542123
getContext().emitNondifferentiabilityError(
21552124
sei, synthesis.task, diag::autodiff_property_not_differentiable);
21562125
errorOccurred = true;
21572126
return;
21582127
}
2159-
assert(getterDiffAttr->getIndices() ==
2160-
SILAutoDiffIndices(/*source*/ 0, /*parameters*/{0}));
2161-
auto *getterVJP = lookUpOrLinkFunction(getterDiffAttr->getVJPName(),
2162-
getContext().getModule());
2163-
21642128
// Reference and apply the VJP.
21652129
auto loc = sei->getLoc();
2166-
auto *getterVJPRef = getBuilder().createFunctionRef(loc, getterVJP);
2130+
auto *getterVJPRef = getBuilder().createFunctionRef(loc, task->getVJP());
21672131
auto *getterVJPApply = getBuilder().createApply(
21682132
loc, getterVJPRef, /*substitutionMap*/ {},
21692133
/*args*/ {getMappedValue(sei->getOperand())}, /*isNonThrowing*/ false);
21702134
SmallVector<SILValue, 8> vjpDirectResults;
21712135
extractAllElements(getterVJPApply, getBuilder(), vjpDirectResults);
2172-
ArrayRef<SILValue> originalDirectResults =
2173-
ArrayRef<SILValue>(vjpDirectResults).drop_back(1);
2174-
21752136
// Map original results.
2137+
auto originalDirectResults =
2138+
ArrayRef<SILValue>(vjpDirectResults).drop_back(1);
21762139
SILValue originalDirectResult = joinElements(originalDirectResults,
21772140
getBuilder(),
21782141
getterVJPApply->getLoc());
21792142
mapValue(sei, originalDirectResult);
2180-
21812143
// Checkpoint the pullback.
21822144
SILValue pullback = vjpDirectResults.back();
21832145
getPrimalInfo().addPullbackDecl(sei, pullback->getType().getASTType());
@@ -3079,60 +3041,41 @@ class AdjointEmitter final : public SILInstructionVisitor<AdjointEmitter> {
30793041
auto loc = remapLocation(sei->getLoc());
30803042
auto &differentiationStrategies =
30813043
getDifferentiationTask()->getStructExtractDifferentiationStrategies();
3082-
auto differentiationStrategyLookUp = differentiationStrategies.find(sei);
3083-
assert(differentiationStrategyLookUp != differentiationStrategies.end());
3084-
auto differentiationStrategy = differentiationStrategyLookUp->second;
3085-
3086-
if (differentiationStrategy ==
3087-
StructExtractDifferentiationStrategy::Inactive) {
3044+
auto strategy = differentiationStrategies.lookup(sei);
3045+
switch (strategy) {
3046+
case StructExtractDifferentiationStrategy::Inactive:
30883047
assert(!activityInfo.isActive(sei, synthesis.indices));
30893048
return;
3090-
}
3091-
3092-
if (differentiationStrategy ==
3093-
StructExtractDifferentiationStrategy::FieldwiseProductSpace) {
3049+
case StructExtractDifferentiationStrategy::Fieldwise: {
30943050
// Compute adjoint as follows:
30953051
// y = struct_extract <key>, x
30963052
// adj[x] = struct (0, ..., key': adj[y], ..., 0)
30973053
// where `key'` is the field in the cotangent space corresponding to
30983054
// `key`.
3099-
3100-
// Find the decl of the cotangent space type.
31013055
auto structTy = sei->getOperand()->getType().getASTType();
31023056
auto cotangentVectorTy = structTy->getAutoDiffAssociatedVectorSpace(
31033057
AutoDiffAssociatedVectorSpaceKind::Cotangent,
31043058
LookUpConformanceInModule(getModule().getSwiftModule()))
3105-
->getType()->getCanonicalType();
3106-
assert(!getModule()
3107-
.Types.getTypeLowering(cotangentVectorTy)
3108-
.isAddressOnly());
3059+
->getType()->getCanonicalType();
3060+
assert(!getModule().Types.getTypeLowering(cotangentVectorTy)
3061+
.isAddressOnly());
31093062
auto cotangentVectorSILTy =
31103063
SILType::getPrimitiveObjectType(cotangentVectorTy);
31113064
auto *cotangentVectorDecl =
31123065
cotangentVectorTy->getStructOrBoundGenericStruct();
31133066
assert(cotangentVectorDecl);
3114-
31153067
// Find the corresponding field in the cotangent space.
31163068
VarDecl *correspondingField = nullptr;
3069+
// If the cotangent space is the original sapce, then it's the same field.
31173070
if (cotangentVectorDecl == sei->getStructDecl())
31183071
correspondingField = sei->getField();
3072+
// Otherwise we just look it up by name.
31193073
else {
31203074
auto correspondingFieldLookup =
31213075
cotangentVectorDecl->lookupDirect(sei->getField()->getName());
31223076
assert(correspondingFieldLookup.size() == 1);
3123-
assert(isa<VarDecl>(correspondingFieldLookup[0]));
3124-
correspondingField = cast<VarDecl>(correspondingFieldLookup[0]);
3077+
correspondingField = cast<VarDecl>(correspondingFieldLookup.front());
31253078
}
3126-
assert(correspondingField);
3127-
3128-
#ifndef NDEBUG
3129-
unsigned numMatchingStoredProperties = 0;
3130-
for (auto *storedProperty : cotangentVectorDecl->getStoredProperties())
3131-
if (storedProperty == correspondingField)
3132-
numMatchingStoredProperties += 1;
3133-
assert(numMatchingStoredProperties == 1);
3134-
#endif
3135-
31363079
// Compute adjoint.
31373080
auto av = getAdjointValue(sei);
31383081
switch (av.getKind()) {
@@ -3148,44 +3091,41 @@ class AdjointEmitter final : public SILInstructionVisitor<AdjointEmitter> {
31483091
eltVals.push_back(av);
31493092
else
31503093
eltVals.push_back(AdjointValue::getZero(
3151-
SILType::getPrimitiveObjectType(field->getType()
3152-
->getCanonicalType())));
3094+
SILType::getPrimitiveObjectType(
3095+
field->getType()->getCanonicalType())));
31533096
}
31543097
addAdjointValue(sei->getOperand(),
31553098
AdjointValue::getAggregate(cotangentVectorSILTy,
31563099
eltVals, allocator));
31573100
}
31583101
}
3159-
31603102
return;
31613103
}
3162-
3163-
// The only remaining strategy is the getter strategy.
3164-
// Replace the `struct_extract` with a call to its pullback.
3165-
assert(differentiationStrategy ==
3166-
StructExtractDifferentiationStrategy::Getter);
3167-
3168-
// Get the pullback.
3169-
auto *pullbackField = getPrimalInfo().lookUpPullbackDecl(sei);
3170-
assert(pullbackField);
3171-
SILValue pullback = builder.createStructExtract(loc,
3172-
primalValueAggregateInAdj,
3173-
pullbackField);
3174-
3175-
// Construct the pullback arguments.
3176-
SmallVector<SILValue, 8> args;
3177-
auto seed = getAdjointValue(sei);
3178-
assert(seed.getType().isObject());
3179-
args.push_back(materializeAdjointDirect(seed, loc));
3180-
3181-
// Call the pullback.
3182-
auto *pullbackCall = builder.createApply(loc, pullback, SubstitutionMap(),
3183-
args, /*isNonThrowing*/ false);
3184-
assert(!pullbackCall->hasIndirectResults());
3185-
3186-
// Set adjoint for the `struct_extract` operand.
3187-
addAdjointValue(sei->getOperand(),
3188-
AdjointValue::getMaterialized(pullbackCall));
3104+
case StructExtractDifferentiationStrategy::Getter: {
3105+
// Get the pullback.
3106+
auto *pullbackField = getPrimalInfo().lookUpPullbackDecl(sei);
3107+
assert(pullbackField);
3108+
SILValue pullback = builder.createStructExtract(loc,
3109+
primalValueAggregateInAdj,
3110+
pullbackField);
3111+
3112+
// Construct the pullback arguments.
3113+
SmallVector<SILValue, 8> args;
3114+
auto seed = getAdjointValue(sei);
3115+
assert(seed.getType().isObject());
3116+
args.push_back(materializeAdjointDirect(seed, loc));
3117+
3118+
// Call the pullback.
3119+
auto *pullbackCall = builder.createApply(loc, pullback, SubstitutionMap(),
3120+
args, /*isNonThrowing*/ false);
3121+
assert(!pullbackCall->hasIndirectResults());
3122+
3123+
// Set adjoint for the `struct_extract` operand.
3124+
addAdjointValue(sei->getOperand(),
3125+
AdjointValue::getMaterialized(pullbackCall));
3126+
break;
3127+
}
3128+
}
31893129
}
31903130

31913131
/// Handle `tuple` instruction.
@@ -4236,25 +4176,22 @@ void DifferentiationTask::createVJP() {
42364176
loc, adjointRef, vjpSubstMap, partialAdjointArgs,
42374177
ParameterConvention::Direct_Guaranteed);
42384178

4239-
// === Clean up the stack allocations. ===
4179+
// Clean up the stack allocations.
42404180
for (auto alloc : reversed(stackAllocsToCleanUp))
42414181
builder.createDeallocStack(loc, alloc);
42424182

4243-
// === Return the direct results. ===
4244-
// (Note that indirect results have already been filled in by the application
4245-
// of the primal).
4183+
// Return the direct results. Note that indirect results have already been
4184+
// filled in by the application of the primal.
42464185
SmallVector<SILValue, 8> directResults;
42474186
auto originalDirectResults = ArrayRef<SILValue>(primalDirectResults)
42484187
.take_back(originalConv.getNumDirectSILResults());
42494188
for (auto originalDirectResult : originalDirectResults)
42504189
directResults.push_back(originalDirectResult);
42514190
directResults.push_back(adjointPartialApply);
4252-
if (directResults.size() > 1) {
4253-
auto tupleRet = builder.createTuple(loc, directResults);
4254-
builder.createReturn(loc, tupleRet);
4255-
} else {
4256-
builder.createReturn(loc, directResults[0]);
4257-
}
4191+
if (directResults.size() > 1)
4192+
builder.createReturn(loc, builder.createTuple(loc, directResults));
4193+
else
4194+
builder.createReturn(loc, directResults.front());
42584195
}
42594196

42604197
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)