@@ -797,6 +797,7 @@ bool PullbackEmitter::run() {
797
797
" Functions without returns must have been diagnosed" );
798
798
auto *origExit = &*origExitIt;
799
799
800
+ // Collect original formal results.
800
801
SmallVector<SILValue, 8 > origFormalResults;
801
802
collectAllFormalResultsInTypeOrder (original, origFormalResults);
802
803
for (auto resultIndex : getIndices ().results ->getIndices ()) {
@@ -815,7 +816,7 @@ bool PullbackEmitter::run() {
815
816
}
816
817
}
817
818
818
- // Get dominated active values in original blocks.
819
+ // Collect dominated active values in original basic blocks.
819
820
// Adjoint values of dominated active values are passed as pullback block
820
821
// arguments.
821
822
DominanceOrder domOrder (original.getEntryBlock (), domInfo);
@@ -829,15 +830,21 @@ bool PullbackEmitter::run() {
829
830
auto &domBBActiveValues = activeValues[domNode->getBlock ()];
830
831
bbActiveValues.append (domBBActiveValues.begin (), domBBActiveValues.end ());
831
832
}
832
- // Booleans tracking whether active-value-related errors have been emitted.
833
- // This prevents duplicate diagnostics for the same active values.
834
- bool diagnosedActiveEnumValue = false ;
835
- bool diagnosedActiveValueTangentValueCategoryIncompatible = false ;
836
- // Mark the activity of a value if it has not yet been visited.
837
- auto markValueActivity = [&](SILValue v) {
833
+ // If `v` is active and has not been visited, records it as an active value
834
+ // in the original basic block.
835
+ // For active values unsupported by differentiation, emits a diagnostic and
836
+ // returns true. Otherwise, returns false.
837
+ auto recordValueIfActive = [&](SILValue v) -> bool {
838
+ // If value is not active, skip.
839
+ if (!getActivityInfo ().isActive (v, getIndices ()))
840
+ return false ;
841
+ // If active value has already been visited, skip.
838
842
if (visited.count (v))
839
- return ;
843
+ return false ;
844
+ // Mark active value as visited.
840
845
visited.insert (v);
846
+
847
+ // Diagnose unsupported active values.
841
848
auto type = v->getType ();
842
849
// Diagnose active values whose value category is incompatible with their
843
850
// tangent types's value category.
@@ -851,56 +858,54 @@ bool PullbackEmitter::run() {
851
858
// $*A | $L | Yes (can create $*L adjoint buffer)
852
859
// $L | $*A | No (cannot create $A adjoint value)
853
860
// $*A | $*A | Yes (no mismatch)
854
- if (!diagnosedActiveValueTangentValueCategoryIncompatible) {
855
- if (auto tanSpace = getTangentSpace (remapType (type).getASTType ())) {
856
- auto tanASTType = tanSpace->getCanonicalType ();
857
- auto &origTL = getTypeLowering (type.getASTType ());
858
- auto &tanTL = getTypeLowering (tanASTType);
859
- if (!origTL.isAddressOnly () && tanTL.isAddressOnly ()) {
860
- getContext ().emitNondifferentiabilityError (
861
- v, getInvoker (),
862
- diag::autodiff_loadable_value_addressonly_tangent_unsupported,
863
- type.getASTType (), tanASTType);
864
- diagnosedActiveValueTangentValueCategoryIncompatible = true ;
865
- errorOccurred = true ;
866
- }
861
+ if (auto tanSpace = getTangentSpace (remapType (type).getASTType ())) {
862
+ auto tanASTType = tanSpace->getCanonicalType ();
863
+ auto &origTL = getTypeLowering (type.getASTType ());
864
+ auto &tanTL = getTypeLowering (tanASTType);
865
+ if (!origTL.isAddressOnly () && tanTL.isAddressOnly ()) {
866
+ getContext ().emitNondifferentiabilityError (
867
+ v, getInvoker (),
868
+ diag::autodiff_loadable_value_addressonly_tangent_unsupported,
869
+ type.getASTType (), tanASTType);
870
+ errorOccurred = true ;
871
+ return true ;
867
872
}
868
873
}
869
874
// Do not emit remaining activity-related diagnostics for semantic member
870
875
// accessors, which have special-case pullback generation.
871
876
if (isSemanticMemberAccessor (&original))
872
- return ;
877
+ return false ;
873
878
// Diagnose active enum values. Differentiation of enum values requires
874
879
// special adjoint value handling and is not yet supported. Diagnose
875
880
// only the first active enum value to prevent too many diagnostics.
876
- if (!diagnosedActiveEnumValue && type.getEnumOrBoundGenericEnum ()) {
881
+ if (type.getEnumOrBoundGenericEnum ()) {
877
882
getContext ().emitNondifferentiabilityError (
878
883
v, getInvoker (), diag::autodiff_enums_unsupported);
879
884
errorOccurred = true ;
880
- diagnosedActiveEnumValue = true ;
885
+ return true ;
881
886
}
882
887
// Skip address projections.
883
888
// Address projections do not need their own adjoint buffers; they
884
889
// become projections into their adjoint base buffer.
885
890
if (Projection::isAddressProjection (v))
886
- return ;
891
+ return false ;
892
+ // Record active value.
887
893
bbActiveValues.push_back (v);
894
+ return false ;
888
895
};
889
- // Visit bb arguments and all instruction operands/results .
896
+ // Record all active values in the basic block .
890
897
for (auto *arg : bb->getArguments ())
891
- if (getActivityInfo (). isActive ( arg, getIndices () ))
892
- markValueActivity (arg) ;
898
+ if (recordValueIfActive ( arg))
899
+ return true ;
893
900
for (auto &inst : *bb) {
894
901
for (auto op : inst.getOperandValues ())
895
- if (getActivityInfo (). isActive (op, getIndices () ))
896
- markValueActivity (op) ;
902
+ if (recordValueIfActive (op ))
903
+ return true ;
897
904
for (auto result : inst.getResults ())
898
- if (getActivityInfo (). isActive ( result, getIndices () ))
899
- markValueActivity (result) ;
905
+ if (recordValueIfActive ( result))
906
+ return true ;
900
907
}
901
908
domOrder.pushChildren (bb);
902
- if (errorOccurred)
903
- return true ;
904
909
}
905
910
906
911
// Create pullback blocks and arguments, visiting original blocks in
0 commit comments