Skip to content

[AutoDiff] Improve diagnostic for multiple active results. #28429

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions include/swift/AST/DiagnosticsSIL.def
Original file line number Diff line number Diff line change
Expand Up @@ -524,6 +524,8 @@ NOTE(autodiff_when_differentiating_function_definition,none,
"when differentiating this function definition", ())
NOTE(autodiff_cannot_differentiate_through_inout_arguments,none,
"cannot differentiate through 'inout' arguments", ())
NOTE(autodiff_cannot_differentiate_through_multiple_results,none,
"cannot differentiate through multiple results", ())
NOTE(autodiff_class_member_not_supported,none,
"differentiating class members is not yet supported", ())
// TODO(TF-642): Remove when `partial_apply` works with `@differentiable`
Expand Down
36 changes: 12 additions & 24 deletions lib/SILOptimizer/Mandatory/Differentiation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1528,20 +1528,6 @@ class DifferentiableActivityInfo {
const SILAutoDiffIndices &indices) const;
};

/// Given a parameter argument (not indirect result) and some differentiation
/// indices, figure out whether the parent function is being differentiated with
/// respect to this parameter, according to the indices.
static bool isDifferentiationParameter(SILArgument *argument,
IndexSubset *indices) {
if (!argument) return false;
auto *function = argument->getFunction();
auto paramArgs = function->getArgumentsWithoutIndirectResults();
for (unsigned i : indices->getIndices())
if (paramArgs[i] == argument)
return true;
return false;
}

/// For an `apply` instruction with active results, compute:
/// - The results of the `apply` instruction, in type order.
/// - The set of minimal parameter and result indices for differentiating the
Expand All @@ -1558,9 +1544,7 @@ static void collectMinimalIndicesForFunctionCall(
// Record all parameter indices in type order.
unsigned currentParamIdx = 0;
for (auto applyArg : ai->getArgumentsWithoutIndirectResults()) {
if (activityInfo.isVaried(applyArg, parentIndices.parameters) ||
isDifferentiationParameter(dyn_cast<SILArgument>(applyArg),
parentIndices.parameters))
if (activityInfo.isActive(applyArg, parentIndices))
paramIndices.push_back(currentParamIdx);
++currentParamIdx;
}
Expand All @@ -1581,13 +1565,12 @@ static void collectMinimalIndicesForFunctionCall(
if (res.isFormalDirect()) {
results.push_back(directResults[dirResIdx]);
if (auto dirRes = directResults[dirResIdx])
if (dirRes && activityInfo.isUseful(dirRes, parentIndices.source))
if (dirRes && activityInfo.isActive(dirRes, parentIndices))
resultIndices.push_back(idx);
++dirResIdx;
} else {
results.push_back(indirectResults[indResIdx]);
if (activityInfo.isUseful(indirectResults[indResIdx],
parentIndices.source))
if (activityInfo.isActive(indirectResults[indResIdx], parentIndices))
resultIndices.push_back(idx);
++indResIdx;
}
Expand Down Expand Up @@ -3814,10 +3797,12 @@ class VJPEmitter final
activeResultIndices.begin(), activeResultIndices.end(),
[&s](unsigned i) { s << i; }, [&s] { s << ", "; });
s << "}\n";);
// FIXME: We don't support multiple active results yet.
// Diagnose multiple active results.
// TODO(TF-983): Support multiple active results.
if (activeResultIndices.size() > 1) {
context.emitNondifferentiabilityError(
ai, invoker, diag::autodiff_expression_not_differentiable_note);
ai, invoker,
diag::autodiff_cannot_differentiate_through_multiple_results);
errorOccurred = true;
return;
}
Expand Down Expand Up @@ -5493,13 +5478,16 @@ class JVPEmitter final
activeResultIndices.begin(), activeResultIndices.end(),
[&s](unsigned i) { s << i; }, [&s] { s << ", "; });
s << "}\n";);
// FIXME: We don't support multiple active results yet.
// Diagnose multiple active results.
// TODO(TF-983): Support multiple active results.
if (activeResultIndices.size() > 1) {
context.emitNondifferentiabilityError(
ai, invoker, diag::autodiff_expression_not_differentiable_note);
ai, invoker,
diag::autodiff_cannot_differentiate_through_multiple_results);
errorOccurred = true;
return;
}

// Form expected indices, assuming there's only one result.
SILAutoDiffIndices indices(
activeResultIndices.front(),
Expand Down
100 changes: 100 additions & 0 deletions test/AutoDiff/activity_analysis.swift
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,106 @@ func TF_954(_ x: Float) -> Float {
// CHECK: [ACTIVE] %40 = begin_access [read] [static] %2 : $*Float
// CHECK: [ACTIVE] %41 = load [trivial] %40 : $*Float

// Check `inout` argument activity.

struct Mut: Differentiable {}
extension Mut {
@differentiable(wrt: x)
mutating func mutatingMethod(_ x: Mut) -> Mut {
return x
}
}

// CHECK-LABEL: [AD] Activity info for $s17activity_analysis3MutV14mutatingMethodyA2CF at (source=0 parameters=(0))
// CHECK: [ACTIVE] %0 = argument of bb0 : $Mut
// CHECK: [NONE] %1 = argument of bb0 : $*Mut

// TODO(TF-985): Find workaround to avoid marking non-wrt `inout` argument as
// active.
// expected-error @+1 {{function is not differentiable}}
@differentiable(wrt: x)
// expected-note @+1 {{when differentiating this function definition}}
func nonActiveInoutArg(_ nonactive: inout Mut, _ x: Mut) -> Mut {
// expected-note @+1 {{cannot differentiate through 'inout' arguments}}
return nonactive.mutatingMethod(x)
}

// CHECK-LABEL: [AD] Activity info for $s17activity_analysis17nonActiveInoutArgyAA3MutVADz_ADtF at (source=0 parameters=(1))
// CHECK: [ACTIVE] %0 = argument of bb0 : $*Mut
// CHECK: [ACTIVE] %1 = argument of bb0 : $Mut
// CHECK: [ACTIVE] %4 = begin_access [modify] [static] %0 : $*Mut
// CHECK: [NONE] // function_ref Mut.mutatingMethod(_:)
// CHECK: [ACTIVE] %6 = apply %5(%1, %4) : $@convention(method) (Mut, @inout Mut) -> Mut

// expected-error @+1 {{function is not differentiable}}
@differentiable(wrt: x)
// expected-note @+1 {{when differentiating this function definition}}
func activeInoutArgMutatingMethod(_ x: Mut) -> Mut {
var result = x
// expected-note @+1 {{cannot differentiate through 'inout' arguments}}
result = result.mutatingMethod(result)
return result
}

// CHECK-LABEL: [AD] Activity info for $s17activity_analysis28activeInoutArgMutatingMethodyAA3MutVADF at (source=0 parameters=(0))
// CHECK: [ACTIVE] %0 = argument of bb0 : $Mut
// CHECK: [ACTIVE] %2 = alloc_stack $Mut, var, name "result"
// CHECK: [ACTIVE] %4 = begin_access [read] [static] %2 : $*Mut
// CHECK: [ACTIVE] %5 = load [trivial] %4 : $*Mut
// CHECK: [ACTIVE] %7 = begin_access [modify] [static] %2 : $*Mut
// CHECK: [NONE] // function_ref Mut.mutatingMethod(_:)
// CHECK: [ACTIVE] %9 = apply %8(%5, %7) : $@convention(method) (Mut, @inout Mut) -> Mut
// CHECK: [ACTIVE] %11 = begin_access [modify] [static] %2 : $*Mut
// CHECK: [ACTIVE] %14 = begin_access [read] [static] %2 : $*Mut
// CHECK: [ACTIVE] %15 = load [trivial] %14 : $*Mut

// expected-error @+1 {{function is not differentiable}}
@differentiable(wrt: x)
// expected-note @+1 {{when differentiating this function definition}}
func activeInoutArgMutatingMethodVar(_ nonactive: inout Mut, _ x: Mut) -> Mut {
var result = nonactive
// expected-note @+1 {{cannot differentiate through 'inout' arguments}}
result = result.mutatingMethod(x)
return result
}

// CHECK_LABEL: [AD] Activity info for $s17activity_analysis31activeInoutArgMutatingMethodVaryAA3MutVADz_ADtF at (source=0 parameters=(1))
// CHECK: [USEFUL] %0 = argument of bb0 : $*Mut
// CHECK: [ACTIVE] %1 = argument of bb0 : $Mut
// CHECK: [ACTIVE] %4 = alloc_stack $Mut, var, name "result"
// CHECK: [USEFUL] %5 = begin_access [read] [static] %0 : $*Mut
// CHECK: [ACTIVE] %8 = begin_access [modify] [static] %4 : $*Mut
// CHECK: [NONE] // function_ref Mut.mutatingMethod(_:)
// CHECK: [ACTIVE] %10 = apply %9(%1, %8) : $@convention(method) (Mut, @inout Mut) -> Mut
// CHECK: [ACTIVE] %12 = begin_access [modify] [static] %4 : $*Mut
// CHECK: [ACTIVE] %15 = begin_access [read] [static] %4 : $*Mut
// CHECK: [ACTIVE] %16 = load [trivial] %15 : $*Mut

// expected-error @+1 {{function is not differentiable}}
@differentiable(wrt: x)
// expected-note @+1 {{when differentiating this function definition}}
func activeInoutArgMutatingMethodTuple(_ nonactive: inout Mut, _ x: Mut) -> Mut {
var result = (nonactive, x)
// expected-note @+1 {{cannot differentiate through 'inout' arguments}}
let result2 = result.0.mutatingMethod(result.0)
return result2
}

// CHECK-LABEL: [AD] Activity info for $s17activity_analysis33activeInoutArgMutatingMethodTupleyAA3MutVADz_ADtF at (source=0 parameters=(1))
// CHECK: [USEFUL] %0 = argument of bb0 : $*Mut
// CHECK: [ACTIVE] %1 = argument of bb0 : $Mut
// CHECK: [ACTIVE] %4 = alloc_stack $(Mut, Mut), var, name "result"
// CHECK: [ACTIVE] %5 = tuple_element_addr %4 : $*(Mut, Mut), 0
// CHECK: [ACTIVE] %6 = tuple_element_addr %4 : $*(Mut, Mut), 1
// CHECK: [USEFUL] %7 = begin_access [read] [static] %0 : $*Mut
// CHECK: [ACTIVE] %11 = begin_access [read] [static] %4 : $*(Mut, Mut)
// CHECK: [ACTIVE] %12 = tuple_element_addr %11 : $*(Mut, Mut), 0
// CHECK: [ACTIVE] %13 = load [trivial] %12 : $*Mut
// CHECK: [ACTIVE] %15 = begin_access [modify] [static] %4 : $*(Mut, Mut)
// CHECK: [ACTIVE] %16 = tuple_element_addr %15 : $*(Mut, Mut), 0
// CHECK: [NONE] // function_ref Mut.mutatingMethod(_:)
// CHECK: [ACTIVE] %18 = apply %17(%13, %16) : $@convention(method) (Mut, @inout Mut) -> Mut

//===----------------------------------------------------------------------===//
// Non-differentiable functions
//===----------------------------------------------------------------------===//
Expand Down
64 changes: 64 additions & 0 deletions test/AutoDiff/differentiation_transform_diagnostics.swift
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,22 @@ let no_return: @differentiable (Float) -> Float = { x in
// expected-note @+1 {{missing return for differentiation}}
}

//===----------------------------------------------------------------------===//
// Multiple results
//===----------------------------------------------------------------------===//

func multipleResults(_ x: Float) -> (Float, Float) {
return (x, x)
}
// expected-error @+1 {{function is not differentiable}}
@differentiable
// expected-note @+1 {{when differentiating this function definition}}
func usesMultipleResults(_ x: Float) -> Float {
// expected-note @+1 {{cannot differentiate through multiple results}}
let tuple = multipleResults(x)
return tuple.0 + tuple.1
}

//===----------------------------------------------------------------------===//
// Non-differentiable arguments and results
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -333,6 +349,54 @@ func activeInoutArgControlFlowComplex(_ array: [Float], _ bool: Bool) -> Float {
return result
}

struct Mut: Differentiable {}
extension Mut {
@differentiable(wrt: x)
mutating func mutatingMethod(_ x: Mut) -> Mut {
return x
}
}

// TODO(TF-985): Find workaround to avoid marking non-wrt `inout` argument as
// active.
// expected-error @+1 {{function is not differentiable}}
@differentiable(wrt: x)
// expected-note @+1 {{when differentiating this function definition}}
func nonActiveInoutArg(_ nonactive: inout Mut, _ x: Mut) -> Mut {
// expected-note @+1 {{cannot differentiate through 'inout' arguments}}
return nonactive.mutatingMethod(x)
}

// expected-error @+1 {{function is not differentiable}}
@differentiable(wrt: x)
// expected-note @+1 {{when differentiating this function definition}}
func activeInoutArgMutatingMethod(_ x: Mut) -> Mut {
var result = x
// expected-note @+1 {{cannot differentiate through 'inout' arguments}}
result = result.mutatingMethod(result)
return result
}

// expected-error @+1 {{function is not differentiable}}
@differentiable(wrt: x)
// expected-note @+1 {{when differentiating this function definition}}
func activeInoutArgMutatingMethodVar(_ nonactive: inout Mut, _ x: Mut) -> Mut {
var result = nonactive
// expected-note @+1 {{cannot differentiate through 'inout' arguments}}
result = result.mutatingMethod(x)
return result
}

// expected-error @+1 {{function is not differentiable}}
@differentiable(wrt: x)
// expected-note @+1 {{when differentiating this function definition}}
func activeInoutArgMutatingMethodTuple(_ nonactive: inout Mut, _ x: Mut) -> Mut {
var result = (nonactive, x)
// expected-note @+1 {{cannot differentiate through 'inout' arguments}}
let result2 = result.0.mutatingMethod(result.0)
return result2
}

//===----------------------------------------------------------------------===//
// Non-varied results
//===----------------------------------------------------------------------===//
Expand Down
62 changes: 62 additions & 0 deletions test/AutoDiff/forward_mode_diagnostics.swift
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,22 @@ func calls_diff_of_nested(_ x: Float) -> Float {
return derivative(at: x, in: func_to_diff)
}

//===----------------------------------------------------------------------===//
// Multiple results
//===----------------------------------------------------------------------===//

func multipleResults(_ x: Float) -> (Float, Float) {
return (x, x)
}
// expected-error @+1 {{function is not differentiable}}
@differentiable
// expected-note @+1 {{when differentiating this function definition}}
func usesMultipleResults(_ x: Float) -> Float {
// expected-note @+1 {{cannot differentiate through multiple results}}
let tuple = multipleResults(x)
return tuple.0 + tuple.1
}

//===----------------------------------------------------------------------===//
// Inout arguments
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -85,6 +101,52 @@ func activeInoutArgControlFlow(_ array: [Float]) -> Float {
return result
}

struct Mut: Differentiable {}
extension Mut {
@differentiable(wrt: x)
mutating func mutatingMethod(_ x: Mut) -> Mut {
return x
}
}

// FIXME(TF-985): Forward-mode crash due to unset tangent buffer.
/*
@differentiable(wrt: x)
func nonActiveInoutArg(_ nonactive: inout Mut, _ x: Mut) -> Mut {
return nonactive.mutatingMethod(x)
}
*/

// FIXME(TF-985): Forward-mode crash due to unset tangent buffer.
/*
@differentiable(wrt: x)
func activeInoutArgMutatingMethod(_ x: Mut) -> Mut {
var result = x
result = result.mutatingMethod(result)
return result
}
*/

// FIXME(TF-985): Forward-mode crash due to unset tangent buffer.
/*
@differentiable(wrt: x)
func activeInoutArgMutatingMethodVar(_ nonactive: inout Mut, _ x: Mut) -> Mut {
var result = nonactive
result = result.mutatingMethod(x)
return result
}
*/

// FIXME(TF-985): Forward-mode crash due to unset tangent buffer.
/*
@differentiable(wrt: x)
func activeInoutArgMutatingMethodTuple(_ nonactive: inout Mut, _ x: Mut) -> Mut {
var result = (nonactive, x)
let result2 = result.0.mutatingMethod(result.0)
return result2
}
*/

//===----------------------------------------------------------------------===//
// Non-varied results
//===----------------------------------------------------------------------===//
Expand Down