Skip to content

Commit bb67311

Browse files
dan-zhengrxwei
authored andcommitted
[AutoDiff] Diagnose unsupported forward-mode control flow. (#27684)
Diagnose unsupported forward-mode control flow instead of crashing.
1 parent 740b63e commit bb67311

File tree

3 files changed

+26
-0
lines changed

3 files changed

+26
-0
lines changed

include/swift/AST/DiagnosticsSIL.def

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -513,6 +513,8 @@ NOTE(autodiff_class_member_not_supported,none,
513513
NOTE(autodiff_cannot_param_subset_thunk_partially_applied_orig_fn,none,
514514
"cannot convert a direct method reference to a '@differentiable' "
515515
"function; use an explicit closure instead", ())
516+
NOTE(autodiff_jvp_control_flow_not_supported,none,
517+
"forward-mode differentiation does not yet support control flow", ())
516518
NOTE(autodiff_control_flow_not_supported,none,
517519
"cannot differentiate unsupported control flow", ())
518520
// TODO(TF-645): Remove when differentiation supports `ref_element_addr`.

lib/SILOptimizer/Mandatory/Differentiation.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8030,6 +8030,15 @@ bool ADContext::processDifferentiableAttribute(
80308030
// generation because generated JVP may not match semantics of custom VJP.
80318031
// Instead, create an empty JVP.
80328032
if (RunJVPGeneration && !vjp) {
8033+
// JVP and differential generation do not currently support functions with
8034+
// multiple basic blocks.
8035+
if (original->getBlocks().size() > 1) {
8036+
emitNondifferentiabilityError(
8037+
original->getLocation().getSourceLoc(), invoker,
8038+
diag::autodiff_jvp_control_flow_not_supported);
8039+
return true;
8040+
}
8041+
80338042
JVPEmitter emitter(*this, original, attr, jvp, invoker);
80348043
if (emitter.run())
80358044
return true;

test/AutoDiff/forward_mode_diagnostics.swift

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,3 +94,18 @@ func nondiff(_ f: @differentiable (Float, @nondiff Float) -> Float) -> Float {
9494
// expected-error @+1 {{function is not differentiable}}
9595
return derivative(at: 2, 3) { (x, y) in f(x * x, y) }
9696
}
97+
98+
//===----------------------------------------------------------------------===//
99+
// Control flow
100+
//===----------------------------------------------------------------------===//
101+
102+
// expected-error @+1 {{function is not differentiable}}
103+
@differentiable
104+
// expected-note @+2 {{when differentiating this function definition}}
105+
// expected-note @+1 {{forward-mode differentiation does not yet support control flow}}
106+
func cond(_ x: Float) -> Float {
107+
if x > 0 {
108+
return x * x
109+
}
110+
return x + x
111+
}

0 commit comments

Comments
 (0)