@@ -1331,21 +1331,28 @@ void DifferentiableActivityInfo::analyze(DominanceInfo *di,
1331
1331
if (isVaried (cai->getSrc (), i))
1332
1332
recursivelySetVaried (cai->getDest (), i);
1333
1333
}
1334
- // Handle `struct_extract`.
1335
- else if (auto *sei = dyn_cast<StructExtractInst>(&inst)) {
1336
- if (isVaried (sei->getOperand (), i)) {
1337
- // If `@noDerivative` exists on the field while the struct is
1338
- // `@_fieldwiseDifferentiable`, this field is not in the set of
1339
- // differentiable variables that we want to track the variedness of.
1340
- auto hasNoDeriv = sei->getField ()->getAttrs ()
1341
- .hasAttribute <NoDerivativeAttr>();
1342
- auto structIsFieldwiseDiffable = sei->getStructDecl ()->getAttrs ()
1343
- .hasAttribute <FieldwiseDifferentiableAttr>();
1344
- if (!(hasNoDeriv && structIsFieldwiseDiffable))
1345
- for (auto result : inst.getResults ())
1346
- setVaried (result, i);
1347
- }
1348
- }
1334
+
1335
+ // Handle `struct_extract` and `struct_element_addr` instructions.
1336
+ // - If the field is marked `@noDerivative` and belongs to a
1337
+ // `@_fieldwiseDifferentiable` struct, do not set the result as varied because
1338
+ // it is not in the set of differentiable variables.
1339
+ // - Otherwise, propagate variedness from operand to result as usual.
1340
+ #define PROPAGATE_VARIED_FOR_STRUCT_EXTRACTION (INST ) \
1341
+ else if (auto *sei = dyn_cast<INST##Inst>(&inst)) { \
1342
+ if (isVaried (sei->getOperand (), i)) { \
1343
+ auto hasNoDeriv = sei->getField ()->getAttrs () \
1344
+ .hasAttribute <NoDerivativeAttr>(); \
1345
+ auto structIsFieldwiseDiffable = sei->getStructDecl ()->getAttrs () \
1346
+ .hasAttribute <FieldwiseDifferentiableAttr>(); \
1347
+ if (!(hasNoDeriv && structIsFieldwiseDiffable)) \
1348
+ for (auto result : inst.getResults ()) \
1349
+ setVaried (result, i); \
1350
+ } \
1351
+ }
1352
+ PROPAGATE_VARIED_FOR_STRUCT_EXTRACTION (StructExtract)
1353
+ PROPAGATE_VARIED_FOR_STRUCT_EXTRACTION (StructElementAddr)
1354
+ #undef VISIT_STRUCT_ELEMENT_INNS
1355
+
1349
1356
// Handle everything else.
1350
1357
else {
1351
1358
for (auto &op : inst.getAllOperands ())
@@ -3630,6 +3637,37 @@ class AdjointEmitter final : public SILInstructionVisitor<AdjointEmitter> {
3630
3637
assert (insertion.second ); (void )insertion;
3631
3638
}
3632
3639
3640
+ SILValue getAdjointProjection (SILValue originalProjection) {
3641
+ // Handle `struct_element_addr`.
3642
+ if (auto *seai = dyn_cast<StructElementAddrInst>(originalProjection)) {
3643
+ auto adjBase = getAdjointBuffer (seai->getOperand ());
3644
+ auto *cotangentVectorDecl =
3645
+ adjBase.getType ().getStructOrBoundGenericStruct ();
3646
+ auto cotanFieldLookup =
3647
+ cotangentVectorDecl->lookupDirect (seai->getField ()->getName ());
3648
+ assert (cotanFieldLookup.size () == 1 );
3649
+ auto *cotanField = cast<VarDecl>(cotanFieldLookup.front ());
3650
+ return builder.createStructElementAddr (
3651
+ seai->getLoc (), adjBase.getValue (), cotanField);
3652
+ }
3653
+ // Handle `tuple_element_addr`.
3654
+ if (auto *teai = dyn_cast<TupleElementAddrInst>(originalProjection)) {
3655
+ auto adjBase = getAdjointBuffer (teai->getOperand ());
3656
+ return builder.createTupleElementAddr (
3657
+ teai->getLoc (), adjBase.getValue (), teai->getFieldNo ());
3658
+ }
3659
+ // Handle `begin_access`.
3660
+ if (auto *bai = dyn_cast<BeginAccessInst>(originalProjection)) {
3661
+ auto adjBase = getAdjointBuffer (bai->getOperand ());
3662
+ if (errorOccurred)
3663
+ return (bufferMap[originalProjection] = ValueWithCleanup ());
3664
+ return builder.createBeginAccess (
3665
+ bai->getLoc (), adjBase, bai->getAccessKind (), bai->getEnforcement (),
3666
+ /* noNestedConflict*/ false , /* fromBuiltin*/ false );
3667
+ }
3668
+ return SILValue ();
3669
+ }
3670
+
3633
3671
ValueWithCleanup &getAdjointBuffer (SILValue originalBuffer) {
3634
3672
assert (originalBuffer->getType ().isAddress ());
3635
3673
assert (originalBuffer->getFunction () == &getOriginal ());
@@ -3638,59 +3676,24 @@ class AdjointEmitter final : public SILInstructionVisitor<AdjointEmitter> {
3638
3676
if (!insertion.second ) // not inserted
3639
3677
return insertion.first ->getSecond ();
3640
3678
3641
- // Diagnose non-differentiable buffers.
3642
- if (!originalBuffer->getType ().isDifferentiable (getModule ())) {
3643
- getContext ().emitNondifferentiabilityError (
3644
- originalBuffer, getDifferentiationTask ());
3645
- errorOccurred = true ;
3646
- return (bufferMap[originalBuffer] = ValueWithCleanup ());
3679
+ // Diagnose `struct_element_addr` instructions to `@noDerivative` fields.
3680
+ if (auto *seai = dyn_cast<StructElementAddrInst>(originalBuffer)) {
3681
+ if (seai->getField ()->getAttrs ().hasAttribute <NoDerivativeAttr>()) {
3682
+ getContext ().emitNondifferentiabilityError (
3683
+ originalBuffer, getDifferentiationTask (),
3684
+ diag::autodiff_noderivative_stored_property);
3685
+ errorOccurred = true ;
3686
+ return (bufferMap[originalBuffer] = ValueWithCleanup ());
3687
+ }
3647
3688
}
3648
3689
3649
- // Check whether the original buffer is an address-to-address projection.
3650
- // If so, recurse until the buffer is such a projection but its operand is
3651
- // not. Then, get the adjoint buffer of the operand and return a
3652
- // corresponding projection into it.
3653
- if (Projection::isAddressProjection (originalBuffer) &&
3654
- !Projection::isObjectToAddressProjection (originalBuffer)) {
3655
- // Get operand of the projection (i.e. the base memory).
3656
- auto *inst = cast<SingleValueInstruction>(originalBuffer);
3657
- Projection proj (inst);
3658
- auto loc = inst->getLoc ();
3659
- auto base = inst->getOperand (0 );
3660
- // Get the corresponding projection into the adjoint buffer.
3661
- SILValue adjProj;
3662
- auto adjBase = getAdjointBuffer (base);
3663
- if (proj.getKind () == ProjectionKind::Struct) {
3664
- auto *origField = proj.getVarDecl (base->getType ());
3665
- auto *cotangentVectorDecl =
3666
- adjBase.getType ().getStructOrBoundGenericStruct ();
3667
- auto cotanFieldLookup =
3668
- cotangentVectorDecl->lookupDirect (origField->getName ());
3669
- assert (cotanFieldLookup.size () == 1 );
3670
- auto *cotanField = cast<VarDecl>(cotanFieldLookup.front ());
3671
- adjProj = builder.createStructElementAddr (loc, adjBase.getValue (),
3672
- cotanField);
3673
- } else {
3674
- adjProj = proj.createAddressProjection (builder, loc, adjBase.getValue ())
3675
- .get ();
3676
- }
3690
+ // If the original buffer is a projection, return a corresponding projection
3691
+ // into the adjoint buffer.
3692
+ if (auto adjProj = getAdjointProjection (originalBuffer)) {
3677
3693
ValueWithCleanup projWithCleanup (
3678
- adjProj, makeCleanupFromChildren ({adjBase. getCleanup ()} ));
3694
+ adjProj, makeCleanup (adjProj, /* cleanup */ nullptr ));
3679
3695
return (bufferMap[originalBuffer] = projWithCleanup);
3680
3696
}
3681
- // If the original buffer is a `begin_access` instruction, get the adjoint
3682
- // buffer of its operand and return a corresponding `begin_access` into it.
3683
- if (auto *bai = dyn_cast<BeginAccessInst>(originalBuffer)) {
3684
- auto adjBase = getAdjointBuffer (bai->getOperand ());
3685
- if (errorOccurred)
3686
- return (bufferMap[originalBuffer] = ValueWithCleanup ());
3687
- auto *adjAccess = builder.createBeginAccess (
3688
- bai->getLoc (), adjBase, bai->getAccessKind (), bai->getEnforcement (),
3689
- /* noNestedConflict*/ false , /* fromBuiltin*/ false );
3690
- ValueWithCleanup accessWithCleanup (
3691
- adjAccess, makeCleanupFromChildren ({adjBase.getCleanup ()}));
3692
- return (bufferMap[originalBuffer] = accessWithCleanup);
3693
- }
3694
3697
3695
3698
// Set insertion point for local allocation builder: before the last local
3696
3699
// allocation, or at the start of the adjoint entry BB if no local
@@ -3803,6 +3806,17 @@ class AdjointEmitter final : public SILInstructionVisitor<AdjointEmitter> {
3803
3806
SmallVector<SILValue, 8 > origFormalResults;
3804
3807
collectAllFormalResultsInTypeOrder (original, origFormalResults);
3805
3808
auto origResult = origFormalResults[task->getIndices ().source ];
3809
+ // Emit warning if original result is not varied, because it will always have
3810
+ // a zero derivative.
3811
+ if (!activityInfo.isVaried (origResult, task->getIndices ().source )) {
3812
+ // Emit fixit if original result has a valid source location.
3813
+ auto sourceLoc = origResult.getLoc ().getSourceLoc ();
3814
+ if (sourceLoc.isValid ()) {
3815
+ getContext ()
3816
+ .diagnose (sourceLoc, diag::autodiff_nonvaried_result_fixit)
3817
+ .fixItInsertAfter (sourceLoc, " .withoutDerivative()" );
3818
+ }
3819
+ }
3806
3820
3807
3821
builder.setInsertionPoint (adjointEntry);
3808
3822
if (seed->getType ().isAddress ()) {
@@ -4229,6 +4243,9 @@ class AdjointEmitter final : public SILInstructionVisitor<AdjointEmitter> {
4229
4243
}
4230
4244
4231
4245
void visitStructExtractInst (StructExtractInst *sei) {
4246
+ assert (!sei->getField ()->getAttrs ().hasAttribute <NoDerivativeAttr>() &&
4247
+ " `struct_extract` with `@noDerivative` field should not be "
4248
+ " differentiated; activity analysis should not marked as varied" );
4232
4249
auto loc = sei->getLoc ();
4233
4250
auto &differentiationStrategies =
4234
4251
getDifferentiationTask ()->getStructExtractDifferentiationStrategies ();
@@ -4562,6 +4579,17 @@ class AdjointEmitter final : public SILInstructionVisitor<AdjointEmitter> {
4562
4579
ValueWithCleanup (adjAccess, makeCleanupFromChildren ({})));
4563
4580
}
4564
4581
4582
+ #define PROPAGATE_BUFFER_CLEANUP (INST ) \
4583
+ void visit##INST##Inst(INST##Inst *inst) { \
4584
+ auto &adjBase = getAdjointBuffer (inst->getOperand ()); \
4585
+ auto &adjProj = getAdjointBuffer (inst); \
4586
+ adjProj.setCleanup (makeCleanupFromChildren ( \
4587
+ {adjProj.getCleanup (), adjBase.getCleanup ()})); \
4588
+ }
4589
+ PROPAGATE_BUFFER_CLEANUP (StructElementAddr)
4590
+ PROPAGATE_BUFFER_CLEANUP (TupleElementAddr)
4591
+ #undef PROPAGATE_CLEANUP
4592
+
4565
4593
#define NOT_DIFFERENTIABLE (INST, DIAG ) \
4566
4594
void visit##INST##Inst(INST##Inst *inst) { \
4567
4595
getContext ().emitNondifferentiabilityError ( \
@@ -4587,10 +4615,6 @@ class AdjointEmitter final : public SILInstructionVisitor<AdjointEmitter> {
4587
4615
NO_ADJOINT (StrongRetainUnowned)
4588
4616
NO_ADJOINT (DestroyValue)
4589
4617
NO_ADJOINT (DestroyAddr)
4590
- // Projection operations have no adjoint visitor.
4591
- // Corresponding adjoint projections are created in `getAdjointBuffer`.
4592
- NO_ADJOINT (StructElementAddr)
4593
- NO_ADJOINT (TupleElementAddr)
4594
4618
#undef NO_DERIVATIVE
4595
4619
};
4596
4620
} // end anonymous namespace
0 commit comments