Skip to content

Commit c11f5ce

Browse files
committed
Migrate changes from PR #41422
1 parent 17f43eb commit c11f5ce

10 files changed

+56
-42
lines changed

include/swift/AST/AutoDiff.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -364,7 +364,9 @@ class TangentSpace {
364364
static TangentSpace getTangentVector(Type tangentVectorType) {
365365
return {Kind::TangentVector, tangentVectorType};
366366
}
367-
static TangentSpace getTuple(TupleType *tupleTy);
367+
static TangentSpace getTuple(TupleType *tupleTy) {
368+
return {Kind::Tuple, tupleTy};
369+
}
368370

369371
bool isTangentVector() const { return kind == Kind::TangentVector; }
370372
bool isTuple() const { return kind == Kind::Tuple; }

include/swift/AST/DiagnosticsSema.def

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3463,7 +3463,7 @@ ERROR(autodiff_attr_original_void_result,none,
34633463
"cannot differentiate void function %0", (DeclName))
34643464
ERROR(autodiff_attr_original_multiple_semantic_results,none,
34653465
"cannot differentiate functions with both a differentiable 'inout' "
3466-
"parameter and a result", ())
3466+
"parameter and a differentiable result", ())
34673467
ERROR(autodiff_attr_result_not_differentiable,none,
34683468
"can only differentiate functions with results that conform to "
34693469
"'Differentiable', but %0 does not conform to 'Differentiable'", (Type))
@@ -5042,13 +5042,17 @@ ERROR(differentiable_function_type_invalid_result,none,
50425042
(StringRef, bool))
50435043
ERROR(differentiable_function_type_multiple_semantic_results,none,
50445044
"'@differentiable' function type cannot have both a differentiable "
5045-
"'inout' parameter and a result", ())
5046-
ERROR(differentiable_function_type_no_differentiability_parameters,
5047-
none,
5045+
"'inout' parameter and a differentiable result", ())
5046+
ERROR(differentiable_function_type_no_differentiability_parameters,none,
50485047
"'@differentiable' function type requires at least one differentiability "
50495048
"parameter, i.e. a non-'@noDerivative' parameter whose type conforms to "
50505049
"'Differentiable'%select{| with its 'TangentVector' equal to itself}0",
50515050
(/*isLinear*/ bool))
5051+
ERROR(differentiable_function_type_no_differentiable_result,none,
5052+
"'@differentiable' function type requires a differentiable result, i.e. "
5053+
"a non-'Void' type that conforms to 'Differentiable'%select{| with its "
5054+
"'TangentVector' equal to itself}0",
5055+
(/*isLinear*/ bool))
50525056

50535057
// SIL
50545058
ERROR(opened_non_protocol,none,

lib/AST/AutoDiff.cpp

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -366,13 +366,6 @@ GenericSignature autodiff::getDifferentiabilityWitnessGenericSignature(
366366
return derivativeGenSig;
367367
}
368368

369-
TangentSpace TangentSpace::getTuple(TupleType *tupleTy) {
370-
assert(!tupleTy->isVoid() &&
371-
"Attempted to get tangent space of 'Void', which cannot be "
372-
"differentiated.");
373-
return {Kind::Tuple, tupleTy};
374-
}
375-
376369
Type TangentSpace::getType() const {
377370
switch (kind) {
378371
case Kind::TangentVector:

lib/AST/Type.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6189,7 +6189,8 @@ TypeBase::getAutoDiffTangentSpace(LookupConformanceFn lookupConformance) {
61896189
newElts.push_back(elt.getWithType(eltSpace->getType()));
61906190
}
61916191
if (newElts.empty())
6192-
return cache(None);
6192+
return cache(
6193+
TangentSpace::getTuple(ctx.TheEmptyTupleType->castTo<TupleType>()));
61936194
if (newElts.size() == 1)
61946195
return cache(TangentSpace::getTangentVector(newElts.front().getType()));
61956196
auto *tupleType = TupleType::get(newElts, ctx)->castTo<TupleType>();

lib/Sema/TypeChecker.cpp

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -619,15 +619,15 @@ bool TypeChecker::diagnoseInvalidFunctionType(FunctionType *fnTy, SourceLoc loc,
619619
dc, stage);
620620
}) != params.end();
621621
bool alreadyDiagnosedOneParam = false;
622-
bool hasInoutDiffParameter = false;
622+
bool hasInoutDifferentiableParameter = false;
623623
for (unsigned i = 0, end = fnTy->getNumParams(); i != end; ++i) {
624624
auto param = params[i];
625625
if (param.isNoDerivative())
626626
continue;
627627
auto paramType = param.getPlainType();
628628
if (TypeChecker::isDifferentiable(paramType, isLinear, dc, stage)) {
629629
if (param.isInOut())
630-
hasInoutDiffParameter = true;
630+
hasInoutDifferentiableParameter = true;
631631
continue;
632632
}
633633
auto diagLoc =
@@ -657,12 +657,13 @@ bool TypeChecker::diagnoseInvalidFunctionType(FunctionType *fnTy, SourceLoc loc,
657657
}
658658

659659
// Check the result.
660-
bool resultIsDifferentiable =
661-
isDifferentiable(result, /*tangentVectorEqualsSelf*/ isLinear, dc,
662-
stage);
660+
bool resultExists = !(result->isVoid());
661+
bool resultIsDifferentiable = TypeChecker::isDifferentiable(
662+
result, /*tangentVectorEqualsSelf*/ isLinear, dc, stage);
663+
bool differentiableResultExists = resultExists && resultIsDifferentiable;
663664

664665
// Reject the case where there are multiple semantic results.
665-
if (resultIsDifferentiable && hasInoutDiffParameter) {
666+
if (differentiableResultExists && hasInoutDifferentiableParameter) {
666667
auto diagLoc = repr ? (*repr)->getArgsTypeRepr()->getLoc() : loc;
667668
auto diag = ctx.Diags.diagnose(
668669
diagLoc,
@@ -674,8 +675,8 @@ bool TypeChecker::diagnoseInvalidFunctionType(FunctionType *fnTy, SourceLoc loc,
674675
}
675676
}
676677

677-
// Reject the case where there are no semantic results.
678-
if (!resultIsDifferentiable && !hasInoutDiffParameter) {
678+
// Reject the case where the semantic result is not differentiable.
679+
if (!resultIsDifferentiable && !hasInoutDifferentiableParameter) {
679680
auto diagLoc = repr ? (*repr)->getResultTypeRepr()->getLoc() : loc;
680681
auto resultStr = fnTy->getResult()->getString();
681682
auto diag = ctx.Diags.diagnose(
@@ -687,6 +688,19 @@ bool TypeChecker::diagnoseInvalidFunctionType(FunctionType *fnTy, SourceLoc loc,
687688
diag.highlight((*repr)->getResultTypeRepr()->getSourceRange());
688689
}
689690
}
691+
692+
// Reject the case where there are no semantic results.
693+
if (!resultExists && !hasInoutDifferentiableParameter) {
694+
auto diagLoc = repr ? (*repr)->getResultTypeRepr()->getLoc() : loc;
695+
auto diag = ctx.Diags.diagnose(
696+
diagLoc, diag::differentiable_function_type_no_differentiable_result,
697+
isLinear);
698+
hadAnyError = true;
699+
700+
if (repr) {
701+
diag.highlight((*repr)->getResultTypeRepr()->getSourceRange());
702+
}
703+
}
690704
}
691705

692706
return hadAnyError;

test/AutoDiff/Sema/derivative_attr_type_checking.swift

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -746,7 +746,7 @@ extension ProtocolRequirementDerivative {
746746
func multipleSemanticResults(_ x: inout Float) -> Float {
747747
return x
748748
}
749-
// expected-error @+1 {{cannot differentiate functions with both a differentiable 'inout' parameter and a result}}
749+
// expected-error @+1 {{cannot differentiate functions with both a differentiable 'inout' parameter and a differentiable result}}
750750
@derivative(of: multipleSemanticResults)
751751
func vjpMultipleSemanticResults(x: inout Float) -> (
752752
value: Float, pullback: (Float) -> Float
@@ -885,14 +885,14 @@ func vjpNoSemanticResults(_ x: Float) -> (value: Void, pullback: Void) {}
885885

886886
extension InoutParameters {
887887
func multipleSemanticResults(_ x: inout Float) -> Float { x }
888-
// expected-error @+1 {{cannot differentiate functions with both a differentiable 'inout' parameter and a result}}
888+
// expected-error @+1 {{cannot differentiate functions with both a differentiable 'inout' parameter and a differentiable result}}
889889
@derivative(of: multipleSemanticResults)
890890
func vjpMultipleSemanticResults(_ x: inout Float) -> (
891891
value: Float, pullback: (inout Float) -> Void
892892
) { fatalError() }
893893

894894
func inoutVoid(_ x: Float, _ void: inout Void) -> Float {}
895-
// expected-error @+1 {{cannot differentiate functions with both a differentiable 'inout' parameter and a result}}
895+
// expected-error @+1 {{cannot differentiate functions with both a differentiable 'inout' parameter and a differentiable result}}
896896
@derivative(of: inoutVoid)
897897
func vjpInoutVoidParameter(_ x: Float, _ void: inout Void) -> (
898898
value: Float, pullback: (inout Float) -> Void

test/AutoDiff/Sema/differentiable_attr_type_checking.swift

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -528,7 +528,7 @@ func two9(x: Float, y: Float) -> Float {
528528
func inout1(x: Float, y: inout Float) -> Void {
529529
let _ = x + y
530530
}
531-
// expected-error @+1 {{cannot differentiate functions with both a differentiable 'inout' parameter and a result}}
531+
// expected-error @+1 {{cannot differentiate functions with both a differentiable 'inout' parameter and a differentiable result}}
532532
@differentiable(reverse, wrt: y)
533533
func inout2(x: Float, y: inout Float) -> Float {
534534
let _ = x + y
@@ -670,11 +670,11 @@ final class FinalClass: Differentiable {
670670
@differentiable(reverse, wrt: y)
671671
func inoutVoid(x: Float, y: inout Float) {}
672672

673-
// expected-error @+1 {{cannot differentiate functions with both a differentiable 'inout' parameter and a result}}
673+
// expected-error @+1 {{cannot differentiate functions with both a differentiable 'inout' parameter and a differentiable result}}
674674
@differentiable(reverse)
675675
func multipleSemanticResults(_ x: inout Float) -> Float { x }
676676

677-
// expected-error @+1 {{cannot differentiate functions with both a differentiable 'inout' parameter and a result}}
677+
// expected-error @+1 {{cannot differentiate functions with both a differentiable 'inout' parameter and a differentiable result}}
678678
@differentiable(reverse, wrt: y)
679679
func swap(x: inout Float, y: inout Float) {}
680680

@@ -687,7 +687,7 @@ extension InoutParameters {
687687
@differentiable(reverse)
688688
static func staticMethod(_ lhs: inout Self, rhs: Self) {}
689689

690-
// expected-error @+1 {{cannot differentiate functions with both a differentiable 'inout' parameter and a result}}
690+
// expected-error @+1 {{cannot differentiate functions with both a differentiable 'inout' parameter and a differentiable result}}
691691
@differentiable(reverse)
692692
static func multipleSemanticResults(_ lhs: inout Self, rhs: Self) -> Self {}
693693
}
@@ -696,7 +696,7 @@ extension InoutParameters {
696696
@differentiable(reverse)
697697
mutating func mutatingMethod(_ other: Self) {}
698698

699-
// expected-error @+1 {{cannot differentiate functions with both a differentiable 'inout' parameter and a result}}
699+
// expected-error @+1 {{cannot differentiate functions with both a differentiable 'inout' parameter and a differentiable result}}
700700
@differentiable(reverse)
701701
mutating func mutatingMethod(_ other: Self) -> Self {}
702702
}

test/AutoDiff/Sema/differentiable_func_type.swift

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,12 @@ let _: @differentiable(reverse) (Float, NonDiffType) -> Float
2828
// expected-error @+1 {{result type 'NonDiffType' does not conform to 'Differentiable' and satisfy 'NonDiffType == NonDiffType.TangentVector', but the enclosing function type is '@differentiable(_linear)'}}
2929
let _: @differentiable(_linear) (Float) -> NonDiffType
3030

31+
// expected-error @+1 {{'@differentiable' function type cannot have both a differentiable 'inout' parameter and a differentiable result}}
32+
let _: @differentiable(reverse) (inout Float) -> Float
33+
34+
// expected-error @+1 {{'@differentiable' function type cannot have both a differentiable 'inout' parameter and a differentiable result}}
35+
let _: @differentiable(_linear) (inout Float) -> Float
36+
3137
// Emit `@noDerivative` fixit iff there is at least one valid linearity parameter.
3238
// expected-error @+1 {{parameter type 'NonDiffType' does not conform to 'Differentiable' and satisfy 'NonDiffType == NonDiffType.TangentVector', but the enclosing function type is '@differentiable(_linear)'; did you want to add '@noDerivative' to this parameter?}} {{41-41=@noDerivative }}
3339
let _: @differentiable(_linear) (Float, NonDiffType) -> Float
@@ -40,10 +46,10 @@ let _: @differentiable(_linear) (Float) -> NonDiffType
4046

4147
let _: @differentiable(_linear) (Float) -> Float
4248

43-
// expected-error @+1 {{'@differentiable' function type cannot have both a differentiable 'inout' parameter and a result}}
49+
// expected-error @+1 {{'@differentiable' function type cannot have both a differentiable 'inout' parameter and a differentiable result}}
4450
let _: @differentiable(reverse) (inout Float) -> Float
4551

46-
// expected-error @+1 {{'@differentiable' function type cannot have both a differentiable 'inout' parameter and a result}}
52+
// expected-error @+1 {{'@differentiable' function type cannot have both a differentiable 'inout' parameter and a differentiable result}}
4753
let _: @differentiable(_linear) (inout Float) -> Float
4854

4955
// expected-error @+1 {{result type '@differentiable(reverse) (U) -> Float' does not conform to 'Differentiable', but the enclosing function type is '@differentiable'}}

test/AutoDiff/compiler_crashers_fixed/sr15808-non-differentiable-closure-parameters.swift

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,18 @@
11
// RUN: %target-swift-frontend -emit-sil -verify %s
2-
32
// SR-15808: In AST, type checking skips a closure with non-differentiable input
43
// where `Void` is included as a parameter without being marked `@noDerivative`.
54
// It also crashes when the output is `Void` and no input is `inout`. As a
65
// result, the compiler crashes during Sema.
7-
86
import _Differentiation
97

10-
// expected-error @+2 {{parameter type '()' does not conform to 'Differentiable', but the enclosing function type is '@differentiable'}}
11-
// expected-error @+1 {{result type 'Void' does not conform to 'Differentiable', but the enclosing function type is '@differentiable'}}
8+
// expected-error @+1 {{'@differentiable' function type requires a differentiable result, i.e. a non-'Void' type that conforms to 'Differentiable'}}
129
func helloWorld(_ x: @differentiable(reverse) (()) -> Void) {}
1310

14-
// expected-error @+1 {{parameter type '()' does not conform to 'Differentiable', but the enclosing function type is '@differentiable'}}
1511
func helloWorld(_ x: @differentiable(reverse) (()) -> Float) {}
1612

17-
// expected-error @+1 {{result type 'Void' does not conform to 'Differentiable', but the enclosing function type is '@differentiable'}}
13+
// expected-error @+1 {{'@differentiable' function type requires a differentiable result, i.e. a non-'Void' type that conforms to 'Differentiable'}}
1814
func helloWorld(_ x: @differentiable(reverse) (Float) -> Void) {}
1915

20-
// expected-error @+1 {{parameter type 'Void' does not conform to 'Differentiable', but the enclosing function type is '@differentiable'}}
2116
func helloWorld(_ x: @differentiable(reverse) (@noDerivative Float, Void) -> Float) {}
2217

2318
// Original crash:

test/AutoDiff/compiler_crashers_fixed/sr15818-inout-noderivative-closure-parameter.swift

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,13 @@
11
// RUN: %target-swift-frontend -emit-sil -verify %s
2-
32
import _Differentiation
43

5-
// expected-error @+1 {{result type 'Void' does not conform to 'Differentiable', but the enclosing function type is '@differentiable'}}
4+
// expected-error @+1 {{@differentiable' function type requires a differentiable result, i.e. a non-'Void' type that conforms to 'Differentiable'}}
65
typealias MyType = @differentiable(reverse) (inout @noDerivative Float, Float) -> Void
76

87
@differentiable(reverse)
9-
func myFunc(_ x: inout @noDerivative Float, _ q: Float) -> Void {}
8+
func myFunc(_ x: inout @noDerivative Float, _ q: Float) -> Void {}
109

11-
print(myFunc as MyType)
10+
let castedFunc = myFunc as MyType
1211

1312
// Original crash:
1413
// Assertion failed: (Index < Length && "Invalid index!"), function operator[], file ArrayRef.h, line 257.

0 commit comments

Comments
 (0)