Skip to content

Commit 1052d3c

Browse files
committed
[AutoDiff] Simplify basic block active value collection.
Make `recordValueIfActive` short-circuit, returning true on error. Remove booleans tracking whether diagnostics have been emitted.
1 parent 1902ff2 commit 1052d3c

File tree

1 file changed

+39
-34
lines changed

1 file changed

+39
-34
lines changed

lib/SILOptimizer/Differentiation/PullbackEmitter.cpp

Lines changed: 39 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -797,6 +797,7 @@ bool PullbackEmitter::run() {
797797
"Functions without returns must have been diagnosed");
798798
auto *origExit = &*origExitIt;
799799

800+
// Collect original formal results.
800801
SmallVector<SILValue, 8> origFormalResults;
801802
collectAllFormalResultsInTypeOrder(original, origFormalResults);
802803
for (auto resultIndex : getIndices().results->getIndices()) {
@@ -815,7 +816,7 @@ bool PullbackEmitter::run() {
815816
}
816817
}
817818

818-
// Get dominated active values in original blocks.
819+
// Collect dominated active values in original basic blocks.
819820
// Adjoint values of dominated active values are passed as pullback block
820821
// arguments.
821822
DominanceOrder domOrder(original.getEntryBlock(), domInfo);
@@ -829,15 +830,21 @@ bool PullbackEmitter::run() {
829830
auto &domBBActiveValues = activeValues[domNode->getBlock()];
830831
bbActiveValues.append(domBBActiveValues.begin(), domBBActiveValues.end());
831832
}
832-
// Booleans tracking whether active-value-related errors have been emitted.
833-
// This prevents duplicate diagnostics for the same active values.
834-
bool diagnosedActiveEnumValue = false;
835-
bool diagnosedActiveValueTangentValueCategoryIncompatible = false;
836-
// Mark the activity of a value if it has not yet been visited.
837-
auto markValueActivity = [&](SILValue v) {
833+
// If `v` is active and has not been visited, records it as an active value
834+
// in the original basic block.
835+
// For active values unsupported by differentiation, emits a diagnostic and
836+
// returns true. Otherwise, returns false.
837+
auto recordValueIfActive = [&](SILValue v) -> bool {
838+
// If value is not active, skip.
839+
if (!getActivityInfo().isActive(v, getIndices()))
840+
return false;
841+
// If active value has already been visited, skip.
838842
if (visited.count(v))
839-
return;
843+
return false;
844+
// Mark active value as visited.
840845
visited.insert(v);
846+
847+
// Diagnose unsupported active values.
841848
auto type = v->getType();
842849
// Diagnose active values whose value category is incompatible with their
843850
// tangent types's value category.
@@ -851,56 +858,54 @@ bool PullbackEmitter::run() {
851858
// $*A | $L | Yes (can create $*L adjoint buffer)
852859
// $L | $*A | No (cannot create $A adjoint value)
853860
// $*A | $*A | Yes (no mismatch)
854-
if (!diagnosedActiveValueTangentValueCategoryIncompatible) {
855-
if (auto tanSpace = getTangentSpace(remapType(type).getASTType())) {
856-
auto tanASTType = tanSpace->getCanonicalType();
857-
auto &origTL = getTypeLowering(type.getASTType());
858-
auto &tanTL = getTypeLowering(tanASTType);
859-
if (!origTL.isAddressOnly() && tanTL.isAddressOnly()) {
860-
getContext().emitNondifferentiabilityError(
861-
v, getInvoker(),
862-
diag::autodiff_loadable_value_addressonly_tangent_unsupported,
863-
type.getASTType(), tanASTType);
864-
diagnosedActiveValueTangentValueCategoryIncompatible = true;
865-
errorOccurred = true;
866-
}
861+
if (auto tanSpace = getTangentSpace(remapType(type).getASTType())) {
862+
auto tanASTType = tanSpace->getCanonicalType();
863+
auto &origTL = getTypeLowering(type.getASTType());
864+
auto &tanTL = getTypeLowering(tanASTType);
865+
if (!origTL.isAddressOnly() && tanTL.isAddressOnly()) {
866+
getContext().emitNondifferentiabilityError(
867+
v, getInvoker(),
868+
diag::autodiff_loadable_value_addressonly_tangent_unsupported,
869+
type.getASTType(), tanASTType);
870+
errorOccurred = true;
871+
return true;
867872
}
868873
}
869874
// Do not emit remaining activity-related diagnostics for semantic member
870875
// accessors, which have special-case pullback generation.
871876
if (isSemanticMemberAccessor(&original))
872-
return;
877+
return false;
873878
// Diagnose active enum values. Differentiation of enum values requires
874879
// special adjoint value handling and is not yet supported. Diagnose
875880
// only the first active enum value to prevent too many diagnostics.
876-
if (!diagnosedActiveEnumValue && type.getEnumOrBoundGenericEnum()) {
881+
if (type.getEnumOrBoundGenericEnum()) {
877882
getContext().emitNondifferentiabilityError(
878883
v, getInvoker(), diag::autodiff_enums_unsupported);
879884
errorOccurred = true;
880-
diagnosedActiveEnumValue = true;
885+
return true;
881886
}
882887
// Skip address projections.
883888
// Address projections do not need their own adjoint buffers; they
884889
// become projections into their adjoint base buffer.
885890
if (Projection::isAddressProjection(v))
886-
return;
891+
return false;
892+
// Record active value.
887893
bbActiveValues.push_back(v);
894+
return false;
888895
};
889-
// Visit bb arguments and all instruction operands/results.
896+
// Record all active values in the basic block.
890897
for (auto *arg : bb->getArguments())
891-
if (getActivityInfo().isActive(arg, getIndices()))
892-
markValueActivity(arg);
898+
if (recordValueIfActive(arg))
899+
return true;
893900
for (auto &inst : *bb) {
894901
for (auto op : inst.getOperandValues())
895-
if (getActivityInfo().isActive(op, getIndices()))
896-
markValueActivity(op);
902+
if (recordValueIfActive(op))
903+
return true;
897904
for (auto result : inst.getResults())
898-
if (getActivityInfo().isActive(result, getIndices()))
899-
markValueActivity(result);
905+
if (recordValueIfActive(result))
906+
return true;
900907
}
901908
domOrder.pushChildren(bb);
902-
if (errorOccurred)
903-
return true;
904909
}
905910

906911
// Create pullback blocks and arguments, visiting original blocks in

0 commit comments

Comments
 (0)