@@ -737,11 +737,7 @@ enum class StructExtractDifferentiationStrategy {
737
737
// that is zero except along the direction of the corresponding field.
738
738
//
739
739
// Fields correspond by matching name.
740
- Fieldwise,
741
-
742
- // Differentiate the `struct_extract` by looking up the corresponding getter
743
- // and using its VJP.
744
- Getter
740
+ Fieldwise
745
741
};
746
742
747
743
static inline llvm::raw_ostream &operator <<(llvm::raw_ostream &os,
@@ -3232,59 +3228,10 @@ class VJPEmitter final
3232
3228
SILClonerWithScopes::visitStructExtractInst (sei);
3233
3229
return ;
3234
3230
}
3235
- // This instruction is active. Determine the appropriate differentiation
3236
- // strategy, and use it.
3237
- // Find the corresponding getter.
3238
- auto *getterDecl = sei->getField ()->getGetter ();
3239
- assert (getterDecl);
3240
- auto *getterFn = getModule ().lookUpFunction (
3241
- SILDeclRef (getterDecl, SILDeclRef::Kind::Func));
3242
- auto *structDecl = sei->getStructDecl ();
3243
- if (!getterFn ||
3244
- structDecl->getAttrs ().hasAttribute <FieldwiseDifferentiableAttr>()) {
3245
- strategies[sei] = StructExtractDifferentiationStrategy::Fieldwise;
3246
- SILClonerWithScopes::visitStructExtractInst (sei);
3247
- return ;
3248
- }
3249
- // The FieldwiseProductSpace strategy is not appropriate, so use the Getter
3250
- // strategy.
3251
- assert (getterFn);
3252
- strategies[sei] = StructExtractDifferentiationStrategy::Getter;
3253
- SILAutoDiffIndices indices (/* source*/ 0 ,
3254
- AutoDiffIndexSubset::getDefault (getASTContext (), 1 , true ));
3255
- auto *attr = context.lookUpDifferentiableAttr (getterFn, indices);
3256
- if (!attr) {
3257
- context.emitNondifferentiabilityError (
3258
- sei, invoker, diag::autodiff_property_not_differentiable);
3259
- errorOccurred = true ;
3260
- return ;
3261
- }
3262
- // Reference and apply the VJP.
3263
- auto loc = sei->getLoc ();
3264
- auto *getterVJP = getAssociatedFunction (
3265
- context, getterFn, attr, AutoDiffAssociatedFunctionKind::VJP,
3266
- attr->getVJPName ());
3267
- assert (getterVJP && " Expected to find getter VJP" );
3268
- auto *getterVJPRef = getBuilder ().createFunctionRef (loc, getterVJP);
3269
- auto *getterVJPApply = getBuilder ().createApply (
3270
- loc, getterVJPRef,
3271
- getOpSubstitutionMap (getterVJP->getForwardingSubstitutionMap ()),
3272
- /* args*/ {getOpValue (sei->getOperand ())}, /* isNonThrowing*/ false );
3273
- // Extract direct results from `getterVJPApply`.
3274
- SmallVector<SILValue, 8 > vjpDirectResults;
3275
- extractAllElements (getterVJPApply, getBuilder (), vjpDirectResults);
3276
- // Map original result.
3277
- auto originalDirectResults =
3278
- ArrayRef<SILValue>(vjpDirectResults).drop_back (1 );
3279
- auto originalDirectResult = joinElements (originalDirectResults,
3280
- getBuilder (),
3281
- getterVJPApply->getLoc ());
3282
- mapValue (sei, originalDirectResult);
3283
- // Checkpoint the pullback.
3284
- auto pullback = vjpDirectResults.back ();
3285
- // TODO: Check whether it's necessary to reabstract getter pullbacks.
3286
- pullbackInfo.addPullbackDecl (sei, getOpType (pullback->getType ()));
3287
- pullbackValues[sei->getParent ()].push_back (pullback);
3231
+ // This instruction is active. Use the field wise differentiation strategy
3232
+ // to differentiate the struct extract instruction.
3233
+ strategies[sei] = StructExtractDifferentiationStrategy::Fieldwise;
3234
+ SILClonerWithScopes::visitStructExtractInst (sei);
3288
3235
}
3289
3236
3290
3237
void visitStructElementAddrInst (StructElementAddrInst *seai) {
@@ -3297,78 +3244,10 @@ class VJPEmitter final
3297
3244
SILClonerWithScopes::visitStructElementAddrInst (seai);
3298
3245
return ;
3299
3246
}
3300
- // This instruction is active. Determine the appropriate differentiation
3301
- // strategy, and use it.
3302
- // Find the corresponding getter.
3303
- auto *getterDecl = seai->getField ()->getGetter ();
3304
- assert (getterDecl);
3305
- auto *getterFn = getModule ().lookUpFunction (
3306
- SILDeclRef (getterDecl, SILDeclRef::Kind::Func));
3307
- auto *structDecl = seai->getStructDecl ();
3308
- if (!getterFn ||
3309
- structDecl->getAttrs ().hasAttribute <FieldwiseDifferentiableAttr>()) {
3310
- strategies[seai] = StructExtractDifferentiationStrategy::Fieldwise;
3311
- SILClonerWithScopes::visitStructElementAddrInst (seai);
3312
- return ;
3313
- }
3314
- // The FieldwiseProductSpace strategy is not appropriate, so use the Getter
3315
- // strategy.
3316
- assert (getterFn);
3317
- strategies[seai] = StructExtractDifferentiationStrategy::Getter;
3318
- SILAutoDiffIndices indices (/* source*/ 0 ,
3319
- AutoDiffIndexSubset::getDefault (getASTContext (), 1 , true ));
3320
- auto *attr = context.lookUpDifferentiableAttr (getterFn, indices);
3321
- if (!attr) {
3322
- context.emitNondifferentiabilityError (
3323
- seai, invoker, diag::autodiff_property_not_differentiable);
3324
- errorOccurred = true ;
3325
- return ;
3326
- }
3327
- // Set generic context scope before getting VJP function type.
3328
- auto vjpGenSig = SubsMap.getGenericSignature ()
3329
- ? SubsMap.getGenericSignature ()->getCanonicalSignature ()
3330
- : nullptr ;
3331
- Lowering::GenericContextScope genericContextScope (
3332
- context.getTypeConverter (), vjpGenSig);
3333
- // Reference the getter VJP.
3334
- auto loc = seai->getLoc ();
3335
- auto *getterVJP = getModule ().lookUpFunction (attr->getVJPName ());
3336
- assert (getterVJP && " Expected to find getter VJP" );
3337
- auto vjpFnTy = getterVJP->getLoweredFunctionType ();
3338
- auto *getterVJPRef = getBuilder ().createFunctionRef (loc, getterVJP);
3339
- // Store getter VJP arguments and indirect result buffers.
3340
- SmallVector<SILValue, 8 > vjpArgs;
3341
- SmallVector<AllocStackInst *, 8 > vjpIndirectResults;
3342
- for (auto indRes : vjpFnTy->getIndirectFormalResults ()) {
3343
- auto *alloc = getBuilder ().createAllocStack (
3344
- loc, getOpType (indRes.getSILStorageType ()));
3345
- vjpArgs.push_back (alloc);
3346
- vjpIndirectResults.push_back (alloc);
3347
- }
3348
- vjpArgs.push_back (getOpValue (seai->getOperand ()));
3349
- // Apply the getter VJP.
3350
- auto *getterVJPApply = getBuilder ().createApply (
3351
- loc, getterVJPRef,
3352
- getOpSubstitutionMap (getterVJP->getForwardingSubstitutionMap ()),
3353
- vjpArgs, /* isNonThrowing*/ false );
3354
- // Collect all results from `getterVJPApply` in type-defined order.
3355
- SmallVector<SILValue, 8 > vjpDirectResults;
3356
- extractAllElements (getterVJPApply, getBuilder (), vjpDirectResults);
3357
- SmallVector<SILValue, 8 > allResults;
3358
- collectAllActualResultsInTypeOrder (
3359
- getterVJPApply, vjpDirectResults,
3360
- getterVJPApply->getIndirectSILResults (), allResults);
3361
- // Deallocate VJP indirect results.
3362
- for (auto alloc : vjpIndirectResults)
3363
- getBuilder ().createDeallocStack (loc, alloc);
3364
- auto originalDirectResult = allResults[indices.source ];
3365
- // Map original result.
3366
- mapValue (seai, originalDirectResult);
3367
- // Checkpoint the pullback.
3368
- SILValue pullback = vjpDirectResults.back ();
3369
- // TODO: Check whether it's necessary to reabstract getter pullbacks.
3370
- pullbackInfo.addPullbackDecl (seai, getOpType (pullback->getType ()));
3371
- pullbackValues[seai->getParent ()].push_back (pullback);
3247
+ // This instruction is active. Use the field wise differentiation strategy
3248
+ // to differentiate the struct extract instruction.
3249
+ strategies[seai] = StructExtractDifferentiationStrategy::Fieldwise;
3250
+ SILClonerWithScopes::visitStructElementAddrInst (seai);
3372
3251
}
3373
3252
3374
3253
// If an `apply` has active results or active inout parameters, replace it
@@ -4839,29 +4718,6 @@ class AdjointEmitter final : public SILInstructionVisitor<AdjointEmitter> {
4839
4718
}
4840
4719
return ;
4841
4720
}
4842
- case StructExtractDifferentiationStrategy::Getter: {
4843
- // Get the pullback.
4844
- auto *pullbackField = getPullbackInfo ().lookUpPullbackDecl (sei);
4845
- assert (pullbackField);
4846
- auto pullback = builder.createStructExtract (
4847
- loc, getAdjointBlockPullbackStructArgument (sei->getParent ()),
4848
- pullbackField);
4849
-
4850
- // Construct the pullback arguments.
4851
- auto av = takeAdjointValue (sei);
4852
- auto vector = materializeAdjointDirect (std::move (av), loc);
4853
-
4854
- // Call the pullback.
4855
- auto *pullbackCall = builder.createApply (
4856
- loc, pullback, SubstitutionMap (), {vector}, /* isNonThrowing*/ false );
4857
- assert (!pullbackCall->hasIndirectResults ());
4858
-
4859
- // Accumulate adjoint for the `struct_extract` operand.
4860
- addAdjointValue (sei->getOperand (),
4861
- makeConcreteAdjointValue (
4862
- ValueWithCleanup (pullbackCall, vector.getCleanup ())));
4863
- break ;
4864
- }
4865
4721
}
4866
4722
}
4867
4723
0 commit comments