Skip to content

Commit 3a5e486

Browse files
dan-zhengbgogul
authored andcommitted
[AutoDiff] Revamp usefulness propagation in activity analysis. (#28225)
Useful values are those that contribute to (specific) dependent variables, i.e. function results. For addresses: all projections of a useful address should be useful. This has special support: `DifferentiableActivityInfo::propagateUsefulThroughAddress`. Previously: - Usefulness was propagated by iterating through all instructions in post-dominance order. This is not efficient because irrelevant instructions may be visited. - For useful addresses, `propagateUsefulThroughAddress` propagated usefulness one step to projections, but not recursively to users of the projections. This caused some values to incorrectly not be marked useful. Now: - Usefulness is propagated by following use-def chains, starting from dependent variables (function results). This is handled by the following helpers: - `setUsefulAndPropagateToOperands(SILValue, unsigned)`: marks a value as useful and recursively propagates usefulness through defining instruction operands and basic block argument incoming values. - `propagateUseful(SILInstruction *inst, unsigned)`: propagates usefulness to the operands of the given instruction. - `DifferentiableActivityInfo::propagateUsefulThroughAddress` now calls `propagateUseful` to propagate usefulness recursively through users' operands. Effects: - More values are now (correctly) marked as useful, affecting non-differentiability diagnostics for active enum values (TF-956) and for-in loops (TF-957). Both have room for improvement. Resolves control flow differentiation correctness issue: TF-954.
1 parent d0fb9fc commit 3a5e486

File tree

5 files changed

+128
-133
lines changed

5 files changed

+128
-133
lines changed

lib/SILOptimizer/Mandatory/Differentiation.cpp

Lines changed: 108 additions & 110 deletions
Original file line numberDiff line numberDiff line change
@@ -1433,7 +1433,7 @@ class DifferentiableActivityInfo {
14331433
/// Marks the given value as varied and propagates variedness to users.
14341434
void setVariedAndPropagateToUsers(SILValue value,
14351435
unsigned independentVariableIndex);
1436-
/// Propagates variedness for the given operand to its user's results.
1436+
/// Propagates variedness from the given operand to its user's results.
14371437
void propagateVaried(Operand *operand, unsigned independentVariableIndex);
14381438
/// Marks the given value as varied and recursively propagates variedness
14391439
/// inwards (to operands) through projections. Skips `@noDerivative` struct
@@ -1444,8 +1444,18 @@ class DifferentiableActivityInfo {
14441444
void setUseful(SILValue value, unsigned dependentVariableIndex);
14451445
void setUsefulAcrossArrayInitialization(SILValue value,
14461446
unsigned dependentVariableIndex);
1447-
void propagateUsefulThroughBuffer(SILValue value,
1448-
unsigned dependentVariableIndex);
1447+
/// Marks the given value as useful and recursively propagates usefulness to:
1448+
/// - Defining instruction operands, if the value has a defining instruction.
1449+
/// - Incoming values, if the value is a basic block argument.
1450+
void setUsefulAndPropagateToOperands(SILValue value,
1451+
unsigned dependentVariableIndex);
1452+
/// Propagates usefulnesss to the operands of the given instruction.
1453+
void propagateUseful(SILInstruction *inst, unsigned dependentVariableIndex);
1454+
/// Marks the given address as useful and recursively propagates usefulness
1455+
/// inwards (to operands) through projections. Skips `@noDerivative` struct
1456+
/// field projections.
1457+
void propagateUsefulThroughAddress(SILValue value,
1458+
unsigned dependentVariableIndex);
14491459

14501460
public:
14511461
explicit DifferentiableActivityInfo(
@@ -1975,6 +1985,71 @@ void DifferentiableActivityInfo::propagateVaried(
19751985
}
19761986
}
19771987

1988+
void DifferentiableActivityInfo::setUsefulAndPropagateToOperands(
1989+
SILValue value, unsigned dependentVariableIndex) {
1990+
// Skip already-useful values to prevent infinite recursion.
1991+
if (isUseful(value, dependentVariableIndex))
1992+
return;
1993+
if (value->getType().isAddress()) {
1994+
propagateUsefulThroughAddress(value, dependentVariableIndex);
1995+
return;
1996+
}
1997+
setUseful(value, dependentVariableIndex);
1998+
// If the given value is a basic block argument, propagate usefulness to
1999+
// incoming values.
2000+
if (auto *bbArg = dyn_cast<SILPhiArgument>(value)) {
2001+
SmallVector<SILValue, 4> incomingValues;
2002+
bbArg->getSingleTerminatorOperands(incomingValues);
2003+
for (auto incomingValue : incomingValues)
2004+
setUsefulAndPropagateToOperands(incomingValue, dependentVariableIndex);
2005+
return;
2006+
}
2007+
auto *inst = value->getDefiningInstruction();
2008+
if (!inst)
2009+
return;
2010+
propagateUseful(inst, dependentVariableIndex);
2011+
}
2012+
2013+
void DifferentiableActivityInfo::propagateUseful(
2014+
SILInstruction *inst, unsigned dependentVariableIndex) {
2015+
// Propagate usefulness for the given instruction: mark operands as useful and
2016+
// recursively propagate usefulness to defining instructions of operands.
2017+
auto i = dependentVariableIndex;
2018+
// Handle indirect results in `apply`.
2019+
if (auto *ai = dyn_cast<ApplyInst>(inst)) {
2020+
if (isWithoutDerivative(ai->getCallee()))
2021+
return;
2022+
for (auto arg : ai->getArgumentsWithoutIndirectResults())
2023+
setUsefulAndPropagateToOperands(arg, i);
2024+
}
2025+
// Handle store-like instructions:
2026+
// `store`, `store_borrow`, `copy_addr`, `unconditional_checked_cast`
2027+
#define PROPAGATE_USEFUL_THROUGH_STORE(INST) \
2028+
else if (auto *si = dyn_cast<INST##Inst>(inst)) { \
2029+
setUsefulAndPropagateToOperands(si->getSrc(), i); \
2030+
}
2031+
PROPAGATE_USEFUL_THROUGH_STORE(Store)
2032+
PROPAGATE_USEFUL_THROUGH_STORE(StoreBorrow)
2033+
PROPAGATE_USEFUL_THROUGH_STORE(CopyAddr)
2034+
PROPAGATE_USEFUL_THROUGH_STORE(UnconditionalCheckedCastAddr)
2035+
#undef PROPAGATE_USEFUL_THROUGH_STORE
2036+
// Handle struct element extraction, skipping `@noDerivative` fields:
2037+
// `struct_extract`, `struct_element_addr`.
2038+
#define PROPAGATE_USEFUL_THROUGH_STRUCT_EXTRACTION(INST) \
2039+
else if (auto *sei = dyn_cast<INST##Inst>(inst)) { \
2040+
if (!sei->getField()->getAttrs().hasAttribute<NoDerivativeAttr>()) \
2041+
setUsefulAndPropagateToOperands(sei->getOperand(), i); \
2042+
}
2043+
PROPAGATE_USEFUL_THROUGH_STRUCT_EXTRACTION(StructExtract)
2044+
PROPAGATE_USEFUL_THROUGH_STRUCT_EXTRACTION(StructElementAddr)
2045+
#undef PROPAGATE_USEFUL_THROUGH_STRUCT_EXTRACTION
2046+
// Handle everything else.
2047+
else {
2048+
for (auto &op : inst->getAllOperands())
2049+
setUsefulAndPropagateToOperands(op.get(), i);
2050+
}
2051+
}
2052+
19782053
void DifferentiableActivityInfo::analyze(DominanceInfo *di,
19792054
PostDominanceInfo *pdi) {
19802055
auto &function = getFunction();
@@ -2010,117 +2085,40 @@ void DifferentiableActivityInfo::analyze(DominanceInfo *di,
20102085

20112086
// Mark differentiable outputs as useful.
20122087
assert(usefulValueSets.empty());
2013-
for (auto output : outputValues) {
2088+
for (auto outputAndIdx : enumerate(outputValues)) {
2089+
auto output = outputAndIdx.value();
2090+
unsigned i = outputAndIdx.index();
20142091
usefulValueSets.push_back({});
2015-
// If the output has an address or class type, propagate usefulness
2016-
// recursively.
2017-
if (output->getType().isAddress() ||
2018-
output->getType().isClassOrClassMetatype())
2019-
propagateUsefulThroughBuffer(output, usefulValueSets.size() - 1);
2020-
// Otherwise, just mark the output as useful.
2021-
else
2022-
setUseful(output, usefulValueSets.size() - 1);
2023-
}
2024-
// Propagate usefulness through the function in post-dominance order.
2025-
PostDominanceOrder postDomOrder(&*function.findReturnBB(), pdi);
2026-
while (auto *bb = postDomOrder.getNext()) {
2027-
for (auto &inst : llvm::reverse(*bb)) {
2028-
for (auto i : indices(outputValues)) {
2029-
// Handle indirect results in `apply`.
2030-
if (auto *ai = dyn_cast<ApplyInst>(&inst)) {
2031-
if (isWithoutDerivative(ai->getCallee()))
2032-
continue;
2033-
auto checkAndSetUseful = [&](SILValue res) {
2034-
if (isUseful(res, i))
2035-
for (auto arg : ai->getArgumentsWithoutIndirectResults())
2036-
setUseful(arg, i);
2037-
};
2038-
for (auto dirRes : ai->getResults())
2039-
checkAndSetUseful(dirRes);
2040-
for (auto indRes : ai->getIndirectSILResults())
2041-
checkAndSetUseful(indRes);
2042-
auto paramInfos = ai->getSubstCalleeConv().getParameters();
2043-
for (auto i : indices(paramInfos))
2044-
if (paramInfos[i].isIndirectInOut())
2045-
checkAndSetUseful(ai->getArgumentsWithoutIndirectResults()[i]);
2046-
}
2047-
// Handle store-like instructions:
2048-
// `store`, `store_borrow`, `copy_addr`, `unconditional_checked_cast`
2049-
#define PROPAGATE_USEFUL_THROUGH_STORE(INST, PROPAGATE) \
2050-
else if (auto *si = dyn_cast<INST##Inst>(&inst)) { \
2051-
if (isUseful(si->getDest(), i)) \
2052-
PROPAGATE(si->getSrc(), i); \
2053-
}
2054-
PROPAGATE_USEFUL_THROUGH_STORE(Store, setUseful)
2055-
PROPAGATE_USEFUL_THROUGH_STORE(StoreBorrow, setUseful)
2056-
PROPAGATE_USEFUL_THROUGH_STORE(CopyAddr, propagateUsefulThroughBuffer)
2057-
PROPAGATE_USEFUL_THROUGH_STORE(UnconditionalCheckedCastAddr,
2058-
propagateUsefulThroughBuffer)
2059-
#undef PROPAGATE_USEFUL_THROUGH_STORE
2060-
// Handle struct element extraction, skipping `@noDerivative` fields:
2061-
// `struct_extract`, `struct_element_addr`.
2062-
#define PROPAGATE_USEFUL_THROUGH_STRUCT_EXTRACTION(INST, PROPAGATE) \
2063-
else if (auto *sei = dyn_cast<INST##Inst>(&inst)) { \
2064-
if (isUseful(sei, i)) { \
2065-
auto hasNoDeriv = sei->getField()->getAttrs() \
2066-
.hasAttribute<NoDerivativeAttr>(); \
2067-
if (!hasNoDeriv) \
2068-
PROPAGATE(sei->getOperand(), i); \
2069-
} \
2070-
}
2071-
PROPAGATE_USEFUL_THROUGH_STRUCT_EXTRACTION(StructExtract, setUseful)
2072-
PROPAGATE_USEFUL_THROUGH_STRUCT_EXTRACTION(StructElementAddr,
2073-
propagateUsefulThroughBuffer)
2074-
#undef PROPAGATE_USEFUL_THROUGH_STRUCT_EXTRACTION
2075-
// Handle everything else.
2076-
else if (llvm::any_of(inst.getResults(),
2077-
[&](SILValue res) { return isUseful(res, i); })) {
2078-
for (auto &op : inst.getAllOperands()) {
2079-
auto value = op.get();
2080-
if (value->getType().isAddress())
2081-
propagateUsefulThroughBuffer(value, i);
2082-
setUseful(value, i);
2083-
}
2084-
}
2085-
}
2086-
}
2087-
// Propagate usefulness from basic block arguments to incoming phi values.
2088-
for (auto i : indices(outputValues)) {
2089-
for (auto *arg : bb->getArguments()) {
2090-
if (isUseful(arg, i)) {
2091-
SmallVector<SILValue, 4> incomingValues;
2092-
arg->getSingleTerminatorOperands(incomingValues);
2093-
for (auto incomingValue : incomingValues)
2094-
setUseful(incomingValue, i);
2095-
}
2096-
}
2097-
}
2098-
postDomOrder.pushChildren(bb);
2092+
setUsefulAndPropagateToOperands(output, i);
20992093
}
21002094
}
21012095

21022096
void DifferentiableActivityInfo::setUsefulAcrossArrayInitialization(
21032097
SILValue value, unsigned dependentVariableIndex) {
21042098
// Array initializer syntax is lowered to an intrinsic and one or more
21052099
// stores to a `RawPointer` returned by the intrinsic.
2106-
auto uai = getAllocateUninitializedArrayIntrinsic(value);
2100+
auto *uai = getAllocateUninitializedArrayIntrinsic(value);
21072101
if (!uai) return;
21082102
for (auto use : value->getUses()) {
2109-
auto dti = dyn_cast<DestructureTupleInst>(use->getUser());
2103+
auto *dti = dyn_cast<DestructureTupleInst>(use->getUser());
21102104
if (!dti) continue;
21112105
// The second tuple field of the return value is the `RawPointer`.
21122106
for (auto use : dti->getResult(1)->getUses()) {
21132107
// The `RawPointer` passes through a `pointer_to_address`. That
2114-
// instruction's first use is a `store` whose src is useful; its
2108+
// instruction's first use is a `store` whose source is useful; its
21152109
// subsequent uses are `index_addr`s whose only use is a useful `store`.
2116-
for (auto use : use->getUser()->getResult(0)->getUses()) {
2117-
auto inst = use->getUser();
2118-
if (auto si = dyn_cast<StoreInst>(inst)) {
2119-
setUseful(si->getSrc(), dependentVariableIndex);
2120-
} else if (auto iai = dyn_cast<IndexAddrInst>(inst)) {
2110+
auto *ptai = dyn_cast<PointerToAddressInst>(use->getUser());
2111+
assert(ptai && "Expected `pointer_to_address` user for uninitialized "
2112+
"array intrinsic");
2113+
for (auto use : ptai->getUses()) {
2114+
auto *inst = use->getUser();
2115+
if (auto *si = dyn_cast<StoreInst>(inst)) {
2116+
setUsefulAndPropagateToOperands(si->getSrc(), dependentVariableIndex);
2117+
} else if (auto *iai = dyn_cast<IndexAddrInst>(inst)) {
21212118
for (auto use : iai->getUses())
21222119
if (auto si = dyn_cast<StoreInst>(use->getUser()))
2123-
setUseful(si->getSrc(), dependentVariableIndex);
2120+
setUsefulAndPropagateToOperands(si->getSrc(),
2121+
dependentVariableIndex);
21242122
}
21252123
}
21262124
}
@@ -2154,21 +2152,20 @@ void DifferentiableActivityInfo::propagateVariedInwardsThroughProjections(
21542152
op.get(), independentVariableIndex);
21552153
}
21562154

2157-
void DifferentiableActivityInfo::propagateUsefulThroughBuffer(
2155+
void DifferentiableActivityInfo::propagateUsefulThroughAddress(
21582156
SILValue value, unsigned dependentVariableIndex) {
2159-
assert(value->getType().isAddress() ||
2160-
value->getType().isClassOrClassMetatype());
2157+
assert(value->getType().isAddress());
21612158
// Check whether value is already useful to prevent infinite recursion.
21622159
if (isUseful(value, dependentVariableIndex))
21632160
return;
21642161
setUseful(value, dependentVariableIndex);
21652162
if (auto *inst = value->getDefiningInstruction())
2166-
for (auto &operand : inst->getAllOperands())
2167-
if (operand.get()->getType().isAddress())
2168-
propagateUsefulThroughBuffer(operand.get(), dependentVariableIndex);
2163+
propagateUseful(inst, dependentVariableIndex);
21692164
// Recursively propagate usefulness through users that are projections or
21702165
// `begin_access` instructions.
21712166
for (auto use : value->getUses()) {
2167+
// Propagate usefulness through user's operands.
2168+
propagateUseful(use->getUser(), dependentVariableIndex);
21722169
for (auto res : use->getUser()->getResults()) {
21732170
#define SKIP_NODERIVATIVE(INST) \
21742171
if (auto *sei = dyn_cast<INST##Inst>(res)) \
@@ -2178,7 +2175,7 @@ void DifferentiableActivityInfo::propagateUsefulThroughBuffer(
21782175
SKIP_NODERIVATIVE(StructElementAddr)
21792176
#undef SKIP_NODERIVATIVE
21802177
if (Projection::isAddressProjection(res) || isa<BeginAccessInst>(res))
2181-
propagateUsefulThroughBuffer(res, dependentVariableIndex);
2178+
propagateUsefulThroughAddress(res, dependentVariableIndex);
21822179
}
21832180
}
21842181
}
@@ -6219,15 +6216,19 @@ class PullbackEmitter final : public SILInstructionVisitor<PullbackEmitter> {
62196216
SmallPtrSet<SILValue, 8> visited(bbActiveValues.begin(),
62206217
bbActiveValues.end());
62216218
// Register a value as active if it has not yet been visited.
6219+
bool diagnosedActiveEnumValue = false;
62226220
auto addActiveValue = [&](SILValue v) {
62236221
if (visited.count(v))
62246222
return;
6225-
// Diagnose active enum values. Differentiation of enum values is not
6226-
// yet supported; requires special adjoint value handling.
6227-
if (v->getType().getEnumOrBoundGenericEnum()) {
6223+
// Diagnose active enum values. Differentiation of enum values requires
6224+
// special adjoint value handling and is not yet supported. Diagnose
6225+
// only the first active enum value to prevent too many diagnostics.
6226+
if (!diagnosedActiveEnumValue &&
6227+
v->getType().getEnumOrBoundGenericEnum()) {
62286228
getContext().emitNondifferentiabilityError(
62296229
v, getInvoker(), diag::autodiff_enums_unsupported);
62306230
errorOccurred = true;
6231+
diagnosedActiveEnumValue = true;
62316232
}
62326233
// Skip address projections.
62336234
// Address projections do not need their own adjoint buffers; they
@@ -6238,9 +6239,6 @@ class PullbackEmitter final : public SILInstructionVisitor<PullbackEmitter> {
62386239
bbActiveValues.push_back(v);
62396240
};
62406241
// Register bb arguments and all instruction operands/results.
6241-
for (auto *arg : bb->getArguments())
6242-
if (getActivityInfo().isActive(arg, getIndices()))
6243-
addActiveValue(arg);
62446242
for (auto &inst : *bb) {
62456243
for (auto op : inst.getOperandValues())
62466244
if (getActivityInfo().isActive(op, getIndices()))

test/AutoDiff/activity_analysis.swift

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -140,12 +140,12 @@ func TF_954(_ x: Float) -> Float {
140140
// CHECK: bb1:
141141
// CHECK: [ACTIVE] %10 = alloc_stack $Float, var, name "inner"
142142
// CHECK: [ACTIVE] %11 = begin_access [read] [static] %2 : $*Float
143-
// CHECK: [NONE] %14 = metatype $@thin Float.Type
143+
// CHECK: [USEFUL] %14 = metatype $@thin Float.Type
144144
// CHECK: [ACTIVE] %15 = begin_access [read] [static] %10 : $*Float
145-
// CHECK: [VARIED] %16 = load [trivial] %15 : $*Float
145+
// CHECK: [ACTIVE] %16 = load [trivial] %15 : $*Float
146146
// CHECK: [NONE] // function_ref static Float.* infix(_:_:)
147147
// CHECK: %18 = function_ref @$sSf1moiyS2f_SftFZ : $@convention(method) (Float, Float, @thin Float.Type) -> Float
148-
// CHECK: [VARIED] %19 = apply %18(%16, %0, %14) : $@convention(method) (Float, Float, @thin Float.Type) -> Float
148+
// CHECK: [ACTIVE] %19 = apply %18(%16, %0, %14) : $@convention(method) (Float, Float, @thin Float.Type) -> Float
149149
// CHECK: [ACTIVE] %20 = begin_access [modify] [static] %10 : $*Float
150150
// CHECK: bb3:
151151
// CHECK: [ACTIVE] %31 = begin_access [read] [static] %10 : $*Float

test/AutoDiff/control_flow.swift

Lines changed: 5 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -68,10 +68,7 @@ ControlFlowTests.test("Conditionals") {
6868
}
6969
return outer
7070
}
71-
// FIXME(TF-954): Investigate incorrect derivative related to addresses and
72-
// nested control flow.
73-
// expectEqual((9, 6), valueWithGradient(at: 3, in: cond4_var))
74-
expectEqual((9, 0), valueWithGradient(at: 3, in: cond4_var))
71+
expectEqual((9, 6), valueWithGradient(at: 3, in: cond4_var))
7572

7673
func cond_tuple(_ x: Float) -> Float {
7774
// Convoluted function returning `x + x`.
@@ -707,16 +704,10 @@ ControlFlowTests.test("Loops") {
707704
}
708705
return outer
709706
}
710-
// FIXME(TF-954): Investigate incorrect derivative related to addresses and
711-
// nested control flow.
712-
// expectEqual((6, 5), valueWithGradient(at: 2, in: { x in nested_loop2(x, count: 1) }))
713-
// expectEqual((20, 22), valueWithGradient(at: 2, in: { x in nested_loop2(x, count: 2) }))
714-
// expectEqual((52, 80), valueWithGradient(at: 2, in: { x in nested_loop2(x, count: 3) }))
715-
// expectEqual((24, 28), valueWithGradient(at: 2, in: { x in nested_loop2(x, count: 4) }))
716-
expectEqual((6, 0), valueWithGradient(at: 2, in: { x in nested_loop2(x, count: 1) }))
717-
expectEqual((20, 0), valueWithGradient(at: 2, in: { x in nested_loop2(x, count: 2) }))
718-
expectEqual((52, 26), valueWithGradient(at: 2, in: { x in nested_loop2(x, count: 3) }))
719-
expectEqual((24, 12), valueWithGradient(at: 2, in: { x in nested_loop2(x, count: 4) }))
707+
expectEqual((6, 5), valueWithGradient(at: 2, in: { x in nested_loop2(x, count: 1) }))
708+
expectEqual((20, 22), valueWithGradient(at: 2, in: { x in nested_loop2(x, count: 2) }))
709+
expectEqual((52, 80), valueWithGradient(at: 2, in: { x in nested_loop2(x, count: 3) }))
710+
expectEqual((24, 28), valueWithGradient(at: 2, in: { x in nested_loop2(x, count: 4) }))
720711
}
721712

722713
runAllTests()

test/AutoDiff/control_flow_diagnostics.swift

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -113,10 +113,12 @@ enum Tree : Differentiable & AdditiveArithmetic {
113113

114114
// expected-error @+1 {{function is not differentiable}}
115115
@differentiable
116-
// expected-note @+1 {{when differentiating this function definition}}
116+
// TODO(TF-956): Improve location of active enum non-differentiability errors
117+
// so that they are closer to the source of the non-differentiability.
118+
// expected-note @+2 {{when differentiating this function definition}}
119+
// expected-note @+1 {{differentiating enum values is not yet supported}}
117120
static func +(_ lhs: Self, _ rhs: Self) -> Self {
118121
switch (lhs, rhs) {
119-
// expected-note @+1 {{differentiating enum values is not yet supported}}
120122
case let (.leaf(x), .leaf(y)):
121123
return .leaf(x + y)
122124
case let (.branch(x1, x2), .branch(y1, y2)):
@@ -128,10 +130,12 @@ enum Tree : Differentiable & AdditiveArithmetic {
128130

129131
// expected-error @+1 {{function is not differentiable}}
130132
@differentiable
131-
// expected-note @+1 {{when differentiating this function definition}}
133+
// TODO(TF-956): Improve location of active enum non-differentiability errors
134+
// so that they are closer to the source of the non-differentiability.
135+
// expected-note @+2 {{when differentiating this function definition}}
136+
// expected-note @+1 {{differentiating enum values is not yet supported}}
132137
static func -(_ lhs: Self, _ rhs: Self) -> Self {
133138
switch (lhs, rhs) {
134-
// expected-note @+1 {{differentiating enum values is not yet supported}}
135139
case let (.leaf(x), .leaf(y)):
136140
return .leaf(x - y)
137141
case let (.branch(x1, x2), .branch(y1, y2)):
@@ -147,7 +151,9 @@ enum Tree : Differentiable & AdditiveArithmetic {
147151
// expected-note @+1 {{when differentiating this function definition}}
148152
func loop_array(_ array: [Float]) -> Float {
149153
var result: Float = 1
150-
// expected-note @+1 {{differentiating enum values is not yet supported}}
154+
// TODO(TF-957): Improve non-differentiability errors for for-in loops
155+
// (`Collection.makeIterator` and `IteratorProtocol.next`).
156+
// expected-note @+1 {{cannot differentiate through a non-differentiable result; do you want to use 'withoutDerivative(at:)'?}}
151157
for x in array {
152158
result = result * x
153159
}

0 commit comments

Comments
 (0)