Skip to content

Commit fae2397

Browse files
authored
Merge pull request #2266 from swiftwasm/main
[pull] swiftwasm from main
2 parents 8969481 + c24e529 commit fae2397

File tree

4 files changed

+34
-3
lines changed

4 files changed

+34
-3
lines changed

docs/DifferentiableProgramming.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
# Differentiable Programming Manifesto
22

33
* Authors: [Richard Wei], [Dan Zheng], [Marc Rasi], [Bart Chrzaszcz]
4-
* Status: Partially implemented on master, feature gated under `import _Differentiation`
4+
* Status:
5+
* Partially implemented on main, feature gated under `import _Differentiation`
6+
* Initial proposal [pitched](https://forums.swift.org/t/differentiable-programming-for-gradient-based-machine-learning/42147) with a significantly scoped-down subset of features. Please refer to the linked pitch thread for the latest design discussions and changes.
57

68
## Table of contents
79

include/swift/AST/DiagnosticsSema.def

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4473,13 +4473,19 @@ ERROR(differentiable_function_type_invalid_parameter,none,
44734473
"%select{| and satisfy '%0 == %0.TangentVector'}1, but the enclosing "
44744474
"function type is '@differentiable%select{|(linear)}1'"
44754475
"%select{|; did you want to add '@noDerivative' to this parameter?}2",
4476-
(StringRef, /*tangentVectorEqualsSelf*/ bool,
4476+
(StringRef, /*isLinear*/ bool,
44774477
/*hasValidDifferentiabilityParameter*/ bool))
44784478
ERROR(differentiable_function_type_invalid_result,none,
44794479
"result type '%0' does not conform to 'Differentiable'"
44804480
"%select{| and satisfy '%0 == %0.TangentVector'}1, but the enclosing "
44814481
"function type is '@differentiable%select{|(linear)}1'",
44824482
(StringRef, bool))
4483+
ERROR(differentiable_function_type_no_differentiability_parameters,
4484+
none,
4485+
"'@differentiable' function type requires at least one differentiability "
4486+
"parameter, i.e. a non-'@noDerivative' parameter whose type conforms to "
4487+
"'Differentiable'%select{| with its 'TangentVector' equal to itself}0",
4488+
(/*isLinear*/ bool))
44834489

44844490
// SIL
44854491
ERROR(opened_non_protocol,none,

lib/Sema/TypeCheckType.cpp

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2610,22 +2610,33 @@ TypeResolver::resolveASTFunctionTypeParams(TupleTypeRepr *inputRepr,
26102610
return isDifferentiable(param.getPlainType(),
26112611
/*tangentVectorEqualsSelf*/ isLinear);
26122612
}) != elements.end();
2613+
bool alreadyDiagnosedOneParam = false;
26132614
for (unsigned i = 0, end = inputRepr->getNumElements(); i != end; ++i) {
26142615
auto *eltTypeRepr = inputRepr->getElementType(i);
26152616
auto param = elements[i];
26162617
if (param.isNoDerivative())
26172618
continue;
26182619
auto paramType = param.getPlainType();
2619-
if (isDifferentiable(paramType, /*tangentVectorEqualsSelf*/ isLinear))
2620+
if (isDifferentiable(paramType, isLinear))
26202621
continue;
26212622
auto paramTypeString = paramType->getString();
26222623
auto diagnostic =
26232624
diagnose(eltTypeRepr->getLoc(),
26242625
diag::differentiable_function_type_invalid_parameter,
26252626
paramTypeString, isLinear, hasValidDifferentiabilityParam);
2627+
alreadyDiagnosedOneParam = true;
26262628
if (hasValidDifferentiabilityParam)
26272629
diagnostic.fixItInsert(eltTypeRepr->getLoc(), "@noDerivative ");
26282630
}
2631+
// Reject the case where all parameters have '@noDerivative'.
2632+
if (!alreadyDiagnosedOneParam && !hasValidDifferentiabilityParam) {
2633+
diagnose(
2634+
inputRepr->getLoc(),
2635+
diag::
2636+
differentiable_function_type_no_differentiability_parameters,
2637+
isLinear)
2638+
.highlight(inputRepr->getSourceRange());
2639+
}
26292640
}
26302641

26312642
return elements;

test/AutoDiff/Sema/differentiable_func_type.swift

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,18 @@ let _: (@noDerivative Float, Float) -> Float
119119

120120
let _: @differentiable (Float, @noDerivative Float) -> Float // okay
121121

122+
// expected-error @+1 {{'@differentiable' function type requires at least one differentiability parameter, i.e. a non-'@noDerivative' parameter whose type conforms to 'Differentiable'}}
123+
let _: @differentiable (@noDerivative Float) -> Float
124+
125+
// expected-error @+1 {{'@differentiable' function type requires at least one differentiability parameter, i.e. a non-'@noDerivative' parameter whose type conforms to 'Differentiable'}}
126+
let _: @differentiable (@noDerivative Float, @noDerivative Int) -> Float
127+
128+
// expected-error @+1 {{'@differentiable' function type requires at least one differentiability parameter, i.e. a non-'@noDerivative' parameter whose type conforms to 'Differentiable'}}
129+
let _: @differentiable (@noDerivative Float, @noDerivative Float) -> Float
130+
131+
// expected-error @+1 {{parameter type 'Int' does not conform to 'Differentiable' and satisfy 'Int == Int.TangentVector', but the enclosing function type is '@differentiable(linear)'}}
132+
let _: @differentiable(linear) (@noDerivative Float, Int) -> Float
133+
122134
// expected-error @+1 {{'@noDerivative' may only be used on parameters of '@differentiable' function types}}
123135
let _: (Float) -> @noDerivative Float
124136

0 commit comments

Comments
 (0)