Skip to content

Commit ad54102

Browse files
authored
[AutoDiff] Improve diagnostic for multiple active results. (#28429)
Improve diagnostic message for multiple active results. Add diagnostic tests for multiple active results and active `inout` arguments. Exposes TF-984: forward-mode crash due to unset tangent buffer. Exposes TF-985: no workaround to avoid marking non-wrt `inout` arguments as active.
1 parent 9775a1a commit ad54102

File tree

5 files changed

+240
-24
lines changed

5 files changed

+240
-24
lines changed

include/swift/AST/DiagnosticsSIL.def

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -524,6 +524,8 @@ NOTE(autodiff_when_differentiating_function_definition,none,
524524
"when differentiating this function definition", ())
525525
NOTE(autodiff_cannot_differentiate_through_inout_arguments,none,
526526
"cannot differentiate through 'inout' arguments", ())
527+
NOTE(autodiff_cannot_differentiate_through_multiple_results,none,
528+
"cannot differentiate through multiple results", ())
527529
NOTE(autodiff_class_member_not_supported,none,
528530
"differentiating class members is not yet supported", ())
529531
// TODO(TF-642): Remove when `partial_apply` works with `@differentiable`

lib/SILOptimizer/Mandatory/Differentiation.cpp

Lines changed: 12 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1579,20 +1579,6 @@ class DifferentiableActivityInfo {
15791579
const SILAutoDiffIndices &indices) const;
15801580
};
15811581

1582-
/// Given a parameter argument (not indirect result) and some differentiation
1583-
/// indices, figure out whether the parent function is being differentiated with
1584-
/// respect to this parameter, according to the indices.
1585-
static bool isDifferentiationParameter(SILArgument *argument,
1586-
IndexSubset *indices) {
1587-
if (!argument) return false;
1588-
auto *function = argument->getFunction();
1589-
auto paramArgs = function->getArgumentsWithoutIndirectResults();
1590-
for (unsigned i : indices->getIndices())
1591-
if (paramArgs[i] == argument)
1592-
return true;
1593-
return false;
1594-
}
1595-
15961582
/// For an `apply` instruction with active results, compute:
15971583
/// - The results of the `apply` instruction, in type order.
15981584
/// - The set of minimal parameter and result indices for differentiating the
@@ -1609,9 +1595,7 @@ static void collectMinimalIndicesForFunctionCall(
16091595
// Record all parameter indices in type order.
16101596
unsigned currentParamIdx = 0;
16111597
for (auto applyArg : ai->getArgumentsWithoutIndirectResults()) {
1612-
if (activityInfo.isVaried(applyArg, parentIndices.parameters) ||
1613-
isDifferentiationParameter(dyn_cast<SILArgument>(applyArg),
1614-
parentIndices.parameters))
1598+
if (activityInfo.isActive(applyArg, parentIndices))
16151599
paramIndices.push_back(currentParamIdx);
16161600
++currentParamIdx;
16171601
}
@@ -1632,13 +1616,12 @@ static void collectMinimalIndicesForFunctionCall(
16321616
if (res.isFormalDirect()) {
16331617
results.push_back(directResults[dirResIdx]);
16341618
if (auto dirRes = directResults[dirResIdx])
1635-
if (dirRes && activityInfo.isUseful(dirRes, parentIndices.source))
1619+
if (dirRes && activityInfo.isActive(dirRes, parentIndices))
16361620
resultIndices.push_back(idx);
16371621
++dirResIdx;
16381622
} else {
16391623
results.push_back(indirectResults[indResIdx]);
1640-
if (activityInfo.isUseful(indirectResults[indResIdx],
1641-
parentIndices.source))
1624+
if (activityInfo.isActive(indirectResults[indResIdx], parentIndices))
16421625
resultIndices.push_back(idx);
16431626
++indResIdx;
16441627
}
@@ -3867,10 +3850,12 @@ class VJPEmitter final
38673850
activeResultIndices.begin(), activeResultIndices.end(),
38683851
[&s](unsigned i) { s << i; }, [&s] { s << ", "; });
38693852
s << "}\n";);
3870-
// FIXME: We don't support multiple active results yet.
3853+
// Diagnose multiple active results.
3854+
// TODO(TF-983): Support multiple active results.
38713855
if (activeResultIndices.size() > 1) {
38723856
context.emitNondifferentiabilityError(
3873-
ai, invoker, diag::autodiff_expression_not_differentiable_note);
3857+
ai, invoker,
3858+
diag::autodiff_cannot_differentiate_through_multiple_results);
38743859
errorOccurred = true;
38753860
return;
38763861
}
@@ -5546,13 +5531,16 @@ class JVPEmitter final
55465531
activeResultIndices.begin(), activeResultIndices.end(),
55475532
[&s](unsigned i) { s << i; }, [&s] { s << ", "; });
55485533
s << "}\n";);
5549-
// FIXME: We don't support multiple active results yet.
5534+
// Diagnose multiple active results.
5535+
// TODO(TF-983): Support multiple active results.
55505536
if (activeResultIndices.size() > 1) {
55515537
context.emitNondifferentiabilityError(
5552-
ai, invoker, diag::autodiff_expression_not_differentiable_note);
5538+
ai, invoker,
5539+
diag::autodiff_cannot_differentiate_through_multiple_results);
55535540
errorOccurred = true;
55545541
return;
55555542
}
5543+
55565544
// Form expected indices, assuming there's only one result.
55575545
SILAutoDiffIndices indices(
55585546
activeResultIndices.front(),

test/AutoDiff/activity_analysis.swift

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,106 @@ func TF_954(_ x: Float) -> Float {
205205
// CHECK: [ACTIVE] %40 = begin_access [read] [static] %2 : $*Float
206206
// CHECK: [ACTIVE] %41 = load [trivial] %40 : $*Float
207207

208+
// Check `inout` argument activity.
209+
210+
struct Mut: Differentiable {}
211+
extension Mut {
212+
@differentiable(wrt: x)
213+
mutating func mutatingMethod(_ x: Mut) -> Mut {
214+
return x
215+
}
216+
}
217+
218+
// CHECK-LABEL: [AD] Activity info for $s17activity_analysis3MutV14mutatingMethodyA2CF at (source=0 parameters=(0))
219+
// CHECK: [ACTIVE] %0 = argument of bb0 : $Mut
220+
// CHECK: [NONE] %1 = argument of bb0 : $*Mut
221+
222+
// TODO(TF-985): Find workaround to avoid marking non-wrt `inout` argument as
223+
// active.
224+
// expected-error @+1 {{function is not differentiable}}
225+
@differentiable(wrt: x)
226+
// expected-note @+1 {{when differentiating this function definition}}
227+
func nonActiveInoutArg(_ nonactive: inout Mut, _ x: Mut) -> Mut {
228+
// expected-note @+1 {{cannot differentiate through 'inout' arguments}}
229+
return nonactive.mutatingMethod(x)
230+
}
231+
232+
// CHECK-LABEL: [AD] Activity info for $s17activity_analysis17nonActiveInoutArgyAA3MutVADz_ADtF at (source=0 parameters=(1))
233+
// CHECK: [ACTIVE] %0 = argument of bb0 : $*Mut
234+
// CHECK: [ACTIVE] %1 = argument of bb0 : $Mut
235+
// CHECK: [ACTIVE] %4 = begin_access [modify] [static] %0 : $*Mut
236+
// CHECK: [NONE] // function_ref Mut.mutatingMethod(_:)
237+
// CHECK: [ACTIVE] %6 = apply %5(%1, %4) : $@convention(method) (Mut, @inout Mut) -> Mut
238+
239+
// expected-error @+1 {{function is not differentiable}}
240+
@differentiable(wrt: x)
241+
// expected-note @+1 {{when differentiating this function definition}}
242+
func activeInoutArgMutatingMethod(_ x: Mut) -> Mut {
243+
var result = x
244+
// expected-note @+1 {{cannot differentiate through 'inout' arguments}}
245+
result = result.mutatingMethod(result)
246+
return result
247+
}
248+
249+
// CHECK-LABEL: [AD] Activity info for $s17activity_analysis28activeInoutArgMutatingMethodyAA3MutVADF at (source=0 parameters=(0))
250+
// CHECK: [ACTIVE] %0 = argument of bb0 : $Mut
251+
// CHECK: [ACTIVE] %2 = alloc_stack $Mut, var, name "result"
252+
// CHECK: [ACTIVE] %4 = begin_access [read] [static] %2 : $*Mut
253+
// CHECK: [ACTIVE] %5 = load [trivial] %4 : $*Mut
254+
// CHECK: [ACTIVE] %7 = begin_access [modify] [static] %2 : $*Mut
255+
// CHECK: [NONE] // function_ref Mut.mutatingMethod(_:)
256+
// CHECK: [ACTIVE] %9 = apply %8(%5, %7) : $@convention(method) (Mut, @inout Mut) -> Mut
257+
// CHECK: [ACTIVE] %11 = begin_access [modify] [static] %2 : $*Mut
258+
// CHECK: [ACTIVE] %14 = begin_access [read] [static] %2 : $*Mut
259+
// CHECK: [ACTIVE] %15 = load [trivial] %14 : $*Mut
260+
261+
// expected-error @+1 {{function is not differentiable}}
262+
@differentiable(wrt: x)
263+
// expected-note @+1 {{when differentiating this function definition}}
264+
func activeInoutArgMutatingMethodVar(_ nonactive: inout Mut, _ x: Mut) -> Mut {
265+
var result = nonactive
266+
// expected-note @+1 {{cannot differentiate through 'inout' arguments}}
267+
result = result.mutatingMethod(x)
268+
return result
269+
}
270+
271+
// CHECK_LABEL: [AD] Activity info for $s17activity_analysis31activeInoutArgMutatingMethodVaryAA3MutVADz_ADtF at (source=0 parameters=(1))
272+
// CHECK: [USEFUL] %0 = argument of bb0 : $*Mut
273+
// CHECK: [ACTIVE] %1 = argument of bb0 : $Mut
274+
// CHECK: [ACTIVE] %4 = alloc_stack $Mut, var, name "result"
275+
// CHECK: [USEFUL] %5 = begin_access [read] [static] %0 : $*Mut
276+
// CHECK: [ACTIVE] %8 = begin_access [modify] [static] %4 : $*Mut
277+
// CHECK: [NONE] // function_ref Mut.mutatingMethod(_:)
278+
// CHECK: [ACTIVE] %10 = apply %9(%1, %8) : $@convention(method) (Mut, @inout Mut) -> Mut
279+
// CHECK: [ACTIVE] %12 = begin_access [modify] [static] %4 : $*Mut
280+
// CHECK: [ACTIVE] %15 = begin_access [read] [static] %4 : $*Mut
281+
// CHECK: [ACTIVE] %16 = load [trivial] %15 : $*Mut
282+
283+
// expected-error @+1 {{function is not differentiable}}
284+
@differentiable(wrt: x)
285+
// expected-note @+1 {{when differentiating this function definition}}
286+
func activeInoutArgMutatingMethodTuple(_ nonactive: inout Mut, _ x: Mut) -> Mut {
287+
var result = (nonactive, x)
288+
// expected-note @+1 {{cannot differentiate through 'inout' arguments}}
289+
let result2 = result.0.mutatingMethod(result.0)
290+
return result2
291+
}
292+
293+
// CHECK-LABEL: [AD] Activity info for $s17activity_analysis33activeInoutArgMutatingMethodTupleyAA3MutVADz_ADtF at (source=0 parameters=(1))
294+
// CHECK: [USEFUL] %0 = argument of bb0 : $*Mut
295+
// CHECK: [ACTIVE] %1 = argument of bb0 : $Mut
296+
// CHECK: [ACTIVE] %4 = alloc_stack $(Mut, Mut), var, name "result"
297+
// CHECK: [ACTIVE] %5 = tuple_element_addr %4 : $*(Mut, Mut), 0
298+
// CHECK: [ACTIVE] %6 = tuple_element_addr %4 : $*(Mut, Mut), 1
299+
// CHECK: [USEFUL] %7 = begin_access [read] [static] %0 : $*Mut
300+
// CHECK: [ACTIVE] %11 = begin_access [read] [static] %4 : $*(Mut, Mut)
301+
// CHECK: [ACTIVE] %12 = tuple_element_addr %11 : $*(Mut, Mut), 0
302+
// CHECK: [ACTIVE] %13 = load [trivial] %12 : $*Mut
303+
// CHECK: [ACTIVE] %15 = begin_access [modify] [static] %4 : $*(Mut, Mut)
304+
// CHECK: [ACTIVE] %16 = tuple_element_addr %15 : $*(Mut, Mut), 0
305+
// CHECK: [NONE] // function_ref Mut.mutatingMethod(_:)
306+
// CHECK: [ACTIVE] %18 = apply %17(%13, %16) : $@convention(method) (Mut, @inout Mut) -> Mut
307+
208308
//===----------------------------------------------------------------------===//
209309
// Non-differentiable functions
210310
//===----------------------------------------------------------------------===//

test/AutoDiff/differentiation_transform_diagnostics.swift

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -254,6 +254,22 @@ let no_return: @differentiable (Float) -> Float = { x in
254254
// expected-note @+1 {{missing return for differentiation}}
255255
}
256256

257+
//===----------------------------------------------------------------------===//
258+
// Multiple results
259+
//===----------------------------------------------------------------------===//
260+
261+
func multipleResults(_ x: Float) -> (Float, Float) {
262+
return (x, x)
263+
}
264+
// expected-error @+1 {{function is not differentiable}}
265+
@differentiable
266+
// expected-note @+1 {{when differentiating this function definition}}
267+
func usesMultipleResults(_ x: Float) -> Float {
268+
// expected-note @+1 {{cannot differentiate through multiple results}}
269+
let tuple = multipleResults(x)
270+
return tuple.0 + tuple.1
271+
}
272+
257273
//===----------------------------------------------------------------------===//
258274
// Non-differentiable arguments and results
259275
//===----------------------------------------------------------------------===//
@@ -333,6 +349,54 @@ func activeInoutArgControlFlowComplex(_ array: [Float], _ bool: Bool) -> Float {
333349
return result
334350
}
335351

352+
struct Mut: Differentiable {}
353+
extension Mut {
354+
@differentiable(wrt: x)
355+
mutating func mutatingMethod(_ x: Mut) -> Mut {
356+
return x
357+
}
358+
}
359+
360+
// TODO(TF-985): Find workaround to avoid marking non-wrt `inout` argument as
361+
// active.
362+
// expected-error @+1 {{function is not differentiable}}
363+
@differentiable(wrt: x)
364+
// expected-note @+1 {{when differentiating this function definition}}
365+
func nonActiveInoutArg(_ nonactive: inout Mut, _ x: Mut) -> Mut {
366+
// expected-note @+1 {{cannot differentiate through 'inout' arguments}}
367+
return nonactive.mutatingMethod(x)
368+
}
369+
370+
// expected-error @+1 {{function is not differentiable}}
371+
@differentiable(wrt: x)
372+
// expected-note @+1 {{when differentiating this function definition}}
373+
func activeInoutArgMutatingMethod(_ x: Mut) -> Mut {
374+
var result = x
375+
// expected-note @+1 {{cannot differentiate through 'inout' arguments}}
376+
result = result.mutatingMethod(result)
377+
return result
378+
}
379+
380+
// expected-error @+1 {{function is not differentiable}}
381+
@differentiable(wrt: x)
382+
// expected-note @+1 {{when differentiating this function definition}}
383+
func activeInoutArgMutatingMethodVar(_ nonactive: inout Mut, _ x: Mut) -> Mut {
384+
var result = nonactive
385+
// expected-note @+1 {{cannot differentiate through 'inout' arguments}}
386+
result = result.mutatingMethod(x)
387+
return result
388+
}
389+
390+
// expected-error @+1 {{function is not differentiable}}
391+
@differentiable(wrt: x)
392+
// expected-note @+1 {{when differentiating this function definition}}
393+
func activeInoutArgMutatingMethodTuple(_ nonactive: inout Mut, _ x: Mut) -> Mut {
394+
var result = (nonactive, x)
395+
// expected-note @+1 {{cannot differentiate through 'inout' arguments}}
396+
let result2 = result.0.mutatingMethod(result.0)
397+
return result2
398+
}
399+
336400
//===----------------------------------------------------------------------===//
337401
// Non-varied results
338402
//===----------------------------------------------------------------------===//

test/AutoDiff/forward_mode_diagnostics.swift

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,22 @@ func calls_diff_of_nested(_ x: Float) -> Float {
4949
return derivative(at: x, in: func_to_diff)
5050
}
5151

52+
//===----------------------------------------------------------------------===//
53+
// Multiple results
54+
//===----------------------------------------------------------------------===//
55+
56+
func multipleResults(_ x: Float) -> (Float, Float) {
57+
return (x, x)
58+
}
59+
// expected-error @+1 {{function is not differentiable}}
60+
@differentiable
61+
// expected-note @+1 {{when differentiating this function definition}}
62+
func usesMultipleResults(_ x: Float) -> Float {
63+
// expected-note @+1 {{cannot differentiate through multiple results}}
64+
let tuple = multipleResults(x)
65+
return tuple.0 + tuple.1
66+
}
67+
5268
//===----------------------------------------------------------------------===//
5369
// Inout arguments
5470
//===----------------------------------------------------------------------===//
@@ -85,6 +101,52 @@ func activeInoutArgControlFlow(_ array: [Float]) -> Float {
85101
return result
86102
}
87103

104+
struct Mut: Differentiable {}
105+
extension Mut {
106+
@differentiable(wrt: x)
107+
mutating func mutatingMethod(_ x: Mut) -> Mut {
108+
return x
109+
}
110+
}
111+
112+
// FIXME(TF-985): Forward-mode crash due to unset tangent buffer.
113+
/*
114+
@differentiable(wrt: x)
115+
func nonActiveInoutArg(_ nonactive: inout Mut, _ x: Mut) -> Mut {
116+
return nonactive.mutatingMethod(x)
117+
}
118+
*/
119+
120+
// FIXME(TF-985): Forward-mode crash due to unset tangent buffer.
121+
/*
122+
@differentiable(wrt: x)
123+
func activeInoutArgMutatingMethod(_ x: Mut) -> Mut {
124+
var result = x
125+
result = result.mutatingMethod(result)
126+
return result
127+
}
128+
*/
129+
130+
// FIXME(TF-985): Forward-mode crash due to unset tangent buffer.
131+
/*
132+
@differentiable(wrt: x)
133+
func activeInoutArgMutatingMethodVar(_ nonactive: inout Mut, _ x: Mut) -> Mut {
134+
var result = nonactive
135+
result = result.mutatingMethod(x)
136+
return result
137+
}
138+
*/
139+
140+
// FIXME(TF-985): Forward-mode crash due to unset tangent buffer.
141+
/*
142+
@differentiable(wrt: x)
143+
func activeInoutArgMutatingMethodTuple(_ nonactive: inout Mut, _ x: Mut) -> Mut {
144+
var result = (nonactive, x)
145+
let result2 = result.0.mutatingMethod(result.0)
146+
return result2
147+
}
148+
*/
149+
88150
//===----------------------------------------------------------------------===//
89151
// Non-varied results
90152
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)