Skip to content

[AutoDiff] Eliminate flow-sensitivity in '@differentiable' closure conversions. #25688

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 23, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions include/swift/AST/DiagnosticsSema.def
Original file line number Diff line number Diff line change
Expand Up @@ -1100,10 +1100,6 @@ NOTE(string_index_not_integer_note,none,
"consider using an existing high level algorithm, "
"str.startIndex.advanced(by: n), or a projection like str.utf8", ())

// SWIFT_ENABLE_TENSORFLOW
ERROR(invalid_tensorflow_fn_conversion,none,
"TensorFlow functions cannot be converted to other function types", ())


ERROR(invalid_c_function_pointer_conversion_expr,none,
"a C function pointer can only be formed from a reference to a 'func' or "
Expand All @@ -1118,6 +1114,12 @@ ERROR(c_function_pointer_from_function_with_context,none,
"%select{local function|closure}0 that captures "
"%select{context|generic parameters|dynamic Self type|<<error>}1",
(bool, unsigned))
// SWIFT_ENABLE_TENSORFLOW
ERROR(invalid_differentiable_function_conversion_expr,none,
"a '@differentiable' function can only be formed from a reference to a "
"'func' or a literal closure", ())
NOTE(invalid_differentiable_function_conversion_parameter,none,
"did you mean to take a '%0' closure?", (StringRef))
ERROR(invalid_autoclosure_forwarding,none,
"add () to forward @autoclosure parameter", ())

Expand Down
52 changes: 52 additions & 0 deletions lib/Sema/CSApply.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5849,6 +5849,58 @@ maybeDiagnoseUnsupportedFunctionConversion(ConstraintSystem &cs, Expr *expr,
tc.diagnose(expr->getLoc(),
diag::invalid_c_function_pointer_conversion_expr);
}

// Conversion from a non-`@differentiable` function to a `@differentiable` is
// only allowed from a closure expression or a declaration/member reference.
if (toType->isDifferentiable() && !fromFnType->isDifferentiable()) {
auto maybeDiagnoseFunctionRef = [&](Expr *semanticExpr) {
if (auto *capture = dyn_cast<CaptureListExpr>(semanticExpr))
semanticExpr = capture->getClosureBody();
if (isa<ClosureExpr>(semanticExpr)) return;
if (auto *declRef = dyn_cast<DeclRefExpr>(semanticExpr)) {
if (isa<FuncDecl>(declRef->getDecl())) return;
// If the referenced decl is a function parameter, the user may want
// to change the declaration to be a '@differentiable' closure. Emit a
// note with a fix-it.
if (auto *paramDecl = dyn_cast<ParamDecl>(declRef->getDecl())) {
tc.diagnose(expr->getLoc(),
diag::invalid_differentiable_function_conversion_expr);
if (paramDecl->getType()->is<AnyFunctionType>()) {
auto *typeRepr = paramDecl->getTypeLoc().getTypeRepr();
while (auto *attributed = dyn_cast<AttributedTypeRepr>(typeRepr))
typeRepr = attributed->getTypeRepr();
std::string attributeString = "@differentiable";
switch (toType->getDifferentiabilityKind()) {
case DifferentiabilityKind::Linear:
attributeString += "(linear)";
break;
case DifferentiabilityKind::Normal:
case DifferentiabilityKind::NonDifferentiable:
break;
}
auto *funcTypeRepr = cast<FunctionTypeRepr>(typeRepr);
auto paramListLoc = funcTypeRepr->getArgsTypeRepr()->getStartLoc();
tc.diagnose(paramDecl->getLoc(),
diag::invalid_differentiable_function_conversion_parameter,
attributeString)
.highlight(paramDecl->getTypeLoc().getSourceRange())
.fixItInsert(paramListLoc, attributeString + " ");
}
return;
}
} else if (auto *memberRef = dyn_cast<MemberRefExpr>(semanticExpr)) {
if (isa<FuncDecl>(memberRef->getMember().getDecl())) return;
} else if (auto *dotSyntaxCall =
dyn_cast<DotSyntaxCallExpr>(semanticExpr)) {
if (isa<FuncDecl>(dotSyntaxCall->getFn()
->getSemanticsProvidingExpr()->getReferencedDecl().getDecl()))
return;
}
tc.diagnose(expr->getLoc(),
diag::invalid_differentiable_function_conversion_expr);
};
maybeDiagnoseFunctionRef(getSemanticExprForDeclOrMemberRef(expr));
}
}

/// Build the conversion of an element in a collection upcast.
Expand Down
10 changes: 0 additions & 10 deletions test/AutoDiff/autodiff_diagnostics.swift
Original file line number Diff line number Diff line change
@@ -1,15 +1,5 @@
// RUN: %target-swift-frontend -emit-sil -verify %s

//===----------------------------------------------------------------------===//
// Top-level (before VJP/adjoint synthesis)
//===----------------------------------------------------------------------===//

// expected-note @+1 {{opaque non-'@differentiable' function is not differentiable}}
func foo(_ f: (Float) -> Float) -> Float {
// expected-error @+1 {{function is not differentiable}}
return gradient(at: 0, in: f)
}

//===----------------------------------------------------------------------===//
// Basic function
//===----------------------------------------------------------------------===//
Expand Down
5 changes: 2 additions & 3 deletions test/AutoDiff/closures.swift
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,14 @@ func diffableClosureInStruct(s: Foo) {

public func closureCaptureMutable() {
var val: Float = 10
let clo: (Float) -> Float = { x in
_ = gradient(at: 0) { (x: Float) -> Float in
val += 2
return val * x
}
_ = gradient(at: 0, in: clo)
}

// CHECK-LABEL: @AD__{{.*}}closureCaptureMutable{{.*}}___vjp_src_0_wrt_0
// CHECK: bb0({{%.*}} : $Float, [[BOXED_ARG:%.*]] : ${ var Float }):
// CHECK: bb0({{%.*}} : $Float, [[INOUT_ARG:%.*]] : $*Float):
// CHECK: [[ADJOINT:%.*]] = function_ref @AD__{{.*}}closureCaptureMutabley{{.*}}___adjoint_src_0_wrt_0
// CHECK: {{.*}} = partial_apply [callee_guaranteed] [[ADJOINT]]({{.*}})

Expand Down
58 changes: 44 additions & 14 deletions test/AutoDiff/differentiable_func_type_type_checking.swift
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
let _: @differentiable (Float) -> Float
let _: @differentiable (Float) throws -> Float

//
//===----------------------------------------------------------------------===//
// Type differentiability
//
//===----------------------------------------------------------------------===//

struct NonDiffType { var x: Int }
// FIXME: Properly type-check parameters and the result's differentiability
Expand All @@ -14,9 +14,46 @@ let _: @differentiable (NonDiffType) -> Float
// expected-error @+1 {{result is not differentiable, but the function type is marked '@differentiable'}}
let _: @differentiable (Float) -> NonDiffType

//
// Argument selection (@nondiff)
//
// expected-error @+1 {{cannot mark types as linear differentiable}}
let _: @differentiable(linear) (Float) -> Float

//===----------------------------------------------------------------------===//
// Function conversion
//===----------------------------------------------------------------------===//

func takesOpaqueClosure(f: @escaping (Float) -> Float) {
// expected-note @-1 {{did you mean to take a '@differentiable' closure?}} {{38-38=@differentiable }}
// expected-error @+1 {{a '@differentiable' function can only be formed from a reference to a 'func' or a literal closure}}
_ = gradient(of: f)
}

let globalAddOne: (Float) -> Float = { $0 + 1 }
// expected-error @+1 {{a '@differentiable' function can only be formed from a reference to a 'func' or a literal closure}}
_ = gradient(of: globalAddOne)

func someScope() {
let localAddOne: (Float) -> Float = { $0 + 1 }
// expected-error @+1 {{a '@differentiable' function can only be formed from a reference to a 'func' or a literal closure}}
_ = gradient(of: globalAddOne)
// expected-error @+1 {{a '@differentiable' function can only be formed from a reference to a 'func' or a literal closure}}
_ = gradient(of: localAddOne)
// The following case is okay during type checking, but will fail in the AD transform.
_ = gradient { localAddOne($0) }
}

func addOne(x: Float) -> Float { x + 1 }
_ = gradient(of: addOne) // okay

extension Float {
static func addOne(x: Float) -> Float { x + 1 }
func addOne(x: Float) -> Float { x + 1 }
}
_ = gradient(of: Float.addOne) // okay
_ = gradient(of: Float(1.0).addOne) // okay

//===----------------------------------------------------------------------===//
// Parameter selection (@nondiff)
//===----------------------------------------------------------------------===//

// expected-error @+1 {{'nondiff' cannot be applied to arguments of a non-differentiable function}}
let _: (@nondiff Float, Float) -> Float
Expand Down Expand Up @@ -55,18 +92,11 @@ extension Vector: Differentiable where T: Differentiable {}
// expected-note @+1 {{where 'U' = 'Int'}}
func inferredConformancesGeneric<T, U>(_: @differentiable (Vector<T>) -> Vector<U>) {}

let nondiffVectorFunc: (Vector<Int>) -> Vector<Int>
func nondiffVectorFunc(x: Vector<Int>) -> Vector<Int> {}
// expected-error @+1 2 {{global function 'inferredConformancesGeneric' requires that 'Int' conform to 'Differentiable}}
inferredConformancesGeneric(nondiffVectorFunc)

let diffVectorFunc: (Vector<Float>) -> Vector<Float>
func diffVectorFunc(x: Vector<Float>) -> Vector<Float> {}
inferredConformancesGeneric(diffVectorFunc) // okay!

func inferredConformancesGenericResult<T, U>() -> @differentiable (Vector<T>) -> Vector<U> {}

//
// linear function type
//

// expected-error @+1 {{cannot mark types as linear differentiable}}
let _: @differentiable(linear) (Float) -> Float
8 changes: 4 additions & 4 deletions test/AutoDiff/method.swift
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ MethodTests.test("instance method with generated adjoint, called from differenta
MethodTests.test("instance method with generated adjoint, differentiated directly") {
// This is our current syntax for taking gradients of instance methods
// directly. If/when we develop nicer syntax for this, change this test.
let g = { (p: Parameter) in p.squared() }
func g(p: Parameter) -> Float { p.squared() }
expectEqual(Parameter(x: 4), gradient(at: Parameter(x: 2), in: g))
expectEqual(Parameter(x: 40), gradient(at: Parameter(x: 20), in: g))
}
Expand Down Expand Up @@ -136,7 +136,7 @@ MethodTests.test("static method with generated adjoint, wrt only second param")
}

MethodTests.test("static method with generated adjoint, wrt all params") {
let g = { (a: Parameter, b: Parameter) in a * b }
func g(a: Parameter, b: Parameter) -> Float { a * b }
expectEqual((Parameter(x: 100), Parameter(x: 200)),
gradient(at: Parameter(x: 200), Parameter(x: 100), in: g))
expectEqual((Parameter(x: 200), Parameter(x: 100)),
Expand Down Expand Up @@ -291,7 +291,7 @@ MethodTests.test("instance method with custom adjoint, called from differentated
MethodTests.test("instance method with generated adjoint, differentated directly") {
// This is our current syntax for taking gradients of instance methods
// directly. If/when we develop nicer syntax for this, change this test.
let g = { (p: CustomParameter) in p.squared() }
func g(p: CustomParameter) -> Float { p.squared() }
expectEqual(CustomParameter(x: 4), gradient(at: CustomParameter(x: 2), in: g))
expectEqual(CustomParameter(x: 10), gradient(at: CustomParameter(x: 20), in: g))
}
Expand Down Expand Up @@ -328,7 +328,7 @@ MethodTests.test("instance method with custom adjoint, wrt only non-self") {
}

MethodTests.test("instance method with custom adjoint, wrt self and non-self") {
let g = { (p: CustomParameter, o: Float) in p.multiplied(with: o) }
func g(p: CustomParameter, o: Float) -> Float { p.multiplied(with: o) }
expectEqual((CustomParameter(x: 5), 10), gradient(at: CustomParameter(x: 100), 5, in: g))
expectEqual((CustomParameter(x: 10), 5), gradient(at: CustomParameter(x: 5), 100, in: g))
}
Expand Down
34 changes: 17 additions & 17 deletions test/AutoDiff/SIMD.swift → test/AutoDiff/simd.swift
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ var SIMDTests = TestSuite("SIMD")
SIMDTests.test("init(repeating:)") {
let g = SIMD4<Float>(1, 1, 1, 1)

let foo1 = { (x: Float) -> SIMD4<Float> in
func foo1(x: Float) -> SIMD4<Float> {
return SIMD4<Float>(repeating: 2 * x)
}
let (val1, bp1) = valueWithPullback(at: 5, in: foo1)
Expand All @@ -24,7 +24,7 @@ SIMDTests.test("init(repeating:)") {
SIMDTests.test("Sum") {
let a = SIMD4<Float>(1, 2, 3, 4)

let foo1 = { (x: SIMD4<Float>) -> Float in
func foo1(x: SIMD4<Float>) -> Float {
return x.sum()
}
let (val1, bp1) = valueWithPullback(at: a, in: foo1)
Expand All @@ -36,7 +36,7 @@ SIMDTests.test("Identity") {
let a = SIMD4<Float>(1, 2, 3, 4)
let g = SIMD4<Float>(1, 1, 1, 1)

let foo1 = { (x: SIMD4<Float>) -> SIMD4<Float> in
func foo1(x: SIMD4<Float>) -> SIMD4<Float> {
return x
}
let (val1, bp1) = valueWithPullback(at: a, in: foo1)
Expand All @@ -48,7 +48,7 @@ SIMDTests.test("Negate") {
let a = SIMD4<Float>(1, 2, 3, 4)
let g = SIMD4<Float>(1, 1, 1, 1)

let foo1 = { (x: SIMD4<Float>) -> SIMD4<Float> in
func foo1(x: SIMD4<Float>) -> SIMD4<Float> {
return -x
}
let (val1, bp1) = valueWithPullback(at: a, in: foo1)
Expand All @@ -59,7 +59,7 @@ SIMDTests.test("Negate") {
SIMDTests.test("subscript") {
let a = SIMD4<Float>(1, 2, 3, 4)

let foo1 = { (x: SIMD4<Float>) -> Float in
func foo1(x: SIMD4<Float>) -> Float {
return x[3]
}

Expand All @@ -73,23 +73,23 @@ SIMDTests.test("Addition") {
let g = SIMD4<Float>(1, 1, 1, 1)

// SIMD + SIMD
let foo1 = { (x: SIMD4<Float>, y: SIMD4<Float>) -> SIMD4<Float> in
func foo1(x: SIMD4<Float>, y: SIMD4<Float>) -> SIMD4<Float> {
return x + y
}
let (val1, bp1) = valueWithPullback(at: a, a, in: foo1)
expectEqual(SIMD4<Float>(2, 4, 6, 8), val1)
expectEqual((g, g), bp1(g))

// SIMD + Scalar
let foo2 = { (x: SIMD4<Float>, y: Float) -> SIMD4<Float> in
func foo2(x: SIMD4<Float>, y: Float) -> SIMD4<Float> {
return x + y
}
let (val2, bp2) = valueWithPullback(at: a, 5, in: foo2)
expectEqual(SIMD4<Float>(6, 7, 8, 9), val2)
expectEqual((g, 4), bp2(g))

// Scalar + SIMD
let foo3 = { (x: SIMD4<Float>, y: Float) -> SIMD4<Float> in
func foo3(x: SIMD4<Float>, y: Float) -> SIMD4<Float> {
return y + x
}
let (val3, bp3) = valueWithPullback(at: a, 5, in: foo3)
Expand All @@ -102,23 +102,23 @@ SIMDTests.test("Subtraction") {
let g = SIMD4<Float>(1, 1, 1, 1)

// SIMD - SIMD
let foo1 = { (x: SIMD4<Float>, y: SIMD4<Float>) -> SIMD4<Float> in
func foo1(x: SIMD4<Float>, y: SIMD4<Float>) -> SIMD4<Float> {
return x - y
}
let (val1, bp1) = valueWithPullback(at: a, a, in: foo1)
expectEqual(SIMD4<Float>(0, 0, 0, 0), val1)
expectEqual((g, -g), bp1(g))

// SIMD - Scalar
let foo2 = { (x: SIMD4<Float>, y: Float) -> SIMD4<Float> in
func foo2(x: SIMD4<Float>, y: Float) -> SIMD4<Float> {
return x - y
}
let (val2, bp2) = valueWithPullback(at: a, 5, in: foo2)
expectEqual(SIMD4<Float>(-4, -3, -2, -1), val2)
expectEqual((g, -4), bp2(g))

// Scalar - SIMD
let foo3 = { (x: SIMD4<Float>, y: Float) -> SIMD4<Float> in
func foo3(x: SIMD4<Float>, y: Float) -> SIMD4<Float> {
return y - x
}
let (val3, bp3) = valueWithPullback(at: a, 5, in: foo3)
Expand All @@ -131,23 +131,23 @@ SIMDTests.test("Multiplication") {
let g = SIMD4<Float>(1, 1, 1, 1)

// SIMD * SIMD
let foo1 = { (x: SIMD4<Float>, y: SIMD4<Float>) -> SIMD4<Float> in
func foo1(x: SIMD4<Float>, y: SIMD4<Float>) -> SIMD4<Float> {
return x * y
}
let (val1, bp1) = valueWithPullback(at: a, a, in: foo1)
expectEqual(a * a, val1)
expectEqual((a, a), bp1(g))

// SIMD * Scalar
let foo2 = { (x: SIMD4<Float>, y: Float) -> SIMD4<Float> in
func foo2(x: SIMD4<Float>, y: Float) -> SIMD4<Float> {
return x * y
}
let (val2, bp2) = valueWithPullback(at: a, 5, in: foo2)
expectEqual(a * 5, val2)
expectEqual((SIMD4<Float>(5, 5, 5, 5), 10), bp2(g))

// Scalar * SIMD
let foo3 = { (x: SIMD4<Float>, y: Float) -> SIMD4<Float> in
func foo3(x: SIMD4<Float>, y: Float) -> SIMD4<Float> {
return y * x
}
let (val3, bp3) = valueWithPullback(at: a, 5, in: foo3)
Expand All @@ -160,7 +160,7 @@ SIMDTests.test("Division") {
let g = SIMD4<Float>(1, 1, 1, 1)

// SIMD / SIMD
let foo1 = { (x: SIMD4<Float>, y: SIMD4<Float>) -> SIMD4<Float> in
func foo1(x: SIMD4<Float>, y: SIMD4<Float>) -> SIMD4<Float> {
return x / y
}
let dlhs1 = g / a
Expand All @@ -170,7 +170,7 @@ SIMDTests.test("Division") {
expectEqual((dlhs1, drhs1), bp1(g))

// SIMD / Scalar
let foo2 = { (x: SIMD4<Float>, y: Float) -> SIMD4<Float> in
func foo2(x: SIMD4<Float>, y: Float) -> SIMD4<Float> {
return x / y
}
let dlhs2 = g / 5
Expand All @@ -180,7 +180,7 @@ SIMDTests.test("Division") {
expectEqual((dlhs2, drhs2), bp2(g))

// Scalar / SIMD
let foo3 = { (x: Float, y: SIMD4<Float>) -> SIMD4<Float> in
func foo3(x: Float, y: SIMD4<Float>) -> SIMD4<Float> {
return x / y
}
let dlhs3 = (g / a).sum()
Expand Down
Loading