Skip to content

[AutoDiff] Fix AdjointEmitter::getAdjointBuffer, clean up tests. #24793

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 15, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion include/swift/AST/Attr.h
Original file line number Diff line number Diff line change
Expand Up @@ -1587,7 +1587,7 @@ class DifferentiableAttr final

// Print the attribute to the given stream.
void print(llvm::raw_ostream &OS, const Decl *D,
ModuleDecl *prettyPrintInModule) const;
ModuleDecl *prettyPrintInModule = nullptr) const;

static bool classof(const DeclAttribute *DA) {
return DA->getKind() == DAK_Differentiable;
Expand Down
8 changes: 7 additions & 1 deletion lib/SILOptimizer/Mandatory/Differentiation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3050,7 +3050,8 @@ class PrimalGenCloner final
context.getTypeConverter(), canGenSig);
auto loweredPullbackType =
getOpType(context.getTypeConverter().getLoweredType(
pullbackDecl->getInterfaceType()->getCanonicalType(), ResilienceExpansion::Minimal))
pullbackDecl->getInterfaceType()->getCanonicalType(),
ResilienceExpansion::Minimal))
.castTo<SILFunctionType>();
if (!loweredPullbackType->isEqual(actualPullbackType)) {
// Set non-reabstracted original pullback type in nested apply info.
Expand Down Expand Up @@ -3790,7 +3791,12 @@ class AdjointEmitter final : public SILInstructionVisitor<AdjointEmitter> {
newBuf->getLoc(), newBuf, SILAccessKind::Init,
SILAccessEnforcement::Static, /*noNestedConflict*/ true,
/*fromBuiltin*/ false);
// Temporarily change global builder insertion point and emit zero into the
// local buffer.
auto insertionPoint = builder.getInsertionBB();
builder.setInsertionPoint(localAllocBuilder.getInsertionPoint());
emitZeroIndirect(access->getType().getASTType(), access, access->getLoc());
builder.setInsertionPoint(insertionPoint);
localAllocBuilder.createEndAccess(
access->getLoc(), access, /*aborted*/ false);
// Create cleanup for local buffer.
Expand Down
43 changes: 22 additions & 21 deletions test/AutoDiff/autodiff_diagnostics.swift
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
// Top-level (before primal/adjoint synthesis)
//===----------------------------------------------------------------------===//

// expected-note @+2 {{opaque non-'@differentiable' function is not differentiable}}
// expected-error @+1 {{expression is not differentiable}}
// expected-error @+2 {{expression is not differentiable}}
// expected-note @+1 {{opaque non-'@differentiable' function is not differentiable}}
func foo(_ f: (Float) -> Float) -> Float {
return gradient(at: 0, in: f)
}
Expand Down Expand Up @@ -67,42 +67,44 @@ _ = gradient(at: NoDerivativeProperty(x: 1, y: 1)) {
// Function composition
//===----------------------------------------------------------------------===//

// FIXME: Figure out why diagnostics no longer accumulate after we removed
// gradient synthesis. When it's fixed, replace "xpected" with "expected" below.
#if false

func uses_optionals(_ x: Float) -> Float {
var maybe: Float? = 10
maybe = x
// xpected-note @+1 {{differentiating control flow is not yet supported}}
// expected-note @+1 {{differentiating control flow is not yet supported}}
return maybe!
}

_ = gradient(at: 0, in: uses_optionals) // xpected-error {{function is not differentiable}}
// expected-error @+1 {{function is not differentiable}}
_ = gradient(at: 0, in: uses_optionals)

func f0(_ x: Float) -> Float {
return x // okay!
func base(_ x: Float) -> Float {
// expected-note @+1 {{cannot differentiate through a non-differentiable result; do you want to add '.withoutDerivative()'?}}
return Float(Int(x))
}

func nested(_ x: Float) -> Float {
return gradient(at: x, in: f0) // xpected-note {{nested differentiation is not supported yet}}
// expected-note @+1 {{when differentiating this function call}}
return base(x)
}

func middle(_ x: Float) -> Float {
let y = uses_optionals(x)
return nested(y) // xpected-note {{when differentiating this function call}}
// expected-note @+1 {{when differentiating this function call}}
return nested(x)
}

func middle2(_ x: Float) -> Float {
return middle(x) // xpected-note {{when differentiating this function call}}
// expected-note @+1 {{when differentiating this function call}}
return middle(x)
}

func func_to_diff(_ x: Float) -> Float {
return middle2(x) // xpected-note {{expression is not differentiable}}
// expected-note @+1 {{expression is not differentiable}}
return middle2(x)
}

func calls_grad_of_nested(_ x: Float) -> Float {
return gradient(at: x, in: func_to_diff) // xpected-error {{function is not differentiable}}
// expected-error @+1 {{function is not differentiable}}
return gradient(at: x, in: func_to_diff)
}

//===----------------------------------------------------------------------===//
Expand All @@ -111,7 +113,7 @@ func calls_grad_of_nested(_ x: Float) -> Float {

func if_else(_ x: Float, _ flag: Bool) -> Float {
let y: Float
// xpected-note @+1 {{differentiating control flow is not supported yet}}
// expected-note @+1 {{differentiating control flow is not yet supported}}
if flag {
y = x + 1
} else {
Expand All @@ -120,10 +122,9 @@ func if_else(_ x: Float, _ flag: Bool) -> Float {
return y
}

// xpected-error @+1 {{function is not differentiable}}
_ = gradient(at: 0) { x in if_else(0, true) }

#endif
// expected-error @+2 {{function is not differentiable}}
// expected-note @+1 {{expression is not differentiable}}
_ = gradient(at: 0) { x in if_else(x, true) }

//===----------------------------------------------------------------------===//
// @differentiable attributes
Expand Down
32 changes: 30 additions & 2 deletions test/AutoDiff/autodiff_indirect_diagnostics.swift
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
// due to direct differentiation of reabstraction thunks, which emits errors
// with unknown location.

// Test unmet generic requirements.

// expected-error @+1 {{function is not differentiable}}
@differentiable
// expected-note @+1 {{when differentiating this function definition}}
Expand All @@ -13,11 +15,37 @@ func generic<T: Differentiable & FloatingPoint>(_ x: T) -> T {
}
_ = gradient(at: 1.0, in: generic) // expected-error {{function is not differentiable}}

// okay!
@differentiable(
vjp: vjpWeirdExtraRequirements
where T : Differentiable & CaseIterable, T.AllCases : ExpressibleByStringLiteral
)
func weird<T>(_ x: T) -> T {
return x
}
func vjpWeirdExtraRequirements<
T : Differentiable & CaseIterable
>(_ x: T) -> (T, (T.CotangentVector) -> T.CotangentVector)
where T.AllCases : ExpressibleByStringLiteral
{
return (x, { $0 })
}
func weirdWrapper<T : Differentiable>(_ x: T) -> T {
// expected-note @+1 {{function call is not differentiable because generic requirements are not met}}
return weird(x)
}
// expected-note @+2 {{expression is not differentiable}}
// expected-error @+1 {{function is not differentiable}}
_ = gradient(at: Float(1), in: { x in weirdWrapper(x) })

/*
// FIXME(TF-482): This currently crashes during differentiation transform.
// because `T` is not constrained to `Differentiable` in generated
// `[differentiable]` attribute.
@differentiable
func directMissingConformance<T : Differentiable>(_ x: T) -> T {
func directMissingConformance<T>(_ x: T) -> T {
return x
}
*/

@differentiable
func direct<T : Differentiable>(_ x: T) -> T {
Expand Down
42 changes: 0 additions & 42 deletions test/AutoDiff/generic_real_vector.swift

This file was deleted.

38 changes: 20 additions & 18 deletions test/AutoDiff/generics.swift
Original file line number Diff line number Diff line change
@@ -1,4 +1,22 @@
// RUN: %target-swift-frontend -emit-sil -verify %s
// RUN: %target-swift-frontend -emit-sil -verify %s | %FileCheck %s -check-prefix=CHECK-SIL

@_silgen_name("identity")
func identity<T : Differentiable>(_ x: T) -> T {
return x
}
_ = gradient(at: Float(1), in: { x in identity(x) })

// Test AdjointEmitter local buffer allocation.
// Verify that local buffers are immediately set to zero.

// CHECK-SIL-LABEL: sil hidden @AD__identity__adjoint_src_0_wrt_0
// CHECK-SIL: [[ORIG_COTAN:%.*]] = alloc_stack $τ_0_0.CotangentVector
// CHECK-SIL-NEXT: [[ORIG_COTAN_BEGIN:%.*]] = begin_access [init] [static] [no_nested_conflict] [[ORIG_COTAN]]
// CHECK-SIL-NEXT: [[ZERO_WITNESS:%.*]] = witness_method $τ_0_0.CotangentVector, #AdditiveArithmetic.zero!getter.1
// CHECK-SIL-NEXT: [[ORIG_COTAN_METATYPE:%.*]] = metatype $@thick τ_0_0.CotangentVector.Type
// CHECK-SIL-NEXT: [[EMIT_ZERO_INDIRECT:%.*]] = apply [[ZERO_WITNESS]]<τ_0_0.CotangentVector>([[ORIG_COTAN_BEGIN]], [[ORIG_COTAN_METATYPE]])
// CHECK-SIL-NEXT: end_access [[ORIG_COTAN_BEGIN]]
// CHECK-SIL: }

struct Tensor<Scalar : FloatingPoint & Differentiable> : VectorNumeric, Differentiable {
// NOTE: `value` must have type with known size (e.g. `Float`, not `Scalar`)
Expand All @@ -12,21 +30,6 @@ func generic<T : FloatingPoint & Differentiable>(_ x: Tensor<T>) -> Float {
}
_ = gradient(at: Tensor<Float>(1), in: generic)

// Test case where associated derivative function's requirements are unmet.

@differentiable(vjp: vjpWeirdExtraRequirements where T : CaseIterable, T.AllCases : ExpressibleByStringLiteral)
func weird<T : FloatingPoint & Differentiable>(_ x: Tensor<T>) -> Tensor<T> {
return x
}
func vjpWeirdExtraRequirements<T : FloatingPoint & Differentiable>(_ x: Tensor<T>) -> (Tensor<T>, (Tensor<T>) -> Tensor<T>) where T : CaseIterable, T.AllCases : ExpressibleByStringLiteral {
return (x, { $0 })
}
func weirdWrapper<T : FloatingPoint & Differentiable>(_ x: Tensor<T>) -> Tensor<T> {
return weird(x) // expected-note {{function call is not differentiable because generic requirements are not met}}
}
_ = pullback(at: Tensor<Float>(1), in: weirdWrapper) // expected-error {{function is not differentiable}}
_ = pullback(at: Tensor<Float>(3), in: weirdWrapper)

// Test case where associated derivative function's requirements are met.
extension Tensor where Scalar : Numeric {
@differentiable(wrt: self where Scalar : Differentiable & FloatingPoint)
Expand All @@ -49,8 +52,7 @@ struct SupervisedTrainer<Model : Layer> {
var model: Model
var lossFunction: @differentiable (Model.Output, Model.Output) -> Float
func fit(y: Model.Output) {
// expected-warning @+1 {{result does not depend on differentiation arguments and will always have a zero derivative; do you want to add '.withoutDerivative()'?}} {{64-64=.withoutDerivative()}}
_ = gradient(at: Float(1)) { _ in return lossFunction(y, y) }
_ = gradient(at: y) { y in return lossFunction(y, y) }
}
}

Expand Down