Skip to content

Commit 5036c95

Browse files
authored
[AutoDiff] Require same access level for original/derivative functions. (#31527)
Require `@derivative` functions and their original functions to have the same access level. Public original functions may also have internal `@usableFromInline` derivatives, as a special case. Diagnose access level mismatches. Produce a fix-it to change the derivative function's access level. This simplifies derivative registration rules, and may enable simplification of AutoDiff symbol linkage rules. Resolves TF-1099 and TF-1160.
1 parent 92493dd commit 5036c95

File tree

4 files changed

+169
-1
lines changed

4 files changed

+169
-1
lines changed

include/swift/AST/DiagnosticsSema.def

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3054,6 +3054,18 @@ ERROR(derivative_attr_original_already_has_derivative,none,
30543054
"a derivative already exists for %0", (DeclName))
30553055
NOTE(derivative_attr_duplicate_note,none,
30563056
"other attribute declared here", ())
3057+
ERROR(derivative_attr_access_level_mismatch,none,
3058+
"derivative function must have same access level as original function; "
3059+
"derivative function %2 is "
3060+
"%select{private|fileprivate|internal|public|open}3, "
3061+
"but original function %0 is "
3062+
"%select{private|fileprivate|internal|public|open}1",
3063+
(/*original*/ DeclName, /*original*/ AccessLevel,
3064+
/*derivative*/ DeclName, /*derivative*/ AccessLevel))
3065+
NOTE(derivative_attr_fix_access,none,
3066+
"mark the derivative function as "
3067+
"'%select{private|fileprivate|internal|@usableFromInline|@usableFromInline}0' "
3068+
"to match the original function", (AccessLevel))
30573069

30583070
// @transpose
30593071
ERROR(transpose_attr_invalid_linearity_parameter_or_result,none,

lib/Sema/TypeCheckAttr.cpp

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4540,6 +4540,42 @@ static bool typeCheckDerivativeAttr(ASTContext &Ctx, Decl *D,
45404540
}
45414541
attr->setOriginalFunction(originalAFD);
45424542

4543+
// Returns true if:
4544+
// - Original function and derivative function have the same access level.
4545+
// - Original function is public and derivative function is internal
4546+
// `@usableFromInline`. This is the only special case.
4547+
auto compatibleAccessLevels = [&]() {
4548+
if (originalAFD->getFormalAccess() == derivative->getFormalAccess())
4549+
return true;
4550+
return originalAFD->getFormalAccess() == AccessLevel::Public &&
4551+
derivative->getEffectiveAccess() == AccessLevel::Public;
4552+
};
4553+
4554+
// Check access level compatibility for original and derivative functions.
4555+
if (!compatibleAccessLevels()) {
4556+
auto originalAccess = originalAFD->getFormalAccess();
4557+
auto derivativeAccess =
4558+
derivative->getFormalAccessScope().accessLevelForDiagnostics();
4559+
diags.diagnose(originalName.Loc,
4560+
diag::derivative_attr_access_level_mismatch,
4561+
originalAFD->getName(), originalAccess,
4562+
derivative->getName(), derivativeAccess);
4563+
auto fixItDiag =
4564+
derivative->diagnose(diag::derivative_attr_fix_access, originalAccess);
4565+
// If original access is public, suggest adding `@usableFromInline` to
4566+
// derivative.
4567+
if (originalAccess == AccessLevel::Public) {
4568+
fixItDiag.fixItInsert(
4569+
derivative->getAttributeInsertionLoc(/*forModifier*/ false),
4570+
"@usableFromInline ");
4571+
}
4572+
// Otherwise, suggest changing derivative access level.
4573+
else {
4574+
fixItAccess(fixItDiag, derivative, originalAccess);
4575+
}
4576+
return true;
4577+
}
4578+
45434579
// Get the resolved differentiability parameter indices.
45444580
auto *resolvedDiffParamIndices = attr->getParameterIndices();
45454581

test/AutoDiff/SILGen/sil_differentiability_witness.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ public func foo_vjp(_ x: Float) -> (value: Float, pullback: (Float) -> Float) {
4444
func bar<T>(_ x: Float, _ y: T) -> Float { x }
4545

4646
@derivative(of: bar)
47-
public func bar_jvp<T>(_ x: Float, _ y: T) -> (value: Float, differential: (Float) -> Float) {
47+
func bar_jvp<T>(_ x: Float, _ y: T) -> (value: Float, differential: (Float) -> Float) {
4848
(x, { $0 })
4949
}
5050

test/AutoDiff/Sema/derivative_attr_type_checking.swift

Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -746,6 +746,7 @@ extension InoutParameters {
746746
// Test cross-file derivative registration.
747747

748748
extension FloatingPoint where Self: Differentiable {
749+
@usableFromInline
749750
@derivative(of: rounded)
750751
func vjpRounded() -> (
751752
value: Self,
@@ -802,3 +803,122 @@ extension HasADefaultDerivative {
802803
(x, { 10 * $0 })
803804
}
804805
}
806+
807+
// MARK: - Original function visibility = derivative function visibility
808+
809+
public func public_original_public_derivative(_ x: Float) -> Float { x }
810+
@derivative(of: public_original_public_derivative)
811+
public func _public_original_public_derivative(_ x: Float) -> (value: Float, pullback: (Float) -> Float) {
812+
fatalError()
813+
}
814+
815+
public func public_original_usablefrominline_derivative(_ x: Float) -> Float { x }
816+
@usableFromInline
817+
@derivative(of: public_original_usablefrominline_derivative)
818+
func _public_original_usablefrominline_derivative(_ x: Float) -> (value: Float, pullback: (Float) -> Float) {
819+
fatalError()
820+
}
821+
822+
func internal_original_internal_derivative(_ x: Float) -> Float { x }
823+
@derivative(of: internal_original_internal_derivative)
824+
func _internal_original_internal_derivative(_ x: Float) -> (value: Float, pullback: (Float) -> Float) {
825+
fatalError()
826+
}
827+
828+
private func private_original_private_derivative(_ x: Float) -> Float { x }
829+
@derivative(of: private_original_private_derivative)
830+
private func _private_original_private_derivative(_ x: Float) -> (value: Float, pullback: (Float) -> Float) {
831+
fatalError()
832+
}
833+
834+
fileprivate func fileprivate_original_fileprivate_derivative(_ x: Float) -> Float { x }
835+
@derivative(of: fileprivate_original_fileprivate_derivative)
836+
fileprivate func _fileprivate_original_fileprivate_derivative(_ x: Float) -> (value: Float, pullback: (Float) -> Float) {
837+
fatalError()
838+
}
839+
840+
// MARK: - Original function visibility < derivative function visibility
841+
842+
@usableFromInline
843+
func usablefrominline_original_public_derivative(_ x: Float) -> Float { x }
844+
// expected-error @+1 {{derivative function must have same access level as original function; derivative function '_usablefrominline_original_public_derivative' is public, but original function 'usablefrominline_original_public_derivative' is internal}}
845+
@derivative(of: usablefrominline_original_public_derivative)
846+
// expected-note @+1 {{mark the derivative function as 'internal' to match the original function}} {{1-7=internal}}
847+
public func _usablefrominline_original_public_derivative(_ x: Float) -> (value: Float, pullback: (Float) -> Float) {
848+
fatalError()
849+
}
850+
851+
func internal_original_public_derivative(_ x: Float) -> Float { x }
852+
// expected-error @+1 {{derivative function must have same access level as original function; derivative function '_internal_original_public_derivative' is public, but original function 'internal_original_public_derivative' is internal}}
853+
@derivative(of: internal_original_public_derivative)
854+
// expected-note @+1 {{mark the derivative function as 'internal' to match the original function}} {{1-7=internal}}
855+
public func _internal_original_public_derivative(_ x: Float) -> (value: Float, pullback: (Float) -> Float) {
856+
fatalError()
857+
}
858+
859+
private func private_original_usablefrominline_derivative(_ x: Float) -> Float { x }
860+
// expected-error @+1 {{derivative function must have same access level as original function; derivative function '_private_original_usablefrominline_derivative' is internal, but original function 'private_original_usablefrominline_derivative' is private}}
861+
@derivative(of: private_original_usablefrominline_derivative)
862+
@usableFromInline
863+
// expected-note @+1 {{mark the derivative function as 'private' to match the original function}} {{1-1=private }}
864+
func _private_original_usablefrominline_derivative(_ x: Float) -> (value: Float, pullback: (Float) -> Float) {
865+
fatalError()
866+
}
867+
868+
private func private_original_public_derivative(_ x: Float) -> Float { x }
869+
// expected-error @+1 {{derivative function must have same access level as original function; derivative function '_private_original_public_derivative' is public, but original function 'private_original_public_derivative' is private}}
870+
@derivative(of: private_original_public_derivative)
871+
// expected-note @+1 {{mark the derivative function as 'private' to match the original function}} {{1-7=private}}
872+
public func _private_original_public_derivative(_ x: Float) -> (value: Float, pullback: (Float) -> Float) {
873+
fatalError()
874+
}
875+
876+
private func private_original_internal_derivative(_ x: Float) -> Float { x }
877+
// expected-error @+1 {{derivative function must have same access level as original function; derivative function '_private_original_internal_derivative' is internal, but original function 'private_original_internal_derivative' is private}}
878+
@derivative(of: private_original_internal_derivative)
879+
// expected-note @+1 {{mark the derivative function as 'private' to match the original function}}
880+
func _private_original_internal_derivative(_ x: Float) -> (value: Float, pullback: (Float) -> Float) {
881+
fatalError()
882+
}
883+
884+
fileprivate func fileprivate_original_private_derivative(_ x: Float) -> Float { x }
885+
// expected-error @+1 {{derivative function must have same access level as original function; derivative function '_fileprivate_original_private_derivative' is private, but original function 'fileprivate_original_private_derivative' is fileprivate}}
886+
@derivative(of: fileprivate_original_private_derivative)
887+
// expected-note @+1 {{mark the derivative function as 'fileprivate' to match the original function}} {{1-8=fileprivate}}
888+
private func _fileprivate_original_private_derivative(_ x: Float) -> (value: Float, pullback: (Float) -> Float) {
889+
fatalError()
890+
}
891+
892+
private func private_original_fileprivate_derivative(_ x: Float) -> Float { x }
893+
// expected-error @+1 {{derivative function must have same access level as original function; derivative function '_private_original_fileprivate_derivative' is fileprivate, but original function 'private_original_fileprivate_derivative' is private}}
894+
@derivative(of: private_original_fileprivate_derivative)
895+
// expected-note @+1 {{mark the derivative function as 'private' to match the original function}} {{1-12=private}}
896+
fileprivate func _private_original_fileprivate_derivative(_ x: Float) -> (value: Float, pullback: (Float) -> Float) {
897+
fatalError()
898+
}
899+
900+
// MARK: - Original function visibility > derivative function visibility
901+
902+
public func public_original_private_derivative(_ x: Float) -> Float { x }
903+
// expected-error @+1 {{derivative function must have same access level as original function; derivative function '_public_original_private_derivative' is fileprivate, but original function 'public_original_private_derivative' is public}}
904+
@derivative(of: public_original_private_derivative)
905+
// expected-note @+1 {{mark the derivative function as '@usableFromInline' to match the original function}} {{1-1=@usableFromInline }}
906+
fileprivate func _public_original_private_derivative(_ x: Float) -> (value: Float, pullback: (Float) -> Float) {
907+
fatalError()
908+
}
909+
910+
public func public_original_internal_derivative(_ x: Float) -> Float { x }
911+
// expected-error @+1 {{derivative function must have same access level as original function; derivative function '_public_original_internal_derivative' is internal, but original function 'public_original_internal_derivative' is public}}
912+
@derivative(of: public_original_internal_derivative)
913+
// expected-note @+1 {{mark the derivative function as '@usableFromInline' to match the original function}} {{1-1=@usableFromInline }}
914+
func _public_original_internal_derivative(_ x: Float) -> (value: Float, pullback: (Float) -> Float) {
915+
fatalError()
916+
}
917+
918+
func internal_original_fileprivate_derivative(_ x: Float) -> Float { x }
919+
// expected-error @+1 {{derivative function must have same access level as original function; derivative function '_internal_original_fileprivate_derivative' is fileprivate, but original function 'internal_original_fileprivate_derivative' is internal}}
920+
@derivative(of: internal_original_fileprivate_derivative)
921+
// expected-note @+1 {{mark the derivative function as 'internal' to match the original function}} {{1-12=internal}}
922+
fileprivate func _internal_original_fileprivate_derivative(_ x: Float) -> (value: Float, pullback: (Float) -> Float) {
923+
fatalError()
924+
}

0 commit comments

Comments
 (0)