@@ -472,9 +472,9 @@ class PrimalInfo {
472
472
// / corresponding tape of its type.
473
473
DenseMap<ApplyInst *, VarDecl *> nestedStaticPrimalValueMap;
474
474
475
- // / Mapping from `apply` instructions in the original function to the
476
- // / corresponding pullback decl in the primal struct.
477
- DenseMap<ApplyInst *, VarDecl *> pullbackValueMap;
475
+ // / Mapping from `apply` and `struct_extract` instructions in the original
476
+ // / function to the corresponding pullback decl in the primal struct.
477
+ DenseMap<SILInstruction *, VarDecl *> pullbackValueMap;
478
478
479
479
// / Mapping from types of control-dependent nested primal values to district
480
480
// / tapes.
@@ -573,7 +573,7 @@ class PrimalInfo {
573
573
}
574
574
575
575
// / Add a pullback to the primal value struct.
576
- VarDecl *addPullbackDecl (ApplyInst *inst, Type pullbackType) {
576
+ VarDecl *addPullbackDecl (SILInstruction *inst, Type pullbackType) {
577
577
// Decls must have AST types (not `SILFunctionType`), so we convert the
578
578
// `SILFunctionType` of the pullback to a `FunctionType` with the same
579
579
// parameters and results.
@@ -605,9 +605,9 @@ class PrimalInfo {
605
605
: lookup->getSecond ();
606
606
}
607
607
608
- // / Finds the pullback decl in the primal value struct for an `apply` in the
609
- // / original function.
610
- VarDecl *lookUpPullbackDecl (ApplyInst *inst) {
608
+ // / Finds the pullback decl in the primal value struct for an `apply` or
609
+ // / `struct_extract` in the original function.
610
+ VarDecl *lookUpPullbackDecl (SILInstruction *inst) {
611
611
auto lookup = pullbackValueMap.find (inst);
612
612
return lookup == pullbackValueMap.end () ? nullptr
613
613
: lookup->getSecond ();
@@ -714,6 +714,9 @@ class DifferentiationTask {
714
714
// / Note: This is only used when `DifferentiationUseVJP`.
715
715
DenseMap<ApplyInst *, NestedApplyActivity> nestedApplyActivities;
716
716
717
+ DenseMap<StructExtractInst *, NestedStructExtractStrategy>
718
+ nestedStructExtractStrategies;
719
+
717
720
// / Cache for associated functions.
718
721
SILFunction *primal = nullptr ;
719
722
SILFunction *adjoint = nullptr ;
@@ -2227,6 +2230,77 @@ class PrimalGenCloner final : public SILClonerWithScopes<PrimalGenCloner> {
2227
2230
SILClonerWithScopes::visitReleaseValueInst (rvi);
2228
2231
}
2229
2232
2233
+ void visitStructExtractInst (StructExtractInst *sei) {
2234
+ // Special handling logic only applies when the `struct_extract` is active.
2235
+ // If not, just do standard cloning.
2236
+ if (!activityInfo.isActive (sei, synthesis.indices )) {
2237
+ LLVM_DEBUG (getADDebugStream () << " Not active:\n " << *sei << ' \n ' );
2238
+ SILClonerWithScopes::visitStructExtractInst (sei);
2239
+ return ;
2240
+ }
2241
+
2242
+ // This instruction is active. Replace it with a call to the corresponding
2243
+ // getter's VJP.
2244
+
2245
+ // Find the corresponding getter and its VJP.
2246
+ auto *getterDecl = sei->getField ()->getGetter ();
2247
+ assert (getterDecl);
2248
+ auto *getterFn = getContext ().getModule ().lookUpFunction (
2249
+ SILDeclRef (getterDecl, SILDeclRef::Kind::Func));
2250
+ if (!getterFn) {
2251
+ getContext ().emitNondifferentiabilityError (
2252
+ sei, synthesis.task , diag::autodiff_property_not_differentiable);
2253
+ errorOccurred = true ;
2254
+ return ;
2255
+ }
2256
+ auto getterDiffAttrs = getterFn->getDifferentiableAttrs ();
2257
+ if (getterDiffAttrs.size () < 1 ) {
2258
+ getContext ().emitNondifferentiabilityError (
2259
+ sei, synthesis.task , diag::autodiff_property_not_differentiable);
2260
+ errorOccurred = true ;
2261
+ return ;
2262
+ }
2263
+ auto *getterDiffAttr = getterDiffAttrs[0 ];
2264
+ if (!getterDiffAttr->hasVJP ()) {
2265
+ getContext ().emitNondifferentiabilityError (
2266
+ sei, synthesis.task , diag::autodiff_property_not_differentiable);
2267
+ errorOccurred = true ;
2268
+ return ;
2269
+ }
2270
+ assert (getterDiffAttr->getIndices () ==
2271
+ SILAutoDiffIndices (/* source*/ 0 , /* parameters*/ {0 }));
2272
+ auto *getterVJP = lookUpOrLinkFunction (getterDiffAttr->getVJPName (),
2273
+ getContext ().getModule ());
2274
+
2275
+ // Reference and apply the VJP.
2276
+ auto loc = sei->getLoc ();
2277
+ auto *getterVJPRef = getBuilder ().createFunctionRef (loc, getterVJP);
2278
+ auto *getterVJPApply = getBuilder ().createApply (
2279
+ loc, getterVJPRef, /* substitutionMap*/ {},
2280
+ /* args*/ {getMappedValue (sei->getOperand ())}, /* isNonThrowing*/ false );
2281
+
2282
+ // Get the VJP results (original results and pullback)
2283
+ SmallVector<SILValue, 8 > vjpDirectResults;
2284
+ extractAllElements (getterVJPApply, getBuilder (), vjpDirectResults);
2285
+ ArrayRef<SILValue> originalDirectResults =
2286
+ ArrayRef<SILValue>(vjpDirectResults).drop_back (1 );
2287
+ SILValue originalDirectResult = joinElements (originalDirectResults,
2288
+ getBuilder (),
2289
+ getterVJPApply->getLoc ());
2290
+ SILValue pullback = vjpDirectResults.back ();
2291
+
2292
+ // Store the original result to the value map.
2293
+ mapValue (sei, originalDirectResult);
2294
+
2295
+ // Checkpoint the original results.
2296
+ getPrimalInfo ().addStaticPrimalValueDecl (sei);
2297
+ staticPrimalValues.push_back (originalDirectResult);
2298
+
2299
+ // Checkpoint the pullback.
2300
+ getPrimalInfo ().addPullbackDecl (sei, pullback->getType ().getASTType ());
2301
+ staticPrimalValues.push_back (pullback);
2302
+ }
2303
+
2230
2304
void visitApplyInst (ApplyInst *ai) {
2231
2305
if (DifferentiationUseVJP)
2232
2306
visitApplyInstWithVJP (ai);
@@ -3522,33 +3596,50 @@ class AdjointEmitter final : public SILInstructionVisitor<AdjointEmitter> {
3522
3596
}
3523
3597
}
3524
3598
3525
- // / Handle `struct_extract` instruction.
3526
- // / y = struct_extract <key>, x
3527
- // / adj[x] = struct (0, ..., key: adj[y], ..., 0)
3528
3599
void visitStructExtractInst (StructExtractInst *sei) {
3529
- auto *structDecl = sei->getStructDecl ();
3530
- auto av = getAdjointValue (sei);
3531
- switch (av.getKind ()) {
3532
- case AdjointValue::Kind::Zero:
3533
- addAdjointValue (sei->getOperand (),
3534
- AdjointValue::getZero (sei->getOperand ()->getType ()));
3535
- break ;
3536
- case AdjointValue::Kind::Materialized:
3537
- case AdjointValue::Kind::Aggregate: {
3538
- SmallVector<AdjointValue, 8 > eltVals;
3539
- for (auto *field : structDecl->getStoredProperties ()) {
3540
- if (field == sei->getField ())
3541
- eltVals.push_back (av);
3542
- else
3543
- eltVals.push_back (AdjointValue::getZero (
3544
- SILType::getPrimitiveObjectType (
3545
- field->getType ()->getCanonicalType ())));
3546
- }
3547
- addAdjointValue (sei->getOperand (),
3548
- AdjointValue::getAggregate (sei->getOperand ()->getType (),
3549
- eltVals, allocator));
3600
+ // Replace a `struct_extract` with a call to its pullback.
3601
+ auto loc = remapLocation (sei->getLoc ());
3602
+
3603
+ // Get the pullback.
3604
+ auto *pullbackField = getPrimalInfo ().lookUpPullbackDecl (sei);
3605
+ if (!pullbackField) {
3606
+ // Inactive `struct_extract` instructions don't need to be cloned into the
3607
+ // adjoint.
3608
+ assert (!activityInfo.isActive (sei, synthesis.indices ));
3609
+ return ;
3550
3610
}
3611
+ SILValue pullback = builder.createStructExtract (loc,
3612
+ primalValueAggregateInAdj,
3613
+ pullbackField);
3614
+
3615
+ // Construct the pullback arguments.
3616
+ SmallVector<SILValue, 8 > args;
3617
+ auto seed = getAdjointValue (sei);
3618
+ auto *seedBuf = builder.createAllocStack (loc, seed.getType ());
3619
+ materializeAdjointIndirectHelper (seed, seedBuf);
3620
+ if (seed.getType ().isAddressOnly (getModule ()))
3621
+ args.push_back (seedBuf);
3622
+ else {
3623
+ auto access = builder.createBeginAccess (
3624
+ loc, seedBuf, SILAccessKind::Read, SILAccessEnforcement::Static,
3625
+ /* noNestedConflict*/ true ,
3626
+ /* fromBuiltin*/ false );
3627
+ args.push_back (builder.createLoad (
3628
+ loc, access, getBufferLOQ (seed.getSwiftType (), getAdjoint ())));
3629
+ builder.createEndAccess (loc, access, /* aborted*/ false );
3551
3630
}
3631
+
3632
+ // Call the pullback.
3633
+ auto *pullbackCall = builder.createApply (loc, pullback, SubstitutionMap (),
3634
+ args, /* isNonThrowing*/ false );
3635
+ assert (!pullbackCall->hasIndirectResults ());
3636
+
3637
+ // Clean up seed allocation.
3638
+ builder.createDeallocStack (loc, seedBuf);
3639
+
3640
+ // Set adjoint for the `struct_extract` operand.
3641
+ addAdjointValue (sei->getOperand (),
3642
+ AdjointValue::getMaterialized (pullbackCall));
3552
3643
}
3553
3644
3554
3645
// / Handle `tuple` instruction.
0 commit comments