Skip to content

Commit e83f0e5

Browse files
authored
---
yaml --- r: 340726 b: refs/heads/rxwei-patch-1 c: 6b822d8 h: refs/heads/master
1 parent 12cd6bf commit e83f0e5

File tree

9 files changed

+133
-60
lines changed

9 files changed

+133
-60
lines changed

[refs]

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1015,7 +1015,7 @@ refs/tags/swift-DEVELOPMENT-SNAPSHOT-2018-08-18-a: b10b1fce14385faa6d44f6b933e95
10151015
refs/heads/rdar-43033749-fix-batch-mode-no-diags-swift-5.0-branch: a14e64eaad30de89f0f5f0b2a782eed7ecdcb255
10161016
refs/heads/revert-19006-error-bridging-integer-type: 8a9065a3696535305ea53fe9b71f91cbe6702019
10171017
refs/heads/revert-19050-revert-19006-error-bridging-integer-type: ecf752d54b05dd0a20f510f0bfa54a3fec3bcaca
1018-
refs/heads/rxwei-patch-1: 71ba224ad4d4b6ce04bc41e521fc9bc166d2b684
1018+
refs/heads/rxwei-patch-1: 6b822d8ac5ae878278e5f53c669d9917367f029f
10191019
refs/heads/shahmishal-patch-1: e58ec0f7488258d42bef51bc3e6d7b3dc74d7b2a
10201020
refs/heads/typelist-existential: 4046359efd541fb5c72d69a92eefc0a784df8f5e
10211021
refs/tags/swift-4.2-DEVELOPMENT-SNAPSHOT-2018-08-20-a: 4319ba09e4fb8650ee86061075c74a016b6baab9

branches/rxwei-patch-1/include/swift/AST/DiagnosticsSema.def

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1100,10 +1100,6 @@ NOTE(string_index_not_integer_note,none,
11001100
"consider using an existing high level algorithm, "
11011101
"str.startIndex.advanced(by: n), or a projection like str.utf8", ())
11021102

1103-
// SWIFT_ENABLE_TENSORFLOW
1104-
ERROR(invalid_tensorflow_fn_conversion,none,
1105-
"TensorFlow functions cannot be converted to other function types", ())
1106-
11071103

11081104
ERROR(invalid_c_function_pointer_conversion_expr,none,
11091105
"a C function pointer can only be formed from a reference to a 'func' or "
@@ -1118,6 +1114,12 @@ ERROR(c_function_pointer_from_function_with_context,none,
11181114
"%select{local function|closure}0 that captures "
11191115
"%select{context|generic parameters|dynamic Self type|<<error>}1",
11201116
(bool, unsigned))
1117+
// SWIFT_ENABLE_TENSORFLOW
1118+
ERROR(invalid_differentiable_function_conversion_expr,none,
1119+
"a '@differentiable' function can only be formed from a reference to a "
1120+
"'func' or a literal closure", ())
1121+
NOTE(invalid_differentiable_function_conversion_parameter,none,
1122+
"did you mean to take a '%0' closure?", (StringRef))
11211123
ERROR(invalid_autoclosure_forwarding,none,
11221124
"add () to forward @autoclosure parameter", ())
11231125

branches/rxwei-patch-1/lib/Sema/CSApply.cpp

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5849,6 +5849,58 @@ maybeDiagnoseUnsupportedFunctionConversion(ConstraintSystem &cs, Expr *expr,
58495849
tc.diagnose(expr->getLoc(),
58505850
diag::invalid_c_function_pointer_conversion_expr);
58515851
}
5852+
5853+
// Conversion from a non-`@differentiable` function to a `@differentiable` is
5854+
// only allowed from a closure expression or a declaration/member reference.
5855+
if (toType->isDifferentiable() && !fromFnType->isDifferentiable()) {
5856+
auto maybeDiagnoseFunctionRef = [&](Expr *semanticExpr) {
5857+
if (auto *capture = dyn_cast<CaptureListExpr>(semanticExpr))
5858+
semanticExpr = capture->getClosureBody();
5859+
if (isa<ClosureExpr>(semanticExpr)) return;
5860+
if (auto *declRef = dyn_cast<DeclRefExpr>(semanticExpr)) {
5861+
if (isa<FuncDecl>(declRef->getDecl())) return;
5862+
// If the referenced decl is a function parameter, the user may want
5863+
// to change the declaration to be a '@differentiable' closure. Emit a
5864+
// note with a fix-it.
5865+
if (auto *paramDecl = dyn_cast<ParamDecl>(declRef->getDecl())) {
5866+
tc.diagnose(expr->getLoc(),
5867+
diag::invalid_differentiable_function_conversion_expr);
5868+
if (paramDecl->getType()->is<AnyFunctionType>()) {
5869+
auto *typeRepr = paramDecl->getTypeLoc().getTypeRepr();
5870+
while (auto *attributed = dyn_cast<AttributedTypeRepr>(typeRepr))
5871+
typeRepr = attributed->getTypeRepr();
5872+
std::string attributeString = "@differentiable";
5873+
switch (toType->getDifferentiabilityKind()) {
5874+
case DifferentiabilityKind::Linear:
5875+
attributeString += "(linear)";
5876+
break;
5877+
case DifferentiabilityKind::Normal:
5878+
case DifferentiabilityKind::NonDifferentiable:
5879+
break;
5880+
}
5881+
auto *funcTypeRepr = cast<FunctionTypeRepr>(typeRepr);
5882+
auto paramListLoc = funcTypeRepr->getArgsTypeRepr()->getStartLoc();
5883+
tc.diagnose(paramDecl->getLoc(),
5884+
diag::invalid_differentiable_function_conversion_parameter,
5885+
attributeString)
5886+
.highlight(paramDecl->getTypeLoc().getSourceRange())
5887+
.fixItInsert(paramListLoc, attributeString + " ");
5888+
}
5889+
return;
5890+
}
5891+
} else if (auto *memberRef = dyn_cast<MemberRefExpr>(semanticExpr)) {
5892+
if (isa<FuncDecl>(memberRef->getMember().getDecl())) return;
5893+
} else if (auto *dotSyntaxCall =
5894+
dyn_cast<DotSyntaxCallExpr>(semanticExpr)) {
5895+
if (isa<FuncDecl>(dotSyntaxCall->getFn()
5896+
->getSemanticsProvidingExpr()->getReferencedDecl().getDecl()))
5897+
return;
5898+
}
5899+
tc.diagnose(expr->getLoc(),
5900+
diag::invalid_differentiable_function_conversion_expr);
5901+
};
5902+
maybeDiagnoseFunctionRef(getSemanticExprForDeclOrMemberRef(expr));
5903+
}
58525904
}
58535905

58545906
/// Build the conversion of an element in a collection upcast.

branches/rxwei-patch-1/test/AutoDiff/autodiff_diagnostics.swift

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,5 @@
11
// RUN: %target-swift-frontend -emit-sil -verify %s
22

3-
//===----------------------------------------------------------------------===//
4-
// Top-level (before VJP/adjoint synthesis)
5-
//===----------------------------------------------------------------------===//
6-
7-
// expected-note @+1 {{opaque non-'@differentiable' function is not differentiable}}
8-
func foo(_ f: (Float) -> Float) -> Float {
9-
// expected-error @+1 {{function is not differentiable}}
10-
return gradient(at: 0, in: f)
11-
}
12-
133
//===----------------------------------------------------------------------===//
144
// Basic function
155
//===----------------------------------------------------------------------===//

branches/rxwei-patch-1/test/AutoDiff/closures.swift

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,15 +16,14 @@ func diffableClosureInStruct(s: Foo) {
1616

1717
public func closureCaptureMutable() {
1818
var val: Float = 10
19-
let clo: (Float) -> Float = { x in
19+
_ = gradient(at: 0) { (x: Float) -> Float in
2020
val += 2
2121
return val * x
2222
}
23-
_ = gradient(at: 0, in: clo)
2423
}
2524

2625
// CHECK-LABEL: @AD__{{.*}}closureCaptureMutable{{.*}}___vjp_src_0_wrt_0
27-
// CHECK: bb0({{%.*}} : $Float, [[BOXED_ARG:%.*]] : ${ var Float }):
26+
// CHECK: bb0({{%.*}} : $Float, [[INOUT_ARG:%.*]] : $*Float):
2827
// CHECK: [[ADJOINT:%.*]] = function_ref @AD__{{.*}}closureCaptureMutabley{{.*}}___adjoint_src_0_wrt_0
2928
// CHECK: {{.*}} = partial_apply [callee_guaranteed] [[ADJOINT]]({{.*}})
3029

branches/rxwei-patch-1/test/AutoDiff/differentiable_func_type_type_checking.swift

Lines changed: 44 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,9 @@
33
let _: @differentiable (Float) -> Float
44
let _: @differentiable (Float) throws -> Float
55

6-
//
6+
//===----------------------------------------------------------------------===//
77
// Type differentiability
8-
//
8+
//===----------------------------------------------------------------------===//
99

1010
struct NonDiffType { var x: Int }
1111
// FIXME: Properly type-check parameters and the result's differentiability
@@ -14,9 +14,46 @@ let _: @differentiable (NonDiffType) -> Float
1414
// expected-error @+1 {{result is not differentiable, but the function type is marked '@differentiable'}}
1515
let _: @differentiable (Float) -> NonDiffType
1616

17-
//
18-
// Argument selection (@nondiff)
19-
//
17+
// expected-error @+1 {{cannot mark types as linear differentiable}}
18+
let _: @differentiable(linear) (Float) -> Float
19+
20+
//===----------------------------------------------------------------------===//
21+
// Function conversion
22+
//===----------------------------------------------------------------------===//
23+
24+
func takesOpaqueClosure(f: @escaping (Float) -> Float) {
25+
// expected-note @-1 {{did you mean to take a '@differentiable' closure?}} {{38-38=@differentiable }}
26+
// expected-error @+1 {{a '@differentiable' function can only be formed from a reference to a 'func' or a literal closure}}
27+
_ = gradient(of: f)
28+
}
29+
30+
let globalAddOne: (Float) -> Float = { $0 + 1 }
31+
// expected-error @+1 {{a '@differentiable' function can only be formed from a reference to a 'func' or a literal closure}}
32+
_ = gradient(of: globalAddOne)
33+
34+
func someScope() {
35+
let localAddOne: (Float) -> Float = { $0 + 1 }
36+
// expected-error @+1 {{a '@differentiable' function can only be formed from a reference to a 'func' or a literal closure}}
37+
_ = gradient(of: globalAddOne)
38+
// expected-error @+1 {{a '@differentiable' function can only be formed from a reference to a 'func' or a literal closure}}
39+
_ = gradient(of: localAddOne)
40+
// The following case is okay during type checking, but will fail in the AD transform.
41+
_ = gradient { localAddOne($0) }
42+
}
43+
44+
func addOne(x: Float) -> Float { x + 1 }
45+
_ = gradient(of: addOne) // okay
46+
47+
extension Float {
48+
static func addOne(x: Float) -> Float { x + 1 }
49+
func addOne(x: Float) -> Float { x + 1 }
50+
}
51+
_ = gradient(of: Float.addOne) // okay
52+
_ = gradient(of: Float(1.0).addOne) // okay
53+
54+
//===----------------------------------------------------------------------===//
55+
// Parameter selection (@nondiff)
56+
//===----------------------------------------------------------------------===//
2057

2158
// expected-error @+1 {{'nondiff' cannot be applied to arguments of a non-differentiable function}}
2259
let _: (@nondiff Float, Float) -> Float
@@ -55,18 +92,11 @@ extension Vector: Differentiable where T: Differentiable {}
5592
// expected-note @+1 {{where 'U' = 'Int'}}
5693
func inferredConformancesGeneric<T, U>(_: @differentiable (Vector<T>) -> Vector<U>) {}
5794

58-
let nondiffVectorFunc: (Vector<Int>) -> Vector<Int>
95+
func nondiffVectorFunc(x: Vector<Int>) -> Vector<Int> {}
5996
// expected-error @+1 2 {{global function 'inferredConformancesGeneric' requires that 'Int' conform to 'Differentiable}}
6097
inferredConformancesGeneric(nondiffVectorFunc)
6198

62-
let diffVectorFunc: (Vector<Float>) -> Vector<Float>
99+
func diffVectorFunc(x: Vector<Float>) -> Vector<Float> {}
63100
inferredConformancesGeneric(diffVectorFunc) // okay!
64101

65102
func inferredConformancesGenericResult<T, U>() -> @differentiable (Vector<T>) -> Vector<U> {}
66-
67-
//
68-
// linear function type
69-
//
70-
71-
// expected-error @+1 {{cannot mark types as linear differentiable}}
72-
let _: @differentiable(linear) (Float) -> Float

branches/rxwei-patch-1/test/AutoDiff/method.swift

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ MethodTests.test("instance method with generated adjoint, called from differenta
7575
MethodTests.test("instance method with generated adjoint, differentiated directly") {
7676
// This is our current syntax for taking gradients of instance methods
7777
// directly. If/when we develop nicer syntax for this, change this test.
78-
let g = { (p: Parameter) in p.squared() }
78+
func g(p: Parameter) -> Float { p.squared() }
7979
expectEqual(Parameter(x: 4), gradient(at: Parameter(x: 2), in: g))
8080
expectEqual(Parameter(x: 40), gradient(at: Parameter(x: 20), in: g))
8181
}
@@ -136,7 +136,7 @@ MethodTests.test("static method with generated adjoint, wrt only second param")
136136
}
137137

138138
MethodTests.test("static method with generated adjoint, wrt all params") {
139-
let g = { (a: Parameter, b: Parameter) in a * b }
139+
func g(a: Parameter, b: Parameter) -> Float { a * b }
140140
expectEqual((Parameter(x: 100), Parameter(x: 200)),
141141
gradient(at: Parameter(x: 200), Parameter(x: 100), in: g))
142142
expectEqual((Parameter(x: 200), Parameter(x: 100)),
@@ -291,7 +291,7 @@ MethodTests.test("instance method with custom adjoint, called from differentated
291291
MethodTests.test("instance method with generated adjoint, differentated directly") {
292292
// This is our current syntax for taking gradients of instance methods
293293
// directly. If/when we develop nicer syntax for this, change this test.
294-
let g = { (p: CustomParameter) in p.squared() }
294+
func g(p: CustomParameter) -> Float { p.squared() }
295295
expectEqual(CustomParameter(x: 4), gradient(at: CustomParameter(x: 2), in: g))
296296
expectEqual(CustomParameter(x: 10), gradient(at: CustomParameter(x: 20), in: g))
297297
}
@@ -328,7 +328,7 @@ MethodTests.test("instance method with custom adjoint, wrt only non-self") {
328328
}
329329

330330
MethodTests.test("instance method with custom adjoint, wrt self and non-self") {
331-
let g = { (p: CustomParameter, o: Float) in p.multiplied(with: o) }
331+
func g(p: CustomParameter, o: Float) -> Float { p.multiplied(with: o) }
332332
expectEqual((CustomParameter(x: 5), 10), gradient(at: CustomParameter(x: 100), 5, in: g))
333333
expectEqual((CustomParameter(x: 10), 5), gradient(at: CustomParameter(x: 5), 100, in: g))
334334
}

branches/rxwei-patch-1/test/AutoDiff/SIMD.swift renamed to branches/rxwei-patch-1/test/AutoDiff/simd.swift

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ var SIMDTests = TestSuite("SIMD")
1313
SIMDTests.test("init(repeating:)") {
1414
let g = SIMD4<Float>(1, 1, 1, 1)
1515

16-
let foo1 = { (x: Float) -> SIMD4<Float> in
16+
func foo1(x: Float) -> SIMD4<Float> {
1717
return SIMD4<Float>(repeating: 2 * x)
1818
}
1919
let (val1, bp1) = valueWithPullback(at: 5, in: foo1)
@@ -24,7 +24,7 @@ SIMDTests.test("init(repeating:)") {
2424
SIMDTests.test("Sum") {
2525
let a = SIMD4<Float>(1, 2, 3, 4)
2626

27-
let foo1 = { (x: SIMD4<Float>) -> Float in
27+
func foo1(x: SIMD4<Float>) -> Float {
2828
return x.sum()
2929
}
3030
let (val1, bp1) = valueWithPullback(at: a, in: foo1)
@@ -36,7 +36,7 @@ SIMDTests.test("Identity") {
3636
let a = SIMD4<Float>(1, 2, 3, 4)
3737
let g = SIMD4<Float>(1, 1, 1, 1)
3838

39-
let foo1 = { (x: SIMD4<Float>) -> SIMD4<Float> in
39+
func foo1(x: SIMD4<Float>) -> SIMD4<Float> {
4040
return x
4141
}
4242
let (val1, bp1) = valueWithPullback(at: a, in: foo1)
@@ -48,7 +48,7 @@ SIMDTests.test("Negate") {
4848
let a = SIMD4<Float>(1, 2, 3, 4)
4949
let g = SIMD4<Float>(1, 1, 1, 1)
5050

51-
let foo1 = { (x: SIMD4<Float>) -> SIMD4<Float> in
51+
func foo1(x: SIMD4<Float>) -> SIMD4<Float> {
5252
return -x
5353
}
5454
let (val1, bp1) = valueWithPullback(at: a, in: foo1)
@@ -59,7 +59,7 @@ SIMDTests.test("Negate") {
5959
SIMDTests.test("subscript") {
6060
let a = SIMD4<Float>(1, 2, 3, 4)
6161

62-
let foo1 = { (x: SIMD4<Float>) -> Float in
62+
func foo1(x: SIMD4<Float>) -> Float {
6363
return x[3]
6464
}
6565

@@ -73,23 +73,23 @@ SIMDTests.test("Addition") {
7373
let g = SIMD4<Float>(1, 1, 1, 1)
7474

7575
// SIMD + SIMD
76-
let foo1 = { (x: SIMD4<Float>, y: SIMD4<Float>) -> SIMD4<Float> in
76+
func foo1(x: SIMD4<Float>, y: SIMD4<Float>) -> SIMD4<Float> {
7777
return x + y
7878
}
7979
let (val1, bp1) = valueWithPullback(at: a, a, in: foo1)
8080
expectEqual(SIMD4<Float>(2, 4, 6, 8), val1)
8181
expectEqual((g, g), bp1(g))
8282

8383
// SIMD + Scalar
84-
let foo2 = { (x: SIMD4<Float>, y: Float) -> SIMD4<Float> in
84+
func foo2(x: SIMD4<Float>, y: Float) -> SIMD4<Float> {
8585
return x + y
8686
}
8787
let (val2, bp2) = valueWithPullback(at: a, 5, in: foo2)
8888
expectEqual(SIMD4<Float>(6, 7, 8, 9), val2)
8989
expectEqual((g, 4), bp2(g))
9090

9191
// Scalar + SIMD
92-
let foo3 = { (x: SIMD4<Float>, y: Float) -> SIMD4<Float> in
92+
func foo3(x: SIMD4<Float>, y: Float) -> SIMD4<Float> {
9393
return y + x
9494
}
9595
let (val3, bp3) = valueWithPullback(at: a, 5, in: foo3)
@@ -102,23 +102,23 @@ SIMDTests.test("Subtraction") {
102102
let g = SIMD4<Float>(1, 1, 1, 1)
103103

104104
// SIMD - SIMD
105-
let foo1 = { (x: SIMD4<Float>, y: SIMD4<Float>) -> SIMD4<Float> in
105+
func foo1(x: SIMD4<Float>, y: SIMD4<Float>) -> SIMD4<Float> {
106106
return x - y
107107
}
108108
let (val1, bp1) = valueWithPullback(at: a, a, in: foo1)
109109
expectEqual(SIMD4<Float>(0, 0, 0, 0), val1)
110110
expectEqual((g, -g), bp1(g))
111111

112112
// SIMD - Scalar
113-
let foo2 = { (x: SIMD4<Float>, y: Float) -> SIMD4<Float> in
113+
func foo2(x: SIMD4<Float>, y: Float) -> SIMD4<Float> {
114114
return x - y
115115
}
116116
let (val2, bp2) = valueWithPullback(at: a, 5, in: foo2)
117117
expectEqual(SIMD4<Float>(-4, -3, -2, -1), val2)
118118
expectEqual((g, -4), bp2(g))
119119

120120
// Scalar - SIMD
121-
let foo3 = { (x: SIMD4<Float>, y: Float) -> SIMD4<Float> in
121+
func foo3(x: SIMD4<Float>, y: Float) -> SIMD4<Float> {
122122
return y - x
123123
}
124124
let (val3, bp3) = valueWithPullback(at: a, 5, in: foo3)
@@ -131,23 +131,23 @@ SIMDTests.test("Multiplication") {
131131
let g = SIMD4<Float>(1, 1, 1, 1)
132132

133133
// SIMD * SIMD
134-
let foo1 = { (x: SIMD4<Float>, y: SIMD4<Float>) -> SIMD4<Float> in
134+
func foo1(x: SIMD4<Float>, y: SIMD4<Float>) -> SIMD4<Float> {
135135
return x * y
136136
}
137137
let (val1, bp1) = valueWithPullback(at: a, a, in: foo1)
138138
expectEqual(a * a, val1)
139139
expectEqual((a, a), bp1(g))
140140

141141
// SIMD * Scalar
142-
let foo2 = { (x: SIMD4<Float>, y: Float) -> SIMD4<Float> in
142+
func foo2(x: SIMD4<Float>, y: Float) -> SIMD4<Float> {
143143
return x * y
144144
}
145145
let (val2, bp2) = valueWithPullback(at: a, 5, in: foo2)
146146
expectEqual(a * 5, val2)
147147
expectEqual((SIMD4<Float>(5, 5, 5, 5), 10), bp2(g))
148148

149149
// Scalar * SIMD
150-
let foo3 = { (x: SIMD4<Float>, y: Float) -> SIMD4<Float> in
150+
func foo3(x: SIMD4<Float>, y: Float) -> SIMD4<Float> {
151151
return y * x
152152
}
153153
let (val3, bp3) = valueWithPullback(at: a, 5, in: foo3)
@@ -160,7 +160,7 @@ SIMDTests.test("Division") {
160160
let g = SIMD4<Float>(1, 1, 1, 1)
161161

162162
// SIMD / SIMD
163-
let foo1 = { (x: SIMD4<Float>, y: SIMD4<Float>) -> SIMD4<Float> in
163+
func foo1(x: SIMD4<Float>, y: SIMD4<Float>) -> SIMD4<Float> {
164164
return x / y
165165
}
166166
let dlhs1 = g / a
@@ -170,7 +170,7 @@ SIMDTests.test("Division") {
170170
expectEqual((dlhs1, drhs1), bp1(g))
171171

172172
// SIMD / Scalar
173-
let foo2 = { (x: SIMD4<Float>, y: Float) -> SIMD4<Float> in
173+
func foo2(x: SIMD4<Float>, y: Float) -> SIMD4<Float> {
174174
return x / y
175175
}
176176
let dlhs2 = g / 5
@@ -180,7 +180,7 @@ SIMDTests.test("Division") {
180180
expectEqual((dlhs2, drhs2), bp2(g))
181181

182182
// Scalar / SIMD
183-
let foo3 = { (x: Float, y: SIMD4<Float>) -> SIMD4<Float> in
183+
func foo3(x: Float, y: SIMD4<Float>) -> SIMD4<Float> {
184184
return x / y
185185
}
186186
let dlhs3 = (g / a).sum()

0 commit comments

Comments
 (0)