Skip to content

Commit c5efb80

Browse files
committed
Migrate changes from PR #41422
1 parent d2f5188 commit c5efb80

10 files changed

+56
-42
lines changed

include/swift/AST/AutoDiff.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -364,7 +364,9 @@ class TangentSpace {
364364
static TangentSpace getTangentVector(Type tangentVectorType) {
365365
return {Kind::TangentVector, tangentVectorType};
366366
}
367-
static TangentSpace getTuple(TupleType *tupleTy);
367+
static TangentSpace getTuple(TupleType *tupleTy) {
368+
return {Kind::Tuple, tupleTy};
369+
}
368370

369371
bool isTangentVector() const { return kind == Kind::TangentVector; }
370372
bool isTuple() const { return kind == Kind::Tuple; }

include/swift/AST/DiagnosticsSema.def

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3500,7 +3500,7 @@ ERROR(autodiff_attr_original_void_result,none,
35003500
"cannot differentiate void function %0", (DeclName))
35013501
ERROR(autodiff_attr_original_multiple_semantic_results,none,
35023502
"cannot differentiate functions with both a differentiable 'inout' "
3503-
"parameter and a result", ())
3503+
"parameter and a differentiable result", ())
35043504
ERROR(autodiff_attr_result_not_differentiable,none,
35053505
"can only differentiate functions with results that conform to "
35063506
"'Differentiable', but %0 does not conform to 'Differentiable'", (Type))
@@ -5059,13 +5059,17 @@ ERROR(differentiable_function_type_invalid_result,none,
50595059
(StringRef, bool))
50605060
ERROR(differentiable_function_type_multiple_semantic_results,none,
50615061
"'@differentiable' function type cannot have both a differentiable "
5062-
"'inout' parameter and a result", ())
5063-
ERROR(differentiable_function_type_no_differentiability_parameters,
5064-
none,
5062+
"'inout' parameter and a differentiable result", ())
5063+
ERROR(differentiable_function_type_no_differentiability_parameters,none,
50655064
"'@differentiable' function type requires at least one differentiability "
50665065
"parameter, i.e. a non-'@noDerivative' parameter whose type conforms to "
50675066
"'Differentiable'%select{| with its 'TangentVector' equal to itself}0",
50685067
(/*isLinear*/ bool))
5068+
ERROR(differentiable_function_type_no_differentiable_result,none,
5069+
"'@differentiable' function type requires a differentiable result, i.e. "
5070+
"a non-'Void' type that conforms to 'Differentiable'%select{| with its "
5071+
"'TangentVector' equal to itself}0",
5072+
(/*isLinear*/ bool))
50695073

50705074
// SIL
50715075
ERROR(opened_non_protocol,none,

lib/AST/AutoDiff.cpp

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -366,13 +366,6 @@ GenericSignature autodiff::getDifferentiabilityWitnessGenericSignature(
366366
return derivativeGenSig;
367367
}
368368

369-
TangentSpace TangentSpace::getTuple(TupleType *tupleTy) {
370-
assert(!tupleTy->isVoid() &&
371-
"Attempted to get tangent space of 'Void', which cannot be "
372-
"differentiated.");
373-
return {Kind::Tuple, tupleTy};
374-
}
375-
376369
Type TangentSpace::getType() const {
377370
switch (kind) {
378371
case Kind::TangentVector:

lib/AST/Type.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6177,7 +6177,8 @@ TypeBase::getAutoDiffTangentSpace(LookupConformanceFn lookupConformance) {
61776177
newElts.push_back(elt.getWithType(eltSpace->getType()));
61786178
}
61796179
if (newElts.empty())
6180-
return cache(None);
6180+
return cache(
6181+
TangentSpace::getTuple(ctx.TheEmptyTupleType->castTo<TupleType>()));
61816182
if (newElts.size() == 1)
61826183
return cache(TangentSpace::getTangentVector(newElts.front().getType()));
61836184
auto *tupleType = TupleType::get(newElts, ctx)->castTo<TupleType>();

lib/Sema/TypeChecker.cpp

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -618,15 +618,15 @@ bool TypeChecker::diagnoseInvalidFunctionType(FunctionType *fnTy, SourceLoc loc,
618618
dc, stage);
619619
}) != params.end();
620620
bool alreadyDiagnosedOneParam = false;
621-
bool hasInoutDiffParameter = false;
621+
bool hasInoutDifferentiableParameter = false;
622622
for (unsigned i = 0, end = fnTy->getNumParams(); i != end; ++i) {
623623
auto param = params[i];
624624
if (param.isNoDerivative())
625625
continue;
626626
auto paramType = param.getPlainType();
627627
if (TypeChecker::isDifferentiable(paramType, isLinear, dc, stage)) {
628628
if (param.isInOut())
629-
hasInoutDiffParameter = true;
629+
hasInoutDifferentiableParameter = true;
630630
continue;
631631
}
632632
auto diagLoc =
@@ -656,12 +656,13 @@ bool TypeChecker::diagnoseInvalidFunctionType(FunctionType *fnTy, SourceLoc loc,
656656
}
657657

658658
// Check the result.
659-
bool resultIsDifferentiable =
660-
isDifferentiable(result, /*tangentVectorEqualsSelf*/ isLinear, dc,
661-
stage);
659+
bool resultExists = !(result->isVoid());
660+
bool resultIsDifferentiable = TypeChecker::isDifferentiable(
661+
result, /*tangentVectorEqualsSelf*/ isLinear, dc, stage);
662+
bool differentiableResultExists = resultExists && resultIsDifferentiable;
662663

663664
// Reject the case where there are multiple semantic results.
664-
if (resultIsDifferentiable && hasInoutDiffParameter) {
665+
if (differentiableResultExists && hasInoutDifferentiableParameter) {
665666
auto diagLoc = repr ? (*repr)->getArgsTypeRepr()->getLoc() : loc;
666667
auto diag = ctx.Diags.diagnose(
667668
diagLoc,
@@ -673,8 +674,8 @@ bool TypeChecker::diagnoseInvalidFunctionType(FunctionType *fnTy, SourceLoc loc,
673674
}
674675
}
675676

676-
// Reject the case where there are no semantic results.
677-
if (!resultIsDifferentiable && !hasInoutDiffParameter) {
677+
// Reject the case where the semantic result is not differentiable.
678+
if (!resultIsDifferentiable && !hasInoutDifferentiableParameter) {
678679
auto diagLoc = repr ? (*repr)->getResultTypeRepr()->getLoc() : loc;
679680
auto resultStr = fnTy->getResult()->getString();
680681
auto diag = ctx.Diags.diagnose(
@@ -686,6 +687,19 @@ bool TypeChecker::diagnoseInvalidFunctionType(FunctionType *fnTy, SourceLoc loc,
686687
diag.highlight((*repr)->getResultTypeRepr()->getSourceRange());
687688
}
688689
}
690+
691+
// Reject the case where there are no semantic results.
692+
if (!resultExists && !hasInoutDifferentiableParameter) {
693+
auto diagLoc = repr ? (*repr)->getResultTypeRepr()->getLoc() : loc;
694+
auto diag = ctx.Diags.diagnose(
695+
diagLoc, diag::differentiable_function_type_no_differentiable_result,
696+
isLinear);
697+
hadAnyError = true;
698+
699+
if (repr) {
700+
diag.highlight((*repr)->getResultTypeRepr()->getSourceRange());
701+
}
702+
}
689703
}
690704

691705
return hadAnyError;

test/AutoDiff/Sema/derivative_attr_type_checking.swift

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -746,7 +746,7 @@ extension ProtocolRequirementDerivative {
746746
func multipleSemanticResults(_ x: inout Float) -> Float {
747747
return x
748748
}
749-
// expected-error @+1 {{cannot differentiate functions with both a differentiable 'inout' parameter and a result}}
749+
// expected-error @+1 {{cannot differentiate functions with both a differentiable 'inout' parameter and a differentiable result}}
750750
@derivative(of: multipleSemanticResults)
751751
func vjpMultipleSemanticResults(x: inout Float) -> (
752752
value: Float, pullback: (Float) -> Float
@@ -885,14 +885,14 @@ func vjpNoSemanticResults(_ x: Float) -> (value: Void, pullback: Void) {}
885885

886886
extension InoutParameters {
887887
func multipleSemanticResults(_ x: inout Float) -> Float { x }
888-
// expected-error @+1 {{cannot differentiate functions with both a differentiable 'inout' parameter and a result}}
888+
// expected-error @+1 {{cannot differentiate functions with both a differentiable 'inout' parameter and a differentiable result}}
889889
@derivative(of: multipleSemanticResults)
890890
func vjpMultipleSemanticResults(_ x: inout Float) -> (
891891
value: Float, pullback: (inout Float) -> Void
892892
) { fatalError() }
893893

894894
func inoutVoid(_ x: Float, _ void: inout Void) -> Float {}
895-
// expected-error @+1 {{cannot differentiate functions with both a differentiable 'inout' parameter and a result}}
895+
// expected-error @+1 {{cannot differentiate functions with both a differentiable 'inout' parameter and a differentiable result}}
896896
@derivative(of: inoutVoid)
897897
func vjpInoutVoidParameter(_ x: Float, _ void: inout Void) -> (
898898
value: Float, pullback: (inout Float) -> Void

test/AutoDiff/Sema/differentiable_attr_type_checking.swift

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -528,7 +528,7 @@ func two9(x: Float, y: Float) -> Float {
528528
func inout1(x: Float, y: inout Float) -> Void {
529529
let _ = x + y
530530
}
531-
// expected-error @+1 {{cannot differentiate functions with both a differentiable 'inout' parameter and a result}}
531+
// expected-error @+1 {{cannot differentiate functions with both a differentiable 'inout' parameter and a differentiable result}}
532532
@differentiable(reverse, wrt: y)
533533
func inout2(x: Float, y: inout Float) -> Float {
534534
let _ = x + y
@@ -670,11 +670,11 @@ final class FinalClass: Differentiable {
670670
@differentiable(reverse, wrt: y)
671671
func inoutVoid(x: Float, y: inout Float) {}
672672

673-
// expected-error @+1 {{cannot differentiate functions with both a differentiable 'inout' parameter and a result}}
673+
// expected-error @+1 {{cannot differentiate functions with both a differentiable 'inout' parameter and a differentiable result}}
674674
@differentiable(reverse)
675675
func multipleSemanticResults(_ x: inout Float) -> Float { x }
676676

677-
// expected-error @+1 {{cannot differentiate functions with both a differentiable 'inout' parameter and a result}}
677+
// expected-error @+1 {{cannot differentiate functions with both a differentiable 'inout' parameter and a differentiable result}}
678678
@differentiable(reverse, wrt: y)
679679
func swap(x: inout Float, y: inout Float) {}
680680

@@ -687,7 +687,7 @@ extension InoutParameters {
687687
@differentiable(reverse)
688688
static func staticMethod(_ lhs: inout Self, rhs: Self) {}
689689

690-
// expected-error @+1 {{cannot differentiate functions with both a differentiable 'inout' parameter and a result}}
690+
// expected-error @+1 {{cannot differentiate functions with both a differentiable 'inout' parameter and a differentiable result}}
691691
@differentiable(reverse)
692692
static func multipleSemanticResults(_ lhs: inout Self, rhs: Self) -> Self {}
693693
}
@@ -696,7 +696,7 @@ extension InoutParameters {
696696
@differentiable(reverse)
697697
mutating func mutatingMethod(_ other: Self) {}
698698

699-
// expected-error @+1 {{cannot differentiate functions with both a differentiable 'inout' parameter and a result}}
699+
// expected-error @+1 {{cannot differentiate functions with both a differentiable 'inout' parameter and a differentiable result}}
700700
@differentiable(reverse)
701701
mutating func mutatingMethod(_ other: Self) -> Self {}
702702
}

test/AutoDiff/Sema/differentiable_func_type.swift

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,12 @@ let _: @differentiable(reverse) (Float, NonDiffType) -> Float
2828
// expected-error @+1 {{result type 'NonDiffType' does not conform to 'Differentiable' and satisfy 'NonDiffType == NonDiffType.TangentVector', but the enclosing function type is '@differentiable(_linear)'}}
2929
let _: @differentiable(_linear) (Float) -> NonDiffType
3030

31+
// expected-error @+1 {{'@differentiable' function type cannot have both a differentiable 'inout' parameter and a differentiable result}}
32+
let _: @differentiable(reverse) (inout Float) -> Float
33+
34+
// expected-error @+1 {{'@differentiable' function type cannot have both a differentiable 'inout' parameter and a differentiable result}}
35+
let _: @differentiable(_linear) (inout Float) -> Float
36+
3137
// Emit `@noDerivative` fixit iff there is at least one valid linearity parameter.
3238
// expected-error @+1 {{parameter type 'NonDiffType' does not conform to 'Differentiable' and satisfy 'NonDiffType == NonDiffType.TangentVector', but the enclosing function type is '@differentiable(_linear)'; did you want to add '@noDerivative' to this parameter?}} {{41-41=@noDerivative }}
3339
let _: @differentiable(_linear) (Float, NonDiffType) -> Float
@@ -40,10 +46,10 @@ let _: @differentiable(_linear) (Float) -> NonDiffType
4046

4147
let _: @differentiable(_linear) (Float) -> Float
4248

43-
// expected-error @+1 {{'@differentiable' function type cannot have both a differentiable 'inout' parameter and a result}}
49+
// expected-error @+1 {{'@differentiable' function type cannot have both a differentiable 'inout' parameter and a differentiable result}}
4450
let _: @differentiable(reverse) (inout Float) -> Float
4551

46-
// expected-error @+1 {{'@differentiable' function type cannot have both a differentiable 'inout' parameter and a result}}
52+
// expected-error @+1 {{'@differentiable' function type cannot have both a differentiable 'inout' parameter and a differentiable result}}
4753
let _: @differentiable(_linear) (inout Float) -> Float
4854

4955
// expected-error @+1 {{result type '@differentiable(reverse) (U) -> Float' does not conform to 'Differentiable', but the enclosing function type is '@differentiable'}}

test/AutoDiff/compiler_crashers_fixed/sr15808-non-differentiable-closure-parameters.swift

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,18 @@
11
// RUN: %target-swift-frontend -emit-sil -verify %s
2-
32
// SR-15808: In AST, type checking skips a closure with non-differentiable input
43
// where `Void` is included as a parameter without being marked `@noDerivative`.
54
// It also crashes when the output is `Void` and no input is `inout`. As a
65
// result, the compiler crashes during Sema.
7-
86
import _Differentiation
97

10-
// expected-error @+2 {{parameter type '()' does not conform to 'Differentiable', but the enclosing function type is '@differentiable'}}
11-
// expected-error @+1 {{result type 'Void' does not conform to 'Differentiable', but the enclosing function type is '@differentiable'}}
8+
// expected-error @+1 {{'@differentiable' function type requires a differentiable result, i.e. a non-'Void' type that conforms to 'Differentiable'}}
129
func helloWorld(_ x: @differentiable(reverse) (()) -> Void) {}
1310

14-
// expected-error @+1 {{parameter type '()' does not conform to 'Differentiable', but the enclosing function type is '@differentiable'}}
1511
func helloWorld(_ x: @differentiable(reverse) (()) -> Float) {}
1612

17-
// expected-error @+1 {{result type 'Void' does not conform to 'Differentiable', but the enclosing function type is '@differentiable'}}
13+
// expected-error @+1 {{'@differentiable' function type requires a differentiable result, i.e. a non-'Void' type that conforms to 'Differentiable'}}
1814
func helloWorld(_ x: @differentiable(reverse) (Float) -> Void) {}
1915

20-
// expected-error @+1 {{parameter type 'Void' does not conform to 'Differentiable', but the enclosing function type is '@differentiable'}}
2116
func helloWorld(_ x: @differentiable(reverse) (@noDerivative Float, Void) -> Float) {}
2217

2318
// Original crash:

test/AutoDiff/compiler_crashers_fixed/sr15818-inout-noderivative-closure-parameter.swift

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,13 @@
11
// RUN: %target-swift-frontend -emit-sil -verify %s
2-
32
import _Differentiation
43

5-
// expected-error @+1 {{result type 'Void' does not conform to 'Differentiable', but the enclosing function type is '@differentiable'}}
4+
// expected-error @+1 {{@differentiable' function type requires a differentiable result, i.e. a non-'Void' type that conforms to 'Differentiable'}}
65
typealias MyType = @differentiable(reverse) (inout @noDerivative Float, Float) -> Void
76

87
@differentiable(reverse)
9-
func myFunc(_ x: inout @noDerivative Float, _ q: Float) -> Void {}
8+
func myFunc(_ x: inout @noDerivative Float, _ q: Float) -> Void {}
109

11-
print(myFunc as MyType)
10+
let castedFunc = myFunc as MyType
1211

1312
// Original crash:
1413
// Assertion failed: (Index < Length && "Invalid index!"), function operator[], file ArrayRef.h, line 257.

0 commit comments

Comments
 (0)