@@ -472,9 +472,9 @@ class PrimalInfo {
472
472
// / corresponding tape of its type.
473
473
DenseMap<ApplyInst *, VarDecl *> nestedStaticPrimalValueMap;
474
474
475
- // / Mapping from `apply` instructions in the original function to the
476
- // / corresponding pullback decl in the primal struct.
477
- DenseMap<ApplyInst *, VarDecl *> pullbackValueMap;
475
+ // / Mapping from `apply` and `struct_extract` instructions in the original
476
+ // / function to the corresponding pullback decl in the primal struct.
477
+ DenseMap<SILInstruction *, VarDecl *> pullbackValueMap;
478
478
479
479
// / Mapping from types of control-dependent nested primal values to district
480
480
// / tapes.
@@ -573,7 +573,7 @@ class PrimalInfo {
573
573
}
574
574
575
575
// / Add a pullback to the primal value struct.
576
- VarDecl *addPullbackDecl (ApplyInst *inst, Type pullbackType) {
576
+ VarDecl *addPullbackDecl (SILInstruction *inst, Type pullbackType) {
577
577
// Decls must have AST types (not `SILFunctionType`), so we convert the
578
578
// `SILFunctionType` of the pullback to a `FunctionType` with the same
579
579
// parameters and results.
@@ -605,9 +605,9 @@ class PrimalInfo {
605
605
: lookup->getSecond ();
606
606
}
607
607
608
- // / Finds the pullback decl in the primal value struct for an `apply` in the
609
- // / original function.
610
- VarDecl *lookUpPullbackDecl (ApplyInst *inst) {
608
+ // / Finds the pullback decl in the primal value struct for an `apply` or
609
+ // / `struct_extract` in the original function.
610
+ VarDecl *lookUpPullbackDecl (SILInstruction *inst) {
611
611
auto lookup = pullbackValueMap.find (inst);
612
612
return lookup == pullbackValueMap.end () ? nullptr
613
613
: lookup->getSecond ();
@@ -2227,6 +2227,79 @@ class PrimalGenCloner final : public SILClonerWithScopes<PrimalGenCloner> {
2227
2227
SILClonerWithScopes::visitReleaseValueInst (rvi);
2228
2228
}
2229
2229
2230
+ void visitStructExtractInst (StructExtractInst *sei) {
2231
+ // Special handling logic only applies when the `struct_extract` is active.
2232
+ // If not, just do standard cloning.
2233
+ if (!activityInfo.isActive (sei, synthesis.indices )) {
2234
+ LLVM_DEBUG (getADDebugStream () << " Not active:\n " << *sei << ' \n ' );
2235
+ SILClonerWithScopes::visitStructExtractInst (sei);
2236
+ return ;
2237
+ }
2238
+
2239
+ // This instruction is active. Replace it with a call to the corresponding
2240
+ // getter's VJP.
2241
+
2242
+ // Find the corresponding getter and its VJP.
2243
+ auto *getterDecl = sei->getField ()->getGetter ();
2244
+ assert (getterDecl);
2245
+ auto *getterFn = getContext ().getModule ().lookUpFunction (
2246
+ SILDeclRef (getterDecl, SILDeclRef::Kind::Func));
2247
+ if (!getterFn) {
2248
+ getContext ().emitNondifferentiabilityError (
2249
+ sei, synthesis.task , diag::autodiff_property_not_differentiable);
2250
+ errorOccurred = true ;
2251
+ return ;
2252
+ }
2253
+ auto getterDiffAttrs = getterFn->getDifferentiableAttrs ();
2254
+ if (getterDiffAttrs.size () < 1 ) {
2255
+ getContext ().emitNondifferentiabilityError (
2256
+ sei, synthesis.task , diag::autodiff_property_not_differentiable);
2257
+ errorOccurred = true ;
2258
+ return ;
2259
+ }
2260
+ auto *getterDiffAttr = getterDiffAttrs[0 ];
2261
+ if (!getterDiffAttr->hasVJP ()) {
2262
+ getContext ().emitNondifferentiabilityError (
2263
+ sei, synthesis.task , diag::autodiff_property_not_differentiable);
2264
+ errorOccurred = true ;
2265
+ return ;
2266
+ }
2267
+ assert (getterDiffAttr->getIndices () ==
2268
+ SILAutoDiffIndices (/* source*/ 0 , /* parameters*/ {0 }));
2269
+ auto *getterVJP = lookUpOrLinkFunction (getterDiffAttr->getVJPName (),
2270
+ getContext ().getModule ());
2271
+
2272
+ // Reference and apply the VJP.
2273
+ auto loc = sei->getLoc ();
2274
+ auto *getterVJPRef = getBuilder ().createFunctionRef (loc, getterVJP);
2275
+ auto *getterVJPApply = getBuilder ().createApply (
2276
+ loc, getterVJPRef, /* substitutionMap*/ {},
2277
+ /* args*/ {getMappedValue (sei->getOperand ())}, /* isNonThrowing*/ false );
2278
+
2279
+ // Get the VJP results (original results and pullback).
2280
+ SmallVector<SILValue, 8 > vjpDirectResults;
2281
+ extractAllElements (getterVJPApply, getBuilder (), vjpDirectResults);
2282
+ ArrayRef<SILValue> originalDirectResults =
2283
+ ArrayRef<SILValue>(vjpDirectResults).drop_back (1 );
2284
+ SILValue originalDirectResult = joinElements (originalDirectResults,
2285
+ getBuilder (),
2286
+ getterVJPApply->getLoc ());
2287
+ SILValue pullback = vjpDirectResults.back ();
2288
+
2289
+ // Store the original result to the value map.
2290
+ mapValue (sei, originalDirectResult);
2291
+
2292
+ // Checkpoint the original results.
2293
+ getPrimalInfo ().addStaticPrimalValueDecl (sei);
2294
+ getBuilder ().createRetainValue (loc, originalDirectResult,
2295
+ getBuilder ().getDefaultAtomicity ());
2296
+ staticPrimalValues.push_back (originalDirectResult);
2297
+
2298
+ // Checkpoint the pullback.
2299
+ getPrimalInfo ().addPullbackDecl (sei, pullback->getType ().getASTType ());
2300
+ staticPrimalValues.push_back (pullback);
2301
+ }
2302
+
2230
2303
void visitApplyInst (ApplyInst *ai) {
2231
2304
if (DifferentiationUseVJP)
2232
2305
visitApplyInstWithVJP (ai);
@@ -3522,33 +3595,36 @@ class AdjointEmitter final : public SILInstructionVisitor<AdjointEmitter> {
3522
3595
}
3523
3596
}
3524
3597
3525
- // / Handle `struct_extract` instruction.
3526
- // / y = struct_extract <key>, x
3527
- // / adj[x] = struct (0, ..., key: adj[y], ..., 0)
3528
3598
void visitStructExtractInst (StructExtractInst *sei) {
3529
- auto *structDecl = sei->getStructDecl ();
3530
- auto av = getAdjointValue (sei);
3531
- switch (av.getKind ()) {
3532
- case AdjointValue::Kind::Zero:
3533
- addAdjointValue (sei->getOperand (),
3534
- AdjointValue::getZero (sei->getOperand ()->getType ()));
3535
- break ;
3536
- case AdjointValue::Kind::Materialized:
3537
- case AdjointValue::Kind::Aggregate: {
3538
- SmallVector<AdjointValue, 8 > eltVals;
3539
- for (auto *field : structDecl->getStoredProperties ()) {
3540
- if (field == sei->getField ())
3541
- eltVals.push_back (av);
3542
- else
3543
- eltVals.push_back (AdjointValue::getZero (
3544
- SILType::getPrimitiveObjectType (
3545
- field->getType ()->getCanonicalType ())));
3546
- }
3547
- addAdjointValue (sei->getOperand (),
3548
- AdjointValue::getAggregate (sei->getOperand ()->getType (),
3549
- eltVals, allocator));
3550
- }
3599
+ // Replace a `struct_extract` with a call to its pullback.
3600
+ auto loc = remapLocation (sei->getLoc ());
3601
+
3602
+ // Get the pullback.
3603
+ auto *pullbackField = getPrimalInfo ().lookUpPullbackDecl (sei);
3604
+ if (!pullbackField) {
3605
+ // Inactive `struct_extract` instructions don't need to be cloned into the
3606
+ // adjoint.
3607
+ assert (!activityInfo.isActive (sei, synthesis.indices ));
3608
+ return ;
3551
3609
}
3610
+ SILValue pullback = builder.createStructExtract (loc,
3611
+ primalValueAggregateInAdj,
3612
+ pullbackField);
3613
+
3614
+ // Construct the pullback arguments.
3615
+ SmallVector<SILValue, 8 > args;
3616
+ auto seed = getAdjointValue (sei);
3617
+ assert (seed.getType ().isObject ());
3618
+ args.push_back (materializeAdjointDirect (seed, loc));
3619
+
3620
+ // Call the pullback.
3621
+ auto *pullbackCall = builder.createApply (loc, pullback, SubstitutionMap (),
3622
+ args, /* isNonThrowing*/ false );
3623
+ assert (!pullbackCall->hasIndirectResults ());
3624
+
3625
+ // Set adjoint for the `struct_extract` operand.
3626
+ addAdjointValue (sei->getOperand (),
3627
+ AdjointValue::getMaterialized (pullbackCall));
3552
3628
}
3553
3629
3554
3630
// / Handle `tuple` instruction.
0 commit comments