Skip to content

Commit da1e398

Browse files
committed
[AutoDiff] Require same access level for original/derivative functions.
Require `@derivative` functions and their original functions to have the same access level. Public original functions may also have internal `@usableFromInline` derivatives, as a special case. Diagnose access level mismatches. This simplifies derivative registration rules, and may enable simplification of AutoDiff symbol linkage rules. Resolves TF-1099 and TF-1160.
1 parent fae995a commit da1e398

File tree

4 files changed

+158
-1
lines changed

4 files changed

+158
-1
lines changed

include/swift/AST/DiagnosticsSema.def

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3054,6 +3054,22 @@ ERROR(derivative_attr_original_already_has_derivative,none,
30543054
"a derivative already exists for %0", (DeclName))
30553055
NOTE(derivative_attr_duplicate_note,none,
30563056
"other attribute declared here", ())
3057+
ERROR(derivative_attr_access_level_lower_than_original,none,
3058+
"derivative functions for "
3059+
"%select{private|fileprivate|internal|public or '@usableFromInline'|open}1 "
3060+
"original function %0 must also be "
3061+
"%select{private|fileprivate|internal|public or '@usableFromInline'|open}1, "
3062+
"but %2 is %select{private|fileprivate|internal|public|open}3",
3063+
(/*original*/ DeclName, /*original*/ AccessLevel,
3064+
/*derivative*/ DeclName, /*derivative*/ AccessLevel))
3065+
ERROR(derivative_attr_access_level_higher_than_original,none,
3066+
"the original function of "
3067+
"%select{a private|a fileprivate|an internal|a public or '@usableFromInline'|an open}3 "
3068+
"derivative function must also be "
3069+
"%select{private|fileprivate|internal|public or '@usableFromInline'|open}3, "
3070+
"but %0 is %select{private|fileprivate|internal|public|open}1",
3071+
(/*original*/ DeclName, /*original*/ AccessLevel,
3072+
/*derivative*/ DeclName, /*derivative*/ AccessLevel))
30573073

30583074
// @transpose
30593075
ERROR(transpose_attr_invalid_linearity_parameter_or_result,none,

lib/Sema/TypeCheckAttr.cpp

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4540,6 +4540,37 @@ static bool typeCheckDerivativeAttr(ASTContext &Ctx, Decl *D,
45404540
}
45414541
attr->setOriginalFunction(originalAFD);
45424542

4543+
// Returns true if:
4544+
// - Original function and derivative function has the same access level.
4545+
// - Original function is public and derivative function is internal
4546+
// `@usableFromInline`. This is the only special case.
4547+
auto compatibleAccessLevels = [&]() {
4548+
if (originalAFD->getFormalAccess() == derivative->getFormalAccess())
4549+
return true;
4550+
return originalAFD->getFormalAccess() == AccessLevel::Public &&
4551+
derivative->getEffectiveAccess() == AccessLevel::Public;
4552+
};
4553+
4554+
// Check access level compatibility for original and derivative functions.
4555+
if (!compatibleAccessLevels()) {
4556+
auto diagID = diag::derivative_attr_access_level_higher_than_original;
4557+
AccessLevel originalAccess;
4558+
AccessLevel derivativeAccess;
4559+
if (originalAFD->getEffectiveAccess() < derivative->getEffectiveAccess()) {
4560+
originalAccess =
4561+
originalAFD->getFormalAccessScope().accessLevelForDiagnostics();
4562+
derivativeAccess = derivative->getEffectiveAccess();
4563+
} else {
4564+
diagID = diag::derivative_attr_access_level_lower_than_original;
4565+
originalAccess = originalAFD->getEffectiveAccess();
4566+
derivativeAccess =
4567+
derivative->getFormalAccessScope().accessLevelForDiagnostics();
4568+
}
4569+
diags.diagnose(originalName.Loc, diagID, originalAFD->getName(),
4570+
originalAccess, derivative->getName(), derivativeAccess);
4571+
return true;
4572+
}
4573+
45434574
// Get the resolved differentiability parameter indices.
45444575
auto *resolvedDiffParamIndices = attr->getParameterIndices();
45454576

test/AutoDiff/SILGen/sil_differentiability_witness.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ public func foo_vjp(_ x: Float) -> (value: Float, pullback: (Float) -> Float) {
4444
func bar<T>(_ x: Float, _ y: T) -> Float { x }
4545

4646
@derivative(of: bar)
47-
public func bar_jvp<T>(_ x: Float, _ y: T) -> (value: Float, differential: (Float) -> Float) {
47+
func bar_jvp<T>(_ x: Float, _ y: T) -> (value: Float, differential: (Float) -> Float) {
4848
(x, { $0 })
4949
}
5050

test/AutoDiff/Sema/derivative_attr_type_checking.swift

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -746,6 +746,7 @@ extension InoutParameters {
746746
// Test cross-file derivative registration.
747747

748748
extension FloatingPoint where Self: Differentiable {
749+
@usableFromInline
749750
@derivative(of: rounded)
750751
func vjpRounded() -> (
751752
value: Self,
@@ -802,3 +803,112 @@ extension HasADefaultDerivative {
802803
(x, { 10 * $0 })
803804
}
804805
}
806+
807+
// MARK: - Original function visibility = derivative function visibility
808+
809+
public func public_original_public_derivative(_ x: Float) -> Float { x }
810+
@derivative(of: public_original_public_derivative)
811+
public func _public_original_public_derivative(_ x: Float) -> (value: Float, pullback: (Float) -> Float) {
812+
fatalError()
813+
}
814+
815+
public func public_original_usablefrominline_derivative(_ x: Float) -> Float { x }
816+
@usableFromInline
817+
@derivative(of: public_original_usablefrominline_derivative)
818+
func _public_original_usablefrominline_derivative(_ x: Float) -> (value: Float, pullback: (Float) -> Float) {
819+
fatalError()
820+
}
821+
822+
@usableFromInline
823+
func usablefrominline_original_public_derivative(_ x: Float) -> Float { x }
824+
// expected-error @+1 {{derivative functions for public or '@usableFromInline' original function 'usablefrominline_original_public_derivative' must also be public or '@usableFromInline', but '_usablefrominline_original_public_derivative' is public}}
825+
@derivative(of: usablefrominline_original_public_derivative)
826+
public func _usablefrominline_original_public_derivative(_ x: Float) -> (value: Float, pullback: (Float) -> Float) {
827+
fatalError()
828+
}
829+
830+
func internal_original_internal_derivative(_ x: Float) -> Float { x }
831+
@derivative(of: internal_original_internal_derivative)
832+
func _internal_original_internal_derivative(_ x: Float) -> (value: Float, pullback: (Float) -> Float) {
833+
fatalError()
834+
}
835+
836+
private func private_original_private_derivative(_ x: Float) -> Float { x }
837+
@derivative(of: private_original_private_derivative)
838+
private func _private_original_private_derivative(_ x: Float) -> (value: Float, pullback: (Float) -> Float) {
839+
fatalError()
840+
}
841+
842+
fileprivate func fileprivate_original_fileprivate_derivative(_ x: Float) -> Float { x }
843+
@derivative(of: fileprivate_original_fileprivate_derivative)
844+
fileprivate func _fileprivate_original_fileprivate_derivative(_ x: Float) -> (value: Float, pullback: (Float) -> Float) {
845+
fatalError()
846+
}
847+
848+
fileprivate func fileprivate_original_private_derivative(_ x: Float) -> Float { x }
849+
// expected-error @+1 {{derivative functions for fileprivate original function 'fileprivate_original_private_derivative' must also be fileprivate, but '_fileprivate_original_private_derivative' is private}}
850+
@derivative(of: fileprivate_original_private_derivative)
851+
private func _fileprivate_original_private_derivative(_ x: Float) -> (value: Float, pullback: (Float) -> Float) {
852+
fatalError()
853+
}
854+
855+
private func private_original_fileprivate_derivative(_ x: Float) -> Float { x }
856+
// expected-error @+1 {{derivative functions for fileprivate original function 'private_original_fileprivate_derivative' must also be fileprivate, but '_private_original_fileprivate_derivative' is fileprivate}}
857+
@derivative(of: private_original_fileprivate_derivative)
858+
fileprivate func _private_original_fileprivate_derivative(_ x: Float) -> (value: Float, pullback: (Float) -> Float) {
859+
fatalError()
860+
}
861+
862+
// MARK: - Original function visibility < derivative function visibility
863+
864+
func internal_original_public_derivative(_ x: Float) -> Float { x }
865+
// expected-error @+1 {{the original function of a public or '@usableFromInline' derivative function must also be public or '@usableFromInline', but 'internal_original_public_derivative' is internal}}
866+
@derivative(of: internal_original_public_derivative)
867+
public func _internal_original_public_derivative(_ x: Float) -> (value: Float, pullback: (Float) -> Float) {
868+
fatalError()
869+
}
870+
871+
private func private_original_usablefrominline_derivative(_ x: Float) -> Float { x }
872+
// expected-error @+1 {{the original function of a public or '@usableFromInline' derivative function must also be public or '@usableFromInline', but 'private_original_usablefrominline_derivative' is private}}
873+
@derivative(of: private_original_usablefrominline_derivative)
874+
@usableFromInline
875+
func _private_original_usablefrominline_derivative(_ x: Float) -> (value: Float, pullback: (Float) -> Float) {
876+
fatalError()
877+
}
878+
879+
private func private_original_public_derivative(_ x: Float) -> Float { x }
880+
// expected-error @+1 {{the original function of a public or '@usableFromInline' derivative function must also be public or '@usableFromInline', but 'private_original_public_derivative' is private}}
881+
@derivative(of: private_original_public_derivative)
882+
public func _private_original_public_derivative(_ x: Float) -> (value: Float, pullback: (Float) -> Float) {
883+
fatalError()
884+
}
885+
886+
private func private_original_internal_derivative(_ x: Float) -> Float { x }
887+
// expected-error @+1 {{the original function of an internal derivative function must also be internal, but 'private_original_internal_derivative' is private}}
888+
@derivative(of: private_original_internal_derivative)
889+
func _private_original_internal_derivative(_ x: Float) -> (value: Float, pullback: (Float) -> Float) {
890+
fatalError()
891+
}
892+
893+
// MARK: - Original function visibility > derivative function visibility
894+
895+
public func public_original_private_derivative(_ x: Float) -> Float { x }
896+
// expected-error @+1 {{derivative functions for public or '@usableFromInline' original function 'public_original_private_derivative' must also be public or '@usableFromInline', but '_public_original_private_derivative' is fileprivate}}
897+
@derivative(of: public_original_private_derivative)
898+
fileprivate func _public_original_private_derivative(_ x: Float) -> (value: Float, pullback: (Float) -> Float) {
899+
fatalError()
900+
}
901+
902+
public func public_original_internal_derivative(_ x: Float) -> Float { x }
903+
// expected-error @+1 {{derivative functions for public or '@usableFromInline' original function 'public_original_internal_derivative' must also be public or '@usableFromInline', but '_public_original_internal_derivative' is internal}}
904+
@derivative(of: public_original_internal_derivative)
905+
func _public_original_internal_derivative(_ x: Float) -> (value: Float, pullback: (Float) -> Float) {
906+
fatalError()
907+
}
908+
909+
func internal_original_fileprivate_derivative(_ x: Float) -> Float { x }
910+
// expected-error @+1 {{derivative functions for internal original function 'internal_original_fileprivate_derivative' must also be internal, but '_internal_original_fileprivate_derivative' is fileprivate}}
911+
@derivative(of: internal_original_fileprivate_derivative)
912+
fileprivate func _internal_original_fileprivate_derivative(_ x: Float) -> (value: Float, pullback: (Float) -> Float) {
913+
fatalError()
914+
}

0 commit comments

Comments
 (0)