@@ -99,16 +99,6 @@ static void createEntryArguments(SILFunction *f) {
99
99
}
100
100
}
101
101
102
- // / Looks up a function in the current module. If it exists, returns it.
103
- // / Otherwise, attempt to link it from imported modules. Returns null if such
104
- // / function name does not exist.
105
- static SILFunction *lookUpOrLinkFunction (StringRef name, SILModule &module ) {
106
- assert (!name.empty ());
107
- if (auto *localFn = module .lookUpFunction (name))
108
- return localFn;
109
- return module .findFunction (name, SILLinkage::PublicExternal);
110
- }
111
-
112
102
// / Computes the correct linkage for functions generated by the AD pass
113
103
// / associated with a function with linkage `originalLinkage`.
114
104
static SILLinkage getAutoDiffFunctionLinkage (SILLinkage originalLinkage) {
@@ -528,7 +518,7 @@ enum class StructExtractDifferentiationStrategy {
528
518
// that is zero except along the direction of the corresponding field.
529
519
//
530
520
// Fields correspond by matching name.
531
- FieldwiseProductSpace ,
521
+ Fieldwise ,
532
522
533
523
// Differentiate the `struct_extract` by looking up the corresponding getter
534
524
// and using its VJP.
@@ -1291,6 +1281,7 @@ void DifferentiableActivityInfo::analyze(DominanceInfo *di,
1291
1281
if (isVaried (cai->getSrc (), i))
1292
1282
recursivelySetVariedIfDifferentiable (cai->getDest (), i);
1293
1283
}
1284
+ // Handle `struct_extract`.
1294
1285
else if (auto *sei = dyn_cast<StructExtractInst>(&inst)) {
1295
1286
if (isVaried (sei->getOperand (), i)) {
1296
1287
auto hasNoDeriv = sei->getField ()->getAttrs ()
@@ -2091,46 +2082,30 @@ class PrimalGenCloner final : public SILClonerWithScopes<PrimalGenCloner> {
2091
2082
}
2092
2083
2093
2084
void visitStructExtractInst (StructExtractInst *sei) {
2094
- auto &astCtx = getContext ().getASTContext ();
2095
- auto &structExtractDifferentiationStrategies =
2085
+ auto &strategies =
2096
2086
getDifferentiationTask ()->getStructExtractDifferentiationStrategies ();
2097
-
2098
2087
// Special handling logic only applies when the `struct_extract` is active.
2099
2088
// If not, just do standard cloning.
2100
2089
if (!activityInfo.isActive (sei, synthesis.indices )) {
2101
2090
LLVM_DEBUG (getADDebugStream () << " Not active:\n " << *sei << ' \n ' );
2102
- structExtractDifferentiationStrategies .insert (
2091
+ strategies .insert (
2103
2092
{sei, StructExtractDifferentiationStrategy::Inactive});
2104
2093
SILClonerWithScopes::visitStructExtractInst (sei);
2105
2094
return ;
2106
2095
}
2107
-
2108
2096
// This instruction is active. Determine the appropriate differentiation
2109
2097
// strategy, and use it.
2110
-
2111
- // Use the FieldwiseProductSpace strategy, if appropriate.
2112
2098
auto *structDecl = sei->getStructDecl ();
2113
- auto cotangentDeclLookup =
2114
- structDecl->lookupDirect (astCtx.Id_CotangentVector );
2115
- if (cotangentDeclLookup.size () >= 1 ) {
2116
- assert (cotangentDeclLookup.size () == 1 );
2117
- auto cotangentTypeDecl = cotangentDeclLookup.front ();
2118
- assert (isa<TypeAliasDecl>(cotangentTypeDecl) ||
2119
- isa<StructDecl>(cotangentTypeDecl));
2120
- if (cotangentTypeDecl->getAttrs ()
2121
- .hasAttribute <FieldwiseProductSpaceAttr>()) {
2122
- structExtractDifferentiationStrategies.insert (
2123
- {sei, StructExtractDifferentiationStrategy::FieldwiseProductSpace});
2124
- SILClonerWithScopes::visitStructExtractInst (sei);
2125
- return ;
2126
- }
2099
+ if (structDecl->getAttrs ().hasAttribute <FieldwiseDifferentiableAttr>()) {
2100
+ strategies.insert (
2101
+ {sei, StructExtractDifferentiationStrategy::Fieldwise});
2102
+ SILClonerWithScopes::visitStructExtractInst (sei);
2103
+ return ;
2127
2104
}
2128
-
2129
2105
// The FieldwiseProductSpace strategy is not appropriate, so use the Getter
2130
2106
// strategy.
2131
- structExtractDifferentiationStrategies .insert (
2107
+ strategies .insert (
2132
2108
{sei, StructExtractDifferentiationStrategy::Getter});
2133
-
2134
2109
// Find the corresponding getter and its VJP.
2135
2110
auto *getterDecl = sei->getField ()->getGetter ();
2136
2111
assert (getterDecl);
@@ -2142,42 +2117,29 @@ class PrimalGenCloner final : public SILClonerWithScopes<PrimalGenCloner> {
2142
2117
errorOccurred = true ;
2143
2118
return ;
2144
2119
}
2145
- auto getterDiffAttrs = getterFn->getDifferentiableAttrs ();
2146
- if (getterDiffAttrs.size () < 1 ) {
2147
- getContext ().emitNondifferentiabilityError (
2148
- sei, synthesis.task , diag::autodiff_property_not_differentiable);
2149
- errorOccurred = true ;
2150
- return ;
2151
- }
2152
- auto *getterDiffAttr = getterDiffAttrs[0 ];
2153
- if (!getterDiffAttr->hasVJP ()) {
2120
+ SILAutoDiffIndices indices (/* source*/ 0 , /* parameters*/ {0 });
2121
+ auto *task = getContext ().lookUpDifferentiationTask (getterFn, indices);
2122
+ if (!task) {
2154
2123
getContext ().emitNondifferentiabilityError (
2155
2124
sei, synthesis.task , diag::autodiff_property_not_differentiable);
2156
2125
errorOccurred = true ;
2157
2126
return ;
2158
2127
}
2159
- assert (getterDiffAttr->getIndices () ==
2160
- SILAutoDiffIndices (/* source*/ 0 , /* parameters*/ {0 }));
2161
- auto *getterVJP = lookUpOrLinkFunction (getterDiffAttr->getVJPName (),
2162
- getContext ().getModule ());
2163
-
2164
2128
// Reference and apply the VJP.
2165
2129
auto loc = sei->getLoc ();
2166
- auto *getterVJPRef = getBuilder ().createFunctionRef (loc, getterVJP );
2130
+ auto *getterVJPRef = getBuilder ().createFunctionRef (loc, task-> getVJP () );
2167
2131
auto *getterVJPApply = getBuilder ().createApply (
2168
2132
loc, getterVJPRef, /* substitutionMap*/ {},
2169
2133
/* args*/ {getMappedValue (sei->getOperand ())}, /* isNonThrowing*/ false );
2170
2134
SmallVector<SILValue, 8 > vjpDirectResults;
2171
2135
extractAllElements (getterVJPApply, getBuilder (), vjpDirectResults);
2172
- ArrayRef<SILValue> originalDirectResults =
2173
- ArrayRef<SILValue>(vjpDirectResults).drop_back (1 );
2174
-
2175
2136
// Map original results.
2137
+ auto originalDirectResults =
2138
+ ArrayRef<SILValue>(vjpDirectResults).drop_back (1 );
2176
2139
SILValue originalDirectResult = joinElements (originalDirectResults,
2177
2140
getBuilder (),
2178
2141
getterVJPApply->getLoc ());
2179
2142
mapValue (sei, originalDirectResult);
2180
-
2181
2143
// Checkpoint the pullback.
2182
2144
SILValue pullback = vjpDirectResults.back ();
2183
2145
getPrimalInfo ().addPullbackDecl (sei, pullback->getType ().getASTType ());
@@ -3079,60 +3041,41 @@ class AdjointEmitter final : public SILInstructionVisitor<AdjointEmitter> {
3079
3041
auto loc = remapLocation (sei->getLoc ());
3080
3042
auto &differentiationStrategies =
3081
3043
getDifferentiationTask ()->getStructExtractDifferentiationStrategies ();
3082
- auto differentiationStrategyLookUp = differentiationStrategies.find (sei);
3083
- assert (differentiationStrategyLookUp != differentiationStrategies.end ());
3084
- auto differentiationStrategy = differentiationStrategyLookUp->second ;
3085
-
3086
- if (differentiationStrategy ==
3087
- StructExtractDifferentiationStrategy::Inactive) {
3044
+ auto strategy = differentiationStrategies.lookup (sei);
3045
+ switch (strategy) {
3046
+ case StructExtractDifferentiationStrategy::Inactive:
3088
3047
assert (!activityInfo.isActive (sei, synthesis.indices ));
3089
3048
return ;
3090
- }
3091
-
3092
- if (differentiationStrategy ==
3093
- StructExtractDifferentiationStrategy::FieldwiseProductSpace) {
3049
+ case StructExtractDifferentiationStrategy::Fieldwise: {
3094
3050
// Compute adjoint as follows:
3095
3051
// y = struct_extract <key>, x
3096
3052
// adj[x] = struct (0, ..., key': adj[y], ..., 0)
3097
3053
// where `key'` is the field in the cotangent space corresponding to
3098
3054
// `key`.
3099
-
3100
- // Find the decl of the cotangent space type.
3101
3055
auto structTy = sei->getOperand ()->getType ().getASTType ();
3102
3056
auto cotangentVectorTy = structTy->getAutoDiffAssociatedVectorSpace (
3103
3057
AutoDiffAssociatedVectorSpaceKind::Cotangent,
3104
3058
LookUpConformanceInModule (getModule ().getSwiftModule ()))
3105
- ->getType ()->getCanonicalType ();
3106
- assert (!getModule ()
3107
- .Types .getTypeLowering (cotangentVectorTy)
3108
- .isAddressOnly ());
3059
+ ->getType ()->getCanonicalType ();
3060
+ assert (!getModule ().Types .getTypeLowering (cotangentVectorTy)
3061
+ .isAddressOnly ());
3109
3062
auto cotangentVectorSILTy =
3110
3063
SILType::getPrimitiveObjectType (cotangentVectorTy);
3111
3064
auto *cotangentVectorDecl =
3112
3065
cotangentVectorTy->getStructOrBoundGenericStruct ();
3113
3066
assert (cotangentVectorDecl);
3114
-
3115
3067
// Find the corresponding field in the cotangent space.
3116
3068
VarDecl *correspondingField = nullptr ;
3069
+ // If the cotangent space is the original sapce, then it's the same field.
3117
3070
if (cotangentVectorDecl == sei->getStructDecl ())
3118
3071
correspondingField = sei->getField ();
3072
+ // Otherwise we just look it up by name.
3119
3073
else {
3120
3074
auto correspondingFieldLookup =
3121
3075
cotangentVectorDecl->lookupDirect (sei->getField ()->getName ());
3122
3076
assert (correspondingFieldLookup.size () == 1 );
3123
- assert (isa<VarDecl>(correspondingFieldLookup[0 ]));
3124
- correspondingField = cast<VarDecl>(correspondingFieldLookup[0 ]);
3077
+ correspondingField = cast<VarDecl>(correspondingFieldLookup.front ());
3125
3078
}
3126
- assert (correspondingField);
3127
-
3128
- #ifndef NDEBUG
3129
- unsigned numMatchingStoredProperties = 0 ;
3130
- for (auto *storedProperty : cotangentVectorDecl->getStoredProperties ())
3131
- if (storedProperty == correspondingField)
3132
- numMatchingStoredProperties += 1 ;
3133
- assert (numMatchingStoredProperties == 1 );
3134
- #endif
3135
-
3136
3079
// Compute adjoint.
3137
3080
auto av = getAdjointValue (sei);
3138
3081
switch (av.getKind ()) {
@@ -3148,44 +3091,41 @@ class AdjointEmitter final : public SILInstructionVisitor<AdjointEmitter> {
3148
3091
eltVals.push_back (av);
3149
3092
else
3150
3093
eltVals.push_back (AdjointValue::getZero (
3151
- SILType::getPrimitiveObjectType (field-> getType ()
3152
- ->getCanonicalType ())));
3094
+ SILType::getPrimitiveObjectType (
3095
+ field-> getType () ->getCanonicalType ())));
3153
3096
}
3154
3097
addAdjointValue (sei->getOperand (),
3155
3098
AdjointValue::getAggregate (cotangentVectorSILTy,
3156
3099
eltVals, allocator));
3157
3100
}
3158
3101
}
3159
-
3160
3102
return ;
3161
3103
}
3162
-
3163
- // The only remaining strategy is the getter strategy.
3164
- // Replace the `struct_extract` with a call to its pullback.
3165
- assert (differentiationStrategy ==
3166
- StructExtractDifferentiationStrategy::Getter);
3167
-
3168
- // Get the pullback.
3169
- auto *pullbackField = getPrimalInfo ().lookUpPullbackDecl (sei);
3170
- assert (pullbackField);
3171
- SILValue pullback = builder.createStructExtract (loc,
3172
- primalValueAggregateInAdj,
3173
- pullbackField);
3174
-
3175
- // Construct the pullback arguments.
3176
- SmallVector<SILValue, 8 > args;
3177
- auto seed = getAdjointValue (sei);
3178
- assert (seed.getType ().isObject ());
3179
- args.push_back (materializeAdjointDirect (seed, loc));
3180
-
3181
- // Call the pullback.
3182
- auto *pullbackCall = builder.createApply (loc, pullback, SubstitutionMap (),
3183
- args, /* isNonThrowing*/ false );
3184
- assert (!pullbackCall->hasIndirectResults ());
3185
-
3186
- // Set adjoint for the `struct_extract` operand.
3187
- addAdjointValue (sei->getOperand (),
3188
- AdjointValue::getMaterialized (pullbackCall));
3104
+ case StructExtractDifferentiationStrategy::Getter: {
3105
+ // Get the pullback.
3106
+ auto *pullbackField = getPrimalInfo ().lookUpPullbackDecl (sei);
3107
+ assert (pullbackField);
3108
+ SILValue pullback = builder.createStructExtract (loc,
3109
+ primalValueAggregateInAdj,
3110
+ pullbackField);
3111
+
3112
+ // Construct the pullback arguments.
3113
+ SmallVector<SILValue, 8 > args;
3114
+ auto seed = getAdjointValue (sei);
3115
+ assert (seed.getType ().isObject ());
3116
+ args.push_back (materializeAdjointDirect (seed, loc));
3117
+
3118
+ // Call the pullback.
3119
+ auto *pullbackCall = builder.createApply (loc, pullback, SubstitutionMap (),
3120
+ args, /* isNonThrowing*/ false );
3121
+ assert (!pullbackCall->hasIndirectResults ());
3122
+
3123
+ // Set adjoint for the `struct_extract` operand.
3124
+ addAdjointValue (sei->getOperand (),
3125
+ AdjointValue::getMaterialized (pullbackCall));
3126
+ break ;
3127
+ }
3128
+ }
3189
3129
}
3190
3130
3191
3131
// / Handle `tuple` instruction.
@@ -4236,25 +4176,22 @@ void DifferentiationTask::createVJP() {
4236
4176
loc, adjointRef, vjpSubstMap, partialAdjointArgs,
4237
4177
ParameterConvention::Direct_Guaranteed);
4238
4178
4239
- // === Clean up the stack allocations. ===
4179
+ // Clean up the stack allocations.
4240
4180
for (auto alloc : reversed (stackAllocsToCleanUp))
4241
4181
builder.createDeallocStack (loc, alloc);
4242
4182
4243
- // === Return the direct results. ===
4244
- // (Note that indirect results have already been filled in by the application
4245
- // of the primal).
4183
+ // Return the direct results. Note that indirect results have already been
4184
+ // filled in by the application of the primal.
4246
4185
SmallVector<SILValue, 8 > directResults;
4247
4186
auto originalDirectResults = ArrayRef<SILValue>(primalDirectResults)
4248
4187
.take_back (originalConv.getNumDirectSILResults ());
4249
4188
for (auto originalDirectResult : originalDirectResults)
4250
4189
directResults.push_back (originalDirectResult);
4251
4190
directResults.push_back (adjointPartialApply);
4252
- if (directResults.size () > 1 ) {
4253
- auto tupleRet = builder.createTuple (loc, directResults);
4254
- builder.createReturn (loc, tupleRet);
4255
- } else {
4256
- builder.createReturn (loc, directResults[0 ]);
4257
- }
4191
+ if (directResults.size () > 1 )
4192
+ builder.createReturn (loc, builder.createTuple (loc, directResults));
4193
+ else
4194
+ builder.createReturn (loc, directResults.front ());
4258
4195
}
4259
4196
4260
4197
// ===----------------------------------------------------------------------===//
0 commit comments