Skip to content

Commit d21ea44

Browse files
dan-zhengrxwei
authored andcommitted
[AutoDiff] Enable differentiation wrt "subset indices". (#23887)
Remove this error: `function is differentiable only with respect to a smaller subset of arguments`. The differentiation pass (and "minimal indices" logic) actually supports such differentiation wrt "subset indices". All that is needed is to remove the deprecated error.
1 parent da317a5 commit d21ea44

File tree

3 files changed

+8
-10
lines changed

3 files changed

+8
-10
lines changed

include/swift/AST/DiagnosticsSIL.def

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -396,9 +396,6 @@ NOTE(autodiff_protocol_member_not_differentiable,none,
396396
NOTE(autodiff_protocol_member_subset_indices_not_differentiable,none,
397397
"member is differentiable only with respect to a smaller subset of "
398398
"arguments", ())
399-
NOTE(autodiff_function_subset_indices_not_differentiable,none,
400-
"function is differentiable only with respect to a smaller subset of "
401-
"arguments", ())
402399
NOTE(autodiff_function_assoc_func_requirements_unmet,none,
403400
"function call is not differentiable because generic requirements are not "
404401
"met", ())

lib/SILOptimizer/Mandatory/Differentiation.cpp

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1793,13 +1793,6 @@ emitAssociatedFunctionReference(ADContext &context, SILBuilder &builder,
17931793
if (autodiffFnType->isDifferentiable()) {
17941794
SILValue assocFn = builder.createAutoDiffFunctionExtract(
17951795
original.getLoc(), kind, /*differentiationOrder*/ 1, functionSource);
1796-
if (autodiffFnType->getDifferentiationParameterIndices().test(
1797-
desiredIndices.parameters)) {
1798-
context.emitNondifferentiabilityError(
1799-
original, parentTask,
1800-
diag::autodiff_function_subset_indices_not_differentiable);
1801-
return None;
1802-
}
18031796
SILAutoDiffIndices indices(0, desiredIndices.parameters);
18041797
return std::make_pair(assocFn, indices);
18051798
}

test/AutoDiff/simple_math.swift

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -308,4 +308,12 @@ SimpleMathTests.test("StructGeneric") {
308308
expectEqual(405, gradient(at: 3, in: fifthPower))
309309
}
310310

311+
SimpleMathTests.test("SubsetIndices") {
312+
func train(_ lossFunction: @differentiable (Float, Float) -> Float) {
313+
let y = Float(0)
314+
_ = gradient(at: 0) { x in lossFunction(x, y) }
315+
}
316+
train { x, y in x + y }
317+
}
318+
311319
runAllTests()

0 commit comments

Comments
 (0)