Skip to content

Commit 68b96a2

Browse files
authored
[AutoDiff] Type-check @differentiable attributes during validation. (#27613)
Type-check `@differentiable` attributes during `TypeChecker::validateDecl` for all relevant declaration kinds (initializers, subscripts, variables), not just function declarations. Resolves TF-888. TF-789 tracks proper request-based type-checking for `@differentiable` attribute. Exposes TF-892: `ElementaryFunctions` linker error on Linux.
1 parent 9886b63 commit 68b96a2

File tree

3 files changed

+68
-1
lines changed

3 files changed

+68
-1
lines changed

lib/Sema/TypeCheckDecl.cpp

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3989,6 +3989,11 @@ void TypeChecker::validateDecl(ValueDecl *D) {
39893989
assert(VD->hasInterfaceType());
39903990
}
39913991

3992+
// SWIFT_ENABLE_TENSORFLOW
3993+
// TODO(TF-789): Find proper way to type-check `@differentiable` attributes.
3994+
checkDeclDifferentiableAttributes(VD);
3995+
// SWIFT_ENABLE_TENSORFLOW END
3996+
39923997
// We're not really done with processing the signature yet, but
39933998
// @objc checking requires the declaration to call itself validated
39943999
// so that it can be considered as a witness.
@@ -4118,8 +4123,10 @@ void TypeChecker::validateDecl(ValueDecl *D) {
41184123
// FIXME: Roll all of this interface type computation into a request.
41194124
FD->computeType();
41204125

4121-
// TODO(TF-789): Figure out the proper way to typecheck these.
4126+
// SWIFT_ENABLE_TENSORFLOW
4127+
// TODO(TF-789): Find proper way to type-check `@differentiable` attributes.
41224128
checkDeclDifferentiableAttributes(FD);
4129+
// SWIFT_ENABLE_TENSORFLOW END
41234130

41244131
// Member functions need some special validation logic.
41254132
if (FD->getDeclContext()->isTypeContext()) {
@@ -4164,6 +4171,10 @@ void TypeChecker::validateDecl(ValueDecl *D) {
41644171
typeCheckParameterList(CD->getParameters(), res,
41654172
TypeResolverContext::AbstractFunctionDecl);
41664173
CD->computeType();
4174+
// SWIFT_ENABLE_TENSORFLOW
4175+
// TODO(TF-789): Find proper way to type-check `@differentiable` attributes.
4176+
checkDeclDifferentiableAttributes(CD);
4177+
// SWIFT_ENABLE_TENSORFLOW END
41674178
break;
41684179
}
41694180

@@ -4196,6 +4207,10 @@ void TypeChecker::validateDecl(ValueDecl *D) {
41964207
SF->markDeclWithOpaqueResultTypeAsValidated(SD);
41974208
}
41984209
}
4210+
// SWIFT_ENABLE_TENSORFLOW
4211+
// TODO(TF-789): Find proper way to type-check `@differentiable` attributes.
4212+
checkDeclDifferentiableAttributes(SD);
4213+
// SWIFT_ENABLE_TENSORFLOW END
41994214

42004215
break;
42014216
}
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
// Verify that `@differentiable` declarations can be differentiated from other
2+
// modules.
3+
4+
public struct Foo: Differentiable {
5+
public var x: Float
6+
7+
@differentiable
8+
public init(_ x: Float) {
9+
self.x = x
10+
}
11+
12+
@differentiable
13+
public func method() -> Float {
14+
x
15+
}
16+
17+
@differentiable
18+
public var computedProperty: Float {
19+
x
20+
}
21+
22+
@differentiable
23+
public subscript() -> Float {
24+
x
25+
}
26+
}
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
// Verify that `@differentiable` declarations can be differentiated from other
2+
// modules.
3+
4+
// RUN: %empty-directory(%t)
5+
// RUN: %target-build-swift %S/../Inputs/differentiable_attr_other_module.swift %s -o /dev/null -lm
6+
// NOTE(TF-892): `-lm` is necessary to prevent linker errors related to `ElementaryFunctions` on Ubuntu.
7+
8+
@differentiable(wrt: x)
9+
func testInitializer(_ x: Float) -> Float {
10+
return Foo(x).x
11+
}
12+
13+
@differentiable(wrt: foo)
14+
func testMethod(_ foo: Foo) -> Float {
15+
return foo.method()
16+
}
17+
18+
@differentiable(wrt: foo)
19+
func testComputedProperty(_ foo: Foo) -> Float {
20+
return foo.computedProperty
21+
}
22+
23+
@differentiable(wrt: foo)
24+
func testSubscript(_ foo: Foo) -> Float {
25+
return foo[]
26+
}

0 commit comments

Comments
 (0)