Skip to content

Commit 65d14d0

Browse files
authored
[AutoDiff] Fix AdjointEmitter::getAdjointBuffer, clean up tests. (#24793)
- Fix `emitZeroIndirect` for local buffer in `AdjointEmitter::getAdjointBuffer`. - Add test to test/AutoDiff/generics.swift. - Clean up tests. - Re-enable parts of test/AutoDiff/autodiff_diagnostics.swift. - Move generic diagnostic test to test/AutoDiff/autodiff_indirect_diagnostics.swift. - Gardening.
1 parent f64718d commit 65d14d0

File tree

6 files changed

+80
-85
lines changed

6 files changed

+80
-85
lines changed

include/swift/AST/Attr.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1587,7 +1587,7 @@ class DifferentiableAttr final
15871587

15881588
// Print the attribute to the given stream.
15891589
void print(llvm::raw_ostream &OS, const Decl *D,
1590-
ModuleDecl *prettyPrintInModule) const;
1590+
ModuleDecl *prettyPrintInModule = nullptr) const;
15911591

15921592
static bool classof(const DeclAttribute *DA) {
15931593
return DA->getKind() == DAK_Differentiable;

lib/SILOptimizer/Mandatory/Differentiation.cpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3050,7 +3050,8 @@ class PrimalGenCloner final
30503050
context.getTypeConverter(), canGenSig);
30513051
auto loweredPullbackType =
30523052
getOpType(context.getTypeConverter().getLoweredType(
3053-
pullbackDecl->getInterfaceType()->getCanonicalType(), ResilienceExpansion::Minimal))
3053+
pullbackDecl->getInterfaceType()->getCanonicalType(),
3054+
ResilienceExpansion::Minimal))
30543055
.castTo<SILFunctionType>();
30553056
if (!loweredPullbackType->isEqual(actualPullbackType)) {
30563057
// Set non-reabstracted original pullback type in nested apply info.
@@ -3790,7 +3791,12 @@ class AdjointEmitter final : public SILInstructionVisitor<AdjointEmitter> {
37903791
newBuf->getLoc(), newBuf, SILAccessKind::Init,
37913792
SILAccessEnforcement::Static, /*noNestedConflict*/ true,
37923793
/*fromBuiltin*/ false);
3794+
// Temporarily change global builder insertion point and emit zero into the
3795+
// local buffer.
3796+
auto insertionPoint = builder.getInsertionBB();
3797+
builder.setInsertionPoint(localAllocBuilder.getInsertionPoint());
37933798
emitZeroIndirect(access->getType().getASTType(), access, access->getLoc());
3799+
builder.setInsertionPoint(insertionPoint);
37943800
localAllocBuilder.createEndAccess(
37953801
access->getLoc(), access, /*aborted*/ false);
37963802
// Create cleanup for local buffer.

test/AutoDiff/autodiff_diagnostics.swift

Lines changed: 22 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44
// Top-level (before primal/adjoint synthesis)
55
//===----------------------------------------------------------------------===//
66

7-
// expected-note @+2 {{opaque non-'@differentiable' function is not differentiable}}
8-
// expected-error @+1 {{expression is not differentiable}}
7+
// expected-error @+2 {{expression is not differentiable}}
8+
// expected-note @+1 {{opaque non-'@differentiable' function is not differentiable}}
99
func foo(_ f: (Float) -> Float) -> Float {
1010
return gradient(at: 0, in: f)
1111
}
@@ -67,42 +67,44 @@ _ = gradient(at: NoDerivativeProperty(x: 1, y: 1)) {
6767
// Function composition
6868
//===----------------------------------------------------------------------===//
6969

70-
// FIXME: Figure out why diagnostics no longer accumulate after we removed
71-
// gradient synthesis. When it's fixed, replace "xpected" with "expected" below.
72-
#if false
73-
7470
func uses_optionals(_ x: Float) -> Float {
7571
var maybe: Float? = 10
7672
maybe = x
77-
// xpected-note @+1 {{differentiating control flow is not yet supported}}
73+
// expected-note @+1 {{differentiating control flow is not yet supported}}
7874
return maybe!
7975
}
8076

81-
_ = gradient(at: 0, in: uses_optionals) // xpected-error {{function is not differentiable}}
77+
// expected-error @+1 {{function is not differentiable}}
78+
_ = gradient(at: 0, in: uses_optionals)
8279

83-
func f0(_ x: Float) -> Float {
84-
return x // okay!
80+
func base(_ x: Float) -> Float {
81+
// expected-note @+1 {{cannot differentiate through a non-differentiable result; do you want to add '.withoutDerivative()'?}}
82+
return Float(Int(x))
8583
}
8684

8785
func nested(_ x: Float) -> Float {
88-
return gradient(at: x, in: f0) // xpected-note {{nested differentiation is not supported yet}}
86+
// expected-note @+1 {{when differentiating this function call}}
87+
return base(x)
8988
}
9089

9190
func middle(_ x: Float) -> Float {
92-
let y = uses_optionals(x)
93-
return nested(y) // xpected-note {{when differentiating this function call}}
91+
// expected-note @+1 {{when differentiating this function call}}
92+
return nested(x)
9493
}
9594

9695
func middle2(_ x: Float) -> Float {
97-
return middle(x) // xpected-note {{when differentiating this function call}}
96+
// expected-note @+1 {{when differentiating this function call}}
97+
return middle(x)
9898
}
9999

100100
func func_to_diff(_ x: Float) -> Float {
101-
return middle2(x) // xpected-note {{expression is not differentiable}}
101+
// expected-note @+1 {{expression is not differentiable}}
102+
return middle2(x)
102103
}
103104

104105
func calls_grad_of_nested(_ x: Float) -> Float {
105-
return gradient(at: x, in: func_to_diff) // xpected-error {{function is not differentiable}}
106+
// expected-error @+1 {{function is not differentiable}}
107+
return gradient(at: x, in: func_to_diff)
106108
}
107109

108110
//===----------------------------------------------------------------------===//
@@ -111,7 +113,7 @@ func calls_grad_of_nested(_ x: Float) -> Float {
111113

112114
func if_else(_ x: Float, _ flag: Bool) -> Float {
113115
let y: Float
114-
// xpected-note @+1 {{differentiating control flow is not supported yet}}
116+
// expected-note @+1 {{differentiating control flow is not yet supported}}
115117
if flag {
116118
y = x + 1
117119
} else {
@@ -120,10 +122,9 @@ func if_else(_ x: Float, _ flag: Bool) -> Float {
120122
return y
121123
}
122124

123-
// xpected-error @+1 {{function is not differentiable}}
124-
_ = gradient(at: 0) { x in if_else(0, true) }
125-
126-
#endif
125+
// expected-error @+2 {{function is not differentiable}}
126+
// expected-note @+1 {{expression is not differentiable}}
127+
_ = gradient(at: 0) { x in if_else(x, true) }
127128

128129
//===----------------------------------------------------------------------===//
129130
// @differentiable attributes

test/AutoDiff/autodiff_indirect_diagnostics.swift

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
// due to direct differentiation of reabstraction thunks, which emits errors
55
// with unknown location.
66

7+
// Test unmet generic requirements.
8+
79
// expected-error @+1 {{function is not differentiable}}
810
@differentiable
911
// expected-note @+1 {{when differentiating this function definition}}
@@ -13,11 +15,37 @@ func generic<T: Differentiable & FloatingPoint>(_ x: T) -> T {
1315
}
1416
_ = gradient(at: 1.0, in: generic) // expected-error {{function is not differentiable}}
1517

16-
// okay!
18+
@differentiable(
19+
vjp: vjpWeirdExtraRequirements
20+
where T : Differentiable & CaseIterable, T.AllCases : ExpressibleByStringLiteral
21+
)
22+
func weird<T>(_ x: T) -> T {
23+
return x
24+
}
25+
func vjpWeirdExtraRequirements<
26+
T : Differentiable & CaseIterable
27+
>(_ x: T) -> (T, (T.CotangentVector) -> T.CotangentVector)
28+
where T.AllCases : ExpressibleByStringLiteral
29+
{
30+
return (x, { $0 })
31+
}
32+
func weirdWrapper<T : Differentiable>(_ x: T) -> T {
33+
// expected-note @+1 {{function call is not differentiable because generic requirements are not met}}
34+
return weird(x)
35+
}
36+
// expected-note @+2 {{expression is not differentiable}}
37+
// expected-error @+1 {{function is not differentiable}}
38+
_ = gradient(at: Float(1), in: { x in weirdWrapper(x) })
39+
40+
/*
41+
// FIXME(TF-482): This currently crashes during differentiation transform.
42+
// because `T` is not constrained to `Differentiable` in generated
43+
// `[differentiable]` attribute.
1744
@differentiable
18-
func directMissingConformance<T : Differentiable>(_ x: T) -> T {
45+
func directMissingConformance<T>(_ x: T) -> T {
1946
return x
2047
}
48+
*/
2149

2250
@differentiable
2351
func direct<T : Differentiable>(_ x: T) -> T {

test/AutoDiff/generic_real_vector.swift

Lines changed: 0 additions & 42 deletions
This file was deleted.

test/AutoDiff/generics.swift

Lines changed: 20 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,22 @@
1-
// RUN: %target-swift-frontend -emit-sil -verify %s
1+
// RUN: %target-swift-frontend -emit-sil -verify %s | %FileCheck %s -check-prefix=CHECK-SIL
2+
3+
@_silgen_name("identity")
4+
func identity<T : Differentiable>(_ x: T) -> T {
5+
return x
6+
}
7+
_ = gradient(at: Float(1), in: { x in identity(x) })
8+
9+
// Test AdjointEmitter local buffer allocation.
10+
// Verify that local buffers are immediately set to zero.
11+
12+
// CHECK-SIL-LABEL: sil hidden @AD__identity__adjoint_src_0_wrt_0
13+
// CHECK-SIL: [[ORIG_COTAN:%.*]] = alloc_stack $τ_0_0.CotangentVector
14+
// CHECK-SIL-NEXT: [[ORIG_COTAN_BEGIN:%.*]] = begin_access [init] [static] [no_nested_conflict] [[ORIG_COTAN]]
15+
// CHECK-SIL-NEXT: [[ZERO_WITNESS:%.*]] = witness_method $τ_0_0.CotangentVector, #AdditiveArithmetic.zero!getter.1
16+
// CHECK-SIL-NEXT: [[ORIG_COTAN_METATYPE:%.*]] = metatype $@thick τ_0_0.CotangentVector.Type
17+
// CHECK-SIL-NEXT: [[EMIT_ZERO_INDIRECT:%.*]] = apply [[ZERO_WITNESS]]<τ_0_0.CotangentVector>([[ORIG_COTAN_BEGIN]], [[ORIG_COTAN_METATYPE]])
18+
// CHECK-SIL-NEXT: end_access [[ORIG_COTAN_BEGIN]]
19+
// CHECK-SIL: }
220

321
struct Tensor<Scalar : FloatingPoint & Differentiable> : VectorNumeric, Differentiable {
422
// NOTE: `value` must have type with known size (e.g. `Float`, not `Scalar`)
@@ -12,21 +30,6 @@ func generic<T : FloatingPoint & Differentiable>(_ x: Tensor<T>) -> Float {
1230
}
1331
_ = gradient(at: Tensor<Float>(1), in: generic)
1432

15-
// Test case where associated derivative function's requirements are unmet.
16-
17-
@differentiable(vjp: vjpWeirdExtraRequirements where T : CaseIterable, T.AllCases : ExpressibleByStringLiteral)
18-
func weird<T : FloatingPoint & Differentiable>(_ x: Tensor<T>) -> Tensor<T> {
19-
return x
20-
}
21-
func vjpWeirdExtraRequirements<T : FloatingPoint & Differentiable>(_ x: Tensor<T>) -> (Tensor<T>, (Tensor<T>) -> Tensor<T>) where T : CaseIterable, T.AllCases : ExpressibleByStringLiteral {
22-
return (x, { $0 })
23-
}
24-
func weirdWrapper<T : FloatingPoint & Differentiable>(_ x: Tensor<T>) -> Tensor<T> {
25-
return weird(x) // expected-note {{function call is not differentiable because generic requirements are not met}}
26-
}
27-
_ = pullback(at: Tensor<Float>(1), in: weirdWrapper) // expected-error {{function is not differentiable}}
28-
_ = pullback(at: Tensor<Float>(3), in: weirdWrapper)
29-
3033
// Test case where associated derivative function's requirements are met.
3134
extension Tensor where Scalar : Numeric {
3235
@differentiable(wrt: self where Scalar : Differentiable & FloatingPoint)
@@ -49,8 +52,7 @@ struct SupervisedTrainer<Model : Layer> {
4952
var model: Model
5053
var lossFunction: @differentiable (Model.Output, Model.Output) -> Float
5154
func fit(y: Model.Output) {
52-
// expected-warning @+1 {{result does not depend on differentiation arguments and will always have a zero derivative; do you want to add '.withoutDerivative()'?}} {{64-64=.withoutDerivative()}}
53-
_ = gradient(at: Float(1)) { _ in return lossFunction(y, y) }
55+
_ = gradient(at: y) { y in return lossFunction(y, y) }
5456
}
5557
}
5658

0 commit comments

Comments
 (0)