Skip to content

Commit c468e33

Browse files
authored
---
yaml --- r: 277503 b: refs/heads/tensorflow-merge c: f4998a4 h: refs/heads/master i: 277501: 20948e3 277499: 926a272 277495: c3cca24 277487: 43421a2 277471: a86c550 277439: 9d0423a 277375: f02aaf9 277247: a5367cb 276991: b955419 276479: 34835d6
1 parent 73833c6 commit c468e33

File tree

8 files changed

+153
-76
lines changed

8 files changed

+153
-76
lines changed

[refs]

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1124,7 +1124,7 @@ refs/tags/swift-4.2-DEVELOPMENT-SNAPSHOT-2018-10-29-a: 1b087071edaea398480fb778e
11241124
refs/tags/swift-4.2-DEVELOPMENT-SNAPSHOT-2018-10-30-a: 8bc9e108e1480d9217299984e428c601c7aaac75
11251125
refs/tags/swift-4.2.1-RELEASE: 02a6ca969ea1387475b6caeb69c31186df7d30b6
11261126
refs/tags/swift-DEVELOPMENT-SNAPSHOT-2018-11-01-a: 3b0299288f8287094b9ef587f46df54f42a347af
1127-
refs/heads/tensorflow-merge: b6774826a3dbb3e24fe0cdba2350144a10904bfe
1127+
refs/heads/tensorflow-merge: f4998a40fc05876bdbeb2911273024f36357dfed
11281128
refs/heads/TensorFlowLite: b91446471276e37bbfe64767c875f3c7f7102954
11291129
refs/heads/ad-side-effects: 19e0c0de1f59b0929c381925df2e8c72cdf4a728
11301130
refs/heads/add-test-for-asan-compiler-crash: 3cdeecffb47bf28707b299fa2b5bdf0769a4a826

branches/tensorflow-merge/include/swift/AST/DiagnosticsSIL.def

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -375,6 +375,13 @@ NOTE(autodiff_nondifferentiable_argument,none,
375375
NOTE(autodiff_nondifferentiable_result,none,
376376
"cannot differentiate through a non-differentiable result; do you want to "
377377
"add '.withoutDerivative()'?", ())
378+
NOTE(autodiff_noderivative_stored_property,none,
379+
"cannot differentiate through a '@noDerivative' stored property; do you "
380+
"want to add '.withoutDerivative()'?", ())
381+
WARNING(autodiff_nonvaried_result_fixit,none,
382+
"result does not depend on differentiation arguments and will always "
383+
"have a zero derivative; do you want to add '.withoutDerivative()'?",
384+
())
378385
NOTE(autodiff_global_let_closure_not_differentiable,none,
379386
"global constant closure is not differentiable", ())
380387
NOTE(autodiff_cannot_differentiate_global_var_closures,none,

branches/tensorflow-merge/lib/SILOptimizer/Mandatory/Differentiation.cpp

Lines changed: 91 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -1331,21 +1331,28 @@ void DifferentiableActivityInfo::analyze(DominanceInfo *di,
13311331
if (isVaried(cai->getSrc(), i))
13321332
recursivelySetVaried(cai->getDest(), i);
13331333
}
1334-
// Handle `struct_extract`.
1335-
else if (auto *sei = dyn_cast<StructExtractInst>(&inst)) {
1336-
if (isVaried(sei->getOperand(), i)) {
1337-
// If `@noDerivative` exists on the field while the struct is
1338-
// `@_fieldwiseDifferentiable`, this field is not in the set of
1339-
// differentiable variables that we want to track the variedness of.
1340-
auto hasNoDeriv = sei->getField()->getAttrs()
1341-
.hasAttribute<NoDerivativeAttr>();
1342-
auto structIsFieldwiseDiffable = sei->getStructDecl()->getAttrs()
1343-
.hasAttribute<FieldwiseDifferentiableAttr>();
1344-
if (!(hasNoDeriv && structIsFieldwiseDiffable))
1345-
for (auto result : inst.getResults())
1346-
setVaried(result, i);
1347-
}
1348-
}
1334+
1335+
// Handle `struct_extract` and `struct_element_addr` instructions.
1336+
// - If the field is marked `@noDerivative` and belongs to a
1337+
// `@_fieldwiseDifferentiable` struct, do not set the result as varied because
1338+
// it is not in the set of differentiable variables.
1339+
// - Otherwise, propagate variedness from operand to result as usual.
1340+
#define PROPAGATE_VARIED_FOR_STRUCT_EXTRACTION(INST) \
1341+
else if (auto *sei = dyn_cast<INST##Inst>(&inst)) { \
1342+
if (isVaried(sei->getOperand(), i)) { \
1343+
auto hasNoDeriv = sei->getField()->getAttrs() \
1344+
.hasAttribute<NoDerivativeAttr>(); \
1345+
auto structIsFieldwiseDiffable = sei->getStructDecl()->getAttrs() \
1346+
.hasAttribute<FieldwiseDifferentiableAttr>(); \
1347+
if (!(hasNoDeriv && structIsFieldwiseDiffable)) \
1348+
for (auto result : inst.getResults()) \
1349+
setVaried(result, i); \
1350+
} \
1351+
}
1352+
PROPAGATE_VARIED_FOR_STRUCT_EXTRACTION(StructExtract)
1353+
PROPAGATE_VARIED_FOR_STRUCT_EXTRACTION(StructElementAddr)
1354+
#undef VISIT_STRUCT_ELEMENT_INNS
1355+
13491356
// Handle everything else.
13501357
else {
13511358
for (auto &op : inst.getAllOperands())
@@ -3630,6 +3637,37 @@ class AdjointEmitter final : public SILInstructionVisitor<AdjointEmitter> {
36303637
assert(insertion.second); (void)insertion;
36313638
}
36323639

3640+
SILValue getAdjointProjection(SILValue originalProjection) {
3641+
// Handle `struct_element_addr`.
3642+
if (auto *seai = dyn_cast<StructElementAddrInst>(originalProjection)) {
3643+
auto adjBase = getAdjointBuffer(seai->getOperand());
3644+
auto *cotangentVectorDecl =
3645+
adjBase.getType().getStructOrBoundGenericStruct();
3646+
auto cotanFieldLookup =
3647+
cotangentVectorDecl->lookupDirect(seai->getField()->getName());
3648+
assert(cotanFieldLookup.size() == 1);
3649+
auto *cotanField = cast<VarDecl>(cotanFieldLookup.front());
3650+
return builder.createStructElementAddr(
3651+
seai->getLoc(), adjBase.getValue(), cotanField);
3652+
}
3653+
// Handle `tuple_element_addr`.
3654+
if (auto *teai = dyn_cast<TupleElementAddrInst>(originalProjection)) {
3655+
auto adjBase = getAdjointBuffer(teai->getOperand());
3656+
return builder.createTupleElementAddr(
3657+
teai->getLoc(), adjBase.getValue(), teai->getFieldNo());
3658+
}
3659+
// Handle `begin_access`.
3660+
if (auto *bai = dyn_cast<BeginAccessInst>(originalProjection)) {
3661+
auto adjBase = getAdjointBuffer(bai->getOperand());
3662+
if (errorOccurred)
3663+
return (bufferMap[originalProjection] = ValueWithCleanup());
3664+
return builder.createBeginAccess(
3665+
bai->getLoc(), adjBase, bai->getAccessKind(), bai->getEnforcement(),
3666+
/*noNestedConflict*/ false, /*fromBuiltin*/ false);
3667+
}
3668+
return SILValue();
3669+
}
3670+
36333671
ValueWithCleanup &getAdjointBuffer(SILValue originalBuffer) {
36343672
assert(originalBuffer->getType().isAddress());
36353673
assert(originalBuffer->getFunction() == &getOriginal());
@@ -3638,59 +3676,24 @@ class AdjointEmitter final : public SILInstructionVisitor<AdjointEmitter> {
36383676
if (!insertion.second) // not inserted
36393677
return insertion.first->getSecond();
36403678

3641-
// Diagnose non-differentiable buffers.
3642-
if (!originalBuffer->getType().isDifferentiable(getModule())) {
3643-
getContext().emitNondifferentiabilityError(
3644-
originalBuffer, getDifferentiationTask());
3645-
errorOccurred = true;
3646-
return (bufferMap[originalBuffer] = ValueWithCleanup());
3679+
// Diagnose `struct_element_addr` instructions to `@noDerivative` fields.
3680+
if (auto *seai = dyn_cast<StructElementAddrInst>(originalBuffer)) {
3681+
if (seai->getField()->getAttrs().hasAttribute<NoDerivativeAttr>()) {
3682+
getContext().emitNondifferentiabilityError(
3683+
originalBuffer, getDifferentiationTask(),
3684+
diag::autodiff_noderivative_stored_property);
3685+
errorOccurred = true;
3686+
return (bufferMap[originalBuffer] = ValueWithCleanup());
3687+
}
36473688
}
36483689

3649-
// Check whether the original buffer is an address-to-address projection.
3650-
// If so, recurse until the buffer is such a projection but its operand is
3651-
// not. Then, get the adjoint buffer of the operand and return a
3652-
// corresponding projection into it.
3653-
if (Projection::isAddressProjection(originalBuffer) &&
3654-
!Projection::isObjectToAddressProjection(originalBuffer)) {
3655-
// Get operand of the projection (i.e. the base memory).
3656-
auto *inst = cast<SingleValueInstruction>(originalBuffer);
3657-
Projection proj(inst);
3658-
auto loc = inst->getLoc();
3659-
auto base = inst->getOperand(0);
3660-
// Get the corresponding projection into the adjoint buffer.
3661-
SILValue adjProj;
3662-
auto adjBase = getAdjointBuffer(base);
3663-
if (proj.getKind() == ProjectionKind::Struct) {
3664-
auto *origField = proj.getVarDecl(base->getType());
3665-
auto *cotangentVectorDecl =
3666-
adjBase.getType().getStructOrBoundGenericStruct();
3667-
auto cotanFieldLookup =
3668-
cotangentVectorDecl->lookupDirect(origField->getName());
3669-
assert(cotanFieldLookup.size() == 1);
3670-
auto *cotanField = cast<VarDecl>(cotanFieldLookup.front());
3671-
adjProj = builder.createStructElementAddr(loc, adjBase.getValue(),
3672-
cotanField);
3673-
} else {
3674-
adjProj = proj.createAddressProjection(builder, loc, adjBase.getValue())
3675-
.get();
3676-
}
3690+
// If the original buffer is a projection, return a corresponding projection
3691+
// into the adjoint buffer.
3692+
if (auto adjProj = getAdjointProjection(originalBuffer)) {
36773693
ValueWithCleanup projWithCleanup(
3678-
adjProj, makeCleanupFromChildren({adjBase.getCleanup()}));
3694+
adjProj, makeCleanup(adjProj, /*cleanup*/ nullptr));
36793695
return (bufferMap[originalBuffer] = projWithCleanup);
36803696
}
3681-
// If the original buffer is a `begin_access` instruction, get the adjoint
3682-
// buffer of its operand and return a corresponding `begin_access` into it.
3683-
if (auto *bai = dyn_cast<BeginAccessInst>(originalBuffer)) {
3684-
auto adjBase = getAdjointBuffer(bai->getOperand());
3685-
if (errorOccurred)
3686-
return (bufferMap[originalBuffer] = ValueWithCleanup());
3687-
auto *adjAccess = builder.createBeginAccess(
3688-
bai->getLoc(), adjBase, bai->getAccessKind(), bai->getEnforcement(),
3689-
/*noNestedConflict*/ false, /*fromBuiltin*/ false);
3690-
ValueWithCleanup accessWithCleanup(
3691-
adjAccess, makeCleanupFromChildren({adjBase.getCleanup()}));
3692-
return (bufferMap[originalBuffer] = accessWithCleanup);
3693-
}
36943697

36953698
// Set insertion point for local allocation builder: before the last local
36963699
// allocation, or at the start of the adjoint entry BB if no local
@@ -3803,6 +3806,17 @@ class AdjointEmitter final : public SILInstructionVisitor<AdjointEmitter> {
38033806
SmallVector<SILValue, 8> origFormalResults;
38043807
collectAllFormalResultsInTypeOrder(original, origFormalResults);
38053808
auto origResult = origFormalResults[task->getIndices().source];
3809+
// Emit warning if original result is not varied, because it will always have
3810+
// a zero derivative.
3811+
if (!activityInfo.isVaried(origResult, task->getIndices().source)) {
3812+
// Emit fixit if original result has a valid source location.
3813+
auto sourceLoc = origResult.getLoc().getSourceLoc();
3814+
if (sourceLoc.isValid()) {
3815+
getContext()
3816+
.diagnose(sourceLoc, diag::autodiff_nonvaried_result_fixit)
3817+
.fixItInsertAfter(sourceLoc, ".withoutDerivative()");
3818+
}
3819+
}
38063820

38073821
builder.setInsertionPoint(adjointEntry);
38083822
if (seed->getType().isAddress()) {
@@ -4229,6 +4243,9 @@ class AdjointEmitter final : public SILInstructionVisitor<AdjointEmitter> {
42294243
}
42304244

42314245
void visitStructExtractInst(StructExtractInst *sei) {
4246+
assert(!sei->getField()->getAttrs().hasAttribute<NoDerivativeAttr>() &&
4247+
"`struct_extract` with `@noDerivative` field should not be "
4248+
"differentiated; activity analysis should not marked as varied");
42324249
auto loc = sei->getLoc();
42334250
auto &differentiationStrategies =
42344251
getDifferentiationTask()->getStructExtractDifferentiationStrategies();
@@ -4562,6 +4579,17 @@ class AdjointEmitter final : public SILInstructionVisitor<AdjointEmitter> {
45624579
ValueWithCleanup(adjAccess, makeCleanupFromChildren({})));
45634580
}
45644581

4582+
#define PROPAGATE_BUFFER_CLEANUP(INST) \
4583+
void visit##INST##Inst(INST##Inst *inst) { \
4584+
auto &adjBase = getAdjointBuffer(inst->getOperand()); \
4585+
auto &adjProj = getAdjointBuffer(inst); \
4586+
adjProj.setCleanup(makeCleanupFromChildren( \
4587+
{adjProj.getCleanup(), adjBase.getCleanup()})); \
4588+
}
4589+
PROPAGATE_BUFFER_CLEANUP(StructElementAddr)
4590+
PROPAGATE_BUFFER_CLEANUP(TupleElementAddr)
4591+
#undef PROPAGATE_CLEANUP
4592+
45654593
#define NOT_DIFFERENTIABLE(INST, DIAG) \
45664594
void visit##INST##Inst(INST##Inst *inst) { \
45674595
getContext().emitNondifferentiabilityError( \
@@ -4587,10 +4615,6 @@ class AdjointEmitter final : public SILInstructionVisitor<AdjointEmitter> {
45874615
NO_ADJOINT(StrongRetainUnowned)
45884616
NO_ADJOINT(DestroyValue)
45894617
NO_ADJOINT(DestroyAddr)
4590-
// Projection operations have no adjoint visitor.
4591-
// Corresponding adjoint projections are created in `getAdjointBuffer`.
4592-
NO_ADJOINT(StructElementAddr)
4593-
NO_ADJOINT(TupleElementAddr)
45944618
#undef NO_DERIVATIVE
45954619
};
45964620
} // end anonymous namespace

branches/tensorflow-merge/lib/Sema/DerivedConformanceDifferentiable.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -872,7 +872,7 @@ static void checkAndDiagnoseImplicitNoDerivative(TypeChecker &TC,
872872
if (conformsToDifferentiable && !isConstantProperty)
873873
continue;
874874
// Otherwise, add an implicit `@noDerivative` attribute.
875-
nominal->getAttrs().add(
875+
vd->getAttrs().add(
876876
new (TC.Context) NoDerivativeAttr(/*Implicit*/ true));
877877
auto loc =
878878
vd->getLoc().isValid() ? vd->getLoc() : DC->getAsDecl()->getLoc();

branches/tensorflow-merge/test/AutoDiff/autodiff_diagnostics.swift

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,26 @@ extension S : Differentiable, VectorNumeric {
4343
// expected-note @+1 {{property is not differentiable}}
4444
_ = gradient(at: S(p: 0)) { s in 2 * s.p }
4545

46+
struct NoDerivativeProperty : Differentiable {
47+
var x: Float
48+
@noDerivative var y: Float
49+
}
50+
// expected-error @+1 {{function is not differentiable}}
51+
_ = gradient(at: NoDerivativeProperty(x: 1, y: 1)) { s -> Float in
52+
var tmp = s
53+
// expected-note @+1 {{cannot differentiate through a '@noDerivative' stored property; do you want to add '.withoutDerivative()'?}}
54+
tmp.y = tmp.x
55+
return tmp.x
56+
}
57+
_ = gradient(at: NoDerivativeProperty(x: 1, y: 1)) { s in
58+
// expected-warning @+1 {{result does not depend on differentiation arguments and will always have a zero derivative; do you want to add '.withoutDerivative()'?}} {{13-13=.withoutDerivative()}}
59+
return s.y
60+
}
61+
_ = gradient(at: NoDerivativeProperty(x: 1, y: 1)) {
62+
// expected-warning @+1 {{result does not depend on differentiation arguments and will always have a zero derivative; do you want to add '.withoutDerivative()'?}} {{7-7=.withoutDerivative()}}
63+
$0.y
64+
}
65+
4666
//===----------------------------------------------------------------------===//
4767
// Function composition
4868
//===----------------------------------------------------------------------===//
@@ -141,8 +161,6 @@ struct TF_305 : Differentiable {
141161
@noDerivative let activation: Activation
142162
@noDerivative let strides: (Int, Int)
143163

144-
// expected-error @+2 {{function is not differentiable}}
145-
// expected-note @+2 {{when differentiating this function definition}}
146164
@differentiable
147165
public init(
148166
filter: Float,
@@ -153,7 +171,7 @@ struct TF_305 : Differentiable {
153171
self.filter = filter
154172
self.bias = bias
155173
self.activation = activation
156-
self.strides = strides // expected-note {{expression is not differentiable}}
174+
self.strides = strides
157175
}
158176
}
159177

branches/tensorflow-merge/test/AutoDiff/generics.swift

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,8 @@ struct SupervisedTrainer<Model : Layer> {
4949
var model: Model
5050
var lossFunction: @differentiable (Model.Output, Model.Output) -> Float
5151
func fit(y: Model.Output) {
52-
_ = gradient(at: Float(1)) { _ in return lossFunction(y, y) }
52+
// expected-warning @+1 {{result does not depend on differentiation arguments and will always have a zero derivative; do you want to add '.withoutDerivative()'?}} {{58-58=.withoutDerivative()}}
53+
_ = gradient(at: Float(1)) { _ in return lossFunction(y, y) }
5354
}
5455
}
5556

branches/tensorflow-merge/test/AutoDiff/simple_math.swift

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,31 @@ SimpleMathTests.test("StructMemberwiseInitializer") {
183183
expectEqual(2, 𝛁foo)
184184
}
185185

186+
// Tests TF-319: struct with non-differentiable constant stored property.
187+
SimpleMathTests.test("StructConstantStoredProperty") {
188+
struct TF_319 : Differentiable {
189+
var x: Float
190+
@noDerivative let constant = Float(2)
191+
192+
@differentiable
193+
init(x: Float) {
194+
self.x = x
195+
}
196+
197+
@differentiable(wrt: (self, input))
198+
func applied(to input: Float) -> Float {
199+
return x * constant * input
200+
}
201+
}
202+
func testStructInit(to input: Float) -> Float {
203+
let model = TF_319(x: 10)
204+
return model.applied(to: input)
205+
}
206+
expectEqual(TF_319.CotangentVector(x: 6),
207+
gradient(at: TF_319(x: 10), in: { $0.applied(to: 3) }))
208+
expectEqual(20, gradient(at: 3, in: testStructInit))
209+
}
210+
186211
SimpleMathTests.test("StructSideEffects") {
187212
struct Point : AdditiveArithmetic, Differentiable {
188213
var x: Float

branches/tensorflow-merge/test/TensorFlowRuntime/model_autodiff_runtime.swift

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,6 @@ public struct Conv2D<Scalar: TensorFlowFloatingPoint>: Layer {
6060
@noDerivative public let strides: (Int32, Int32)
6161
@noDerivative public let padding: Padding
6262

63-
// TODO(TF-309): Add `@differentiable` initializer using assignments, when supported.
6463
@differentiable
6564
public init(
6665
filter: Tensor<Scalar>,
@@ -69,8 +68,11 @@ public struct Conv2D<Scalar: TensorFlowFloatingPoint>: Layer {
6968
strides: (Int, Int),
7069
padding: Padding
7170
) {
72-
self.init(filter: filter, bias: bias, activation: activation,
73-
strides: (Int32(strides.0), Int32(strides.1)), padding: padding)
71+
self.filter = filter
72+
self.bias = bias
73+
self.activation = activation
74+
self.strides = (Int32(strides.0), Int32(strides.1))
75+
self.padding = padding
7476
}
7577

7678
@differentiable

0 commit comments

Comments
 (0)