Skip to content

Commit ad4ae1e

Browse files
authored
---
yaml --- r: 262141 b: refs/heads/tensorflow c: a123f9b h: refs/heads/master i: 262139: cd68cd1
1 parent 07c9c5b commit ad4ae1e

File tree

12 files changed

+118
-328
lines changed

12 files changed

+118
-328
lines changed

[refs]

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -818,7 +818,7 @@ refs/tags/swift-DEVELOPMENT-SNAPSHOT-2018-04-25-a: 22f738a831d43aff2b9c9773bcb65
818818
refs/tags/swift-DEVELOPMENT-SNAPSHOT-2018-05-08-a: 7d98cc16689baba5c8a3b90a9329bdcc1a12b4e9
819819
refs/heads/cherr42: a566ad54b073c2c56ac0a705d0a5bed9743135a5
820820
"refs/heads/codable_test_comment_fix": fc8f6824f7f347e1e8db55bff62db385c5728b5a
821-
refs/heads/tensorflow: b9a98185310f63bd530103e2ec6ea0f12501b572
821+
refs/heads/tensorflow: a123f9b8b242dc1455334e4f9e7ca34118953e97
822822
refs/tags/swift-4.1-DEVELOPMENT-SNAPSHOT-2018-05-11-a: 8126fd7a652e2f70ad6d76505239e34fb2ef3e1a
823823
refs/tags/swift-4.1-DEVELOPMENT-SNAPSHOT-2018-05-12-a: b3fd3dd84df6717f2e2e9df58c6d7e99fed57086
824824
refs/tags/swift-4.1-DEVELOPMENT-SNAPSHOT-2018-05-13-a: 71135119579039dc321c5f65d870050fe36efda2

branches/tensorflow/include/swift/AST/Attr.def

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -395,7 +395,9 @@ SIMPLE_DECL_ATTR(TensorFlowGraph, TensorFlowGraph,
395395
SIMPLE_DECL_ATTR(TFParameter, TFParameter,
396396
OnVar, 83)
397397
SIMPLE_DECL_ATTR(_fieldwiseProductSpace, FieldwiseProductSpace,
398-
OnTypeAlias | UserInaccessible, 84)
398+
OnTypeAlias | OnNominalType | UserInaccessible, 84)
399+
SIMPLE_DECL_ATTR(noDerivative, NoDerivative,
400+
OnVar, 85)
399401

400402
#undef TYPE_ATTR
401403
#undef DECL_ATTR_ALIAS

branches/tensorflow/include/swift/AST/DiagnosticsSema.def

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2735,6 +2735,11 @@ ERROR(tfparameter_attr_instance_stored_property_only,none,
27352735
ERROR(tfparameter_attr_not_in_parameterized,none,
27362736
"@%0 is allowed only in types that conform to 'Parameterized'", (StringRef))
27372737

2738+
// @noDerivative attribute
2739+
ERROR(noderivative_only_on_stored_properties_in_differentiable_structs,none,
2740+
"@noDerivative is only allowed on stored properties in structure types "
2741+
"that declare a conformance to 'Differentiable'", ())
2742+
27382743
//------------------------------------------------------------------------------
27392744
// MARK: Type Check Expressions
27402745
//------------------------------------------------------------------------------

branches/tensorflow/lib/SILOptimizer/Mandatory/TFDifferentiation.cpp

Lines changed: 15 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2290,12 +2290,15 @@ class PrimalGenCloner final : public SILClonerWithScopes<PrimalGenCloner> {
22902290

22912291
// Use the FieldwiseProductSpace strategy, if appropriate.
22922292
auto *structDecl = sei->getStructDecl();
2293-
auto aliasLookup = structDecl->lookupDirect(astCtx.Id_CotangentVector);
2294-
if (aliasLookup.size() >= 1) {
2295-
assert(aliasLookup.size() == 1);
2296-
assert(isa<TypeAliasDecl>(aliasLookup[0]));
2297-
auto *aliasDecl = cast<TypeAliasDecl>(aliasLookup[0]);
2298-
if (aliasDecl->getAttrs().hasAttribute<FieldwiseProductSpaceAttr>()) {
2293+
auto cotangentDeclLookup =
2294+
structDecl->lookupDirect(astCtx.Id_CotangentVector);
2295+
if (cotangentDeclLookup.size() >= 1) {
2296+
assert(cotangentDeclLookup.size() == 1);
2297+
auto cotangentTypeDecl = cotangentDeclLookup.front();
2298+
assert(isa<TypeAliasDecl>(cotangentTypeDecl) ||
2299+
isa<StructDecl>(cotangentTypeDecl));
2300+
if (cotangentTypeDecl->getAttrs()
2301+
.hasAttribute<FieldwiseProductSpaceAttr>()) {
22992302
structExtractDifferentiationStrategies.insert(
23002303
{sei, StructExtractDifferentiationStrategy::FieldwiseProductSpace});
23012304
SILClonerWithScopes::visitStructExtractInst(sei);
@@ -3666,8 +3669,6 @@ class AdjointEmitter final : public SILInstructionVisitor<AdjointEmitter> {
36663669

36673670
void visitStructExtractInst(StructExtractInst *sei) {
36683671
auto loc = remapLocation(sei->getLoc());
3669-
auto &astCtx = getContext().getASTContext();
3670-
36713672
auto &differentiationStrategies =
36723673
getDifferentiationTask()->getStructExtractDifferentiationStrategies();
36733674
auto differentiationStrategyLookUp = differentiationStrategies.find(sei);
@@ -3689,14 +3690,11 @@ class AdjointEmitter final : public SILInstructionVisitor<AdjointEmitter> {
36893690
// `key`.
36903691

36913692
// Find the decl of the cotangent space type.
3692-
auto *structDecl = sei->getStructDecl();
3693-
auto aliasLookup = structDecl->lookupDirect(astCtx.Id_CotangentVector);
3694-
assert(aliasLookup.size() == 1);
3695-
assert(isa<TypeAliasDecl>(aliasLookup[0]));
3696-
auto *aliasDecl = cast<TypeAliasDecl>(aliasLookup[0]);
3697-
assert(aliasDecl->getAttrs().hasAttribute<FieldwiseProductSpaceAttr>());
3698-
auto cotangentVectorTy =
3699-
aliasDecl->getUnderlyingTypeLoc().getType()->getCanonicalType();
3693+
auto structTy = sei->getOperand()->getType().getASTType();
3694+
auto cotangentVectorTy = structTy->getAutoDiffAssociatedVectorSpace(
3695+
AutoDiffAssociatedVectorSpaceKind::Cotangent,
3696+
LookUpConformanceInModule(getModule().getSwiftModule()))
3697+
->getType()->getCanonicalType();
37003698
assert(!getModule()
37013699
.Types.getTypeLowering(cotangentVectorTy)
37023700
.isAddressOnly());
@@ -3708,7 +3706,7 @@ class AdjointEmitter final : public SILInstructionVisitor<AdjointEmitter> {
37083706

37093707
// Find the corresponding field in the cotangent space.
37103708
VarDecl *correspondingField = nullptr;
3711-
if (cotangentVectorDecl == structDecl)
3709+
if (cotangentVectorDecl == sei->getStructDecl())
37123710
correspondingField = sei->getField();
37133711
else {
37143712
auto correspondingFieldLookup =

branches/tensorflow/lib/Sema/DerivedConformanceDifferentiable.cpp

Lines changed: 13 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -58,24 +58,14 @@ static ValueDecl *getProtocolRequirement(ProtocolDecl *proto, Identifier name) {
5858
}
5959

6060
// Get the stored properties of a nominal type that are relevant for
61-
// differentiation.
62-
// - If the nominal conforms to `Parameterized`, return only the stored
63-
// properties marked with `@TFParameter`.
64-
// - Otherwise, return all stored properties.
65-
static SmallVector<VarDecl *, 4>
66-
getStoredPropertiesForDifferentiation(NominalTypeDecl *nominal) {
67-
auto &C = nominal->getASTContext();
68-
auto *parameterizedProto = C.getProtocol(KnownProtocolKind::Parameterized);
69-
SmallVector<VarDecl *, 4> storedProperties;
70-
if (TypeChecker::conformsToProtocol(
71-
nominal->getDeclaredInterfaceType(), parameterizedProto,
72-
nominal->getDeclContext(), ConformanceCheckFlags::Used)) {
73-
nominal->getAllTFParameters(storedProperties);
74-
} else {
75-
storedProperties.append(nominal->getStoredProperties().begin(),
76-
nominal->getStoredProperties().end());
61+
// differentiation, except the ones tagged `@noDerivative`.
62+
static void getStoredPropertiesForDifferentiation(
63+
NominalTypeDecl *nominal, SmallVectorImpl<VarDecl *> &result) {
64+
for (auto *vd : nominal->getStoredProperties()) {
65+
if (vd->getAttrs().hasAttribute<NoDerivativeAttr>())
66+
continue;
67+
result.push_back(vd);
7768
}
78-
return storedProperties;
7969
}
8070

8171
bool DerivedConformance::canDeriveDifferentiable(NominalTypeDecl *nominal) {
@@ -122,17 +112,11 @@ bool DerivedConformance::canDeriveDifferentiable(NominalTypeDecl *nominal) {
122112

123113
// If there are no valid vector space types, `Self` must conform to either:
124114
// - `VectorNumeric`. Vector space types will be set to `Self`.
125-
// - `Parameterized`. Vector space types will be set to `Parameters` member
126-
// struct type.
127115
// TODO(dan-zheng): Lift this restriction.
128116
if (validTangentDeclCount == 0 || validCotangentDeclCount == 0) {
129117
auto *vectorNumericProto = C.getProtocol(KnownProtocolKind::VectorNumeric);
130-
auto *parameterizedProto = C.getProtocol(KnownProtocolKind::Parameterized);
131118
if (!TypeChecker::conformsToProtocol(
132119
nominal->getDeclaredInterfaceType(), vectorNumericProto,
133-
nominal->getDeclContext(), ConformanceCheckFlags::Used) &&
134-
!TypeChecker::conformsToProtocol(
135-
nominal->getDeclaredInterfaceType(), parameterizedProto,
136120
nominal->getDeclContext(), ConformanceCheckFlags::Used))
137121
return false;
138122
}
@@ -141,8 +125,9 @@ bool DerivedConformance::canDeriveDifferentiable(NominalTypeDecl *nominal) {
141125
// Currently, all stored properties must also have
142126
// `Self == TangentVector == CotangentVector`.
143127
// TODO(dan-zheng): Lift this restriction.
144-
return llvm::all_of(
145-
getStoredPropertiesForDifferentiation(structDecl), [&](VarDecl *v) {
128+
SmallVector<VarDecl *, 16> diffProperties;
129+
getStoredPropertiesForDifferentiation(structDecl, diffProperties);
130+
return llvm::all_of(diffProperties, [&](VarDecl *v) {
146131
if (!v->hasType())
147132
lazyResolver->resolveDeclSignature(v);
148133
if (!v->hasType())
@@ -273,21 +258,9 @@ static void deriveBodyDifferentiable_method(AbstractFunctionDecl *funcDecl,
273258
// Create call expression applying a member method to a parameter member.
274259
// Format: `<member>.method(<parameter>.<member>)`.
275260
// Example: `x.moved(along: direction.x)`.
276-
auto parameterizedProto = C.getProtocol(KnownProtocolKind::Parameterized);
277-
auto retNominalIsParameterized = TypeChecker::conformsToProtocol(
278-
retNominal->getDeclaredInterfaceType(), parameterizedProto,
279-
retNominal->getDeclContext(), ConformanceCheckFlags::Used);
280-
281261
auto createMemberMethodCallExpr = [&](VarDecl *member) -> Expr * {
282262
auto module = nominal->getModuleContext();
283263
auto confRef = module->lookupConformance(member->getType(), diffProto);
284-
// If the returned nominal is `Parameterized` and the member does not have
285-
// `@TFParameter`, create direct reference to member.
286-
if (retNominalIsParameterized &&
287-
!member->getAttrs().hasAttribute<TFParameterAttr>()) {
288-
return new (C) MemberRefExpr(selfDRE, SourceLoc(), member, DeclNameLoc(),
289-
/*Implicit*/ true);
290-
}
291264
assert(confRef && "Member does not conform to 'Differentiable'");
292265

293266
// Get member type's method, e.g. `Member.moved(along:)`.
@@ -463,12 +436,11 @@ deriveDifferentiable_VectorSpace(DerivedConformance &derived,
463436
auto nominal = derived.Nominal;
464437
auto &C = nominal->getASTContext();
465438

466-
// TODO: Check if nominal type conforms to `Parameterized` in addition to
467-
// `Differentiable`. If so, return the associated `Parameters` struct.
468-
469439
// Check if all members have vector space associated types equal to `Self`.
440+
SmallVector<VarDecl *, 16> diffProperties;
441+
getStoredPropertiesForDifferentiation(nominal, diffProperties);
470442
bool allMembersVectorSpaceEqualsSelf = llvm::all_of(
471-
getStoredPropertiesForDifferentiation(nominal), [&](VarDecl *member) {
443+
diffProperties, [&](VarDecl *member) {
472444
auto memberAssocType =
473445
nominal->mapTypeIntoContext(getVectorSpaceType(member, kind));
474446
return member->getType()->isEqual(memberAssocType);

branches/tensorflow/lib/Sema/TypeCheckAttr.cpp

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,7 @@ class AttributeEarlyChecker : public AttributeVisitor<AttributeEarlyChecker> {
127127
IGNORED_ATTR(TensorFlowGraph)
128128
IGNORED_ATTR(TFParameter)
129129
IGNORED_ATTR(FieldwiseProductSpace)
130+
IGNORED_ATTR(NoDerivative)
130131
#undef IGNORED_ATTR
131132

132133
// @noreturn has been replaced with a 'Never' return type.
@@ -886,6 +887,7 @@ class AttributeChecker : public AttributeVisitor<AttributeChecker> {
886887
void visitTensorFlowGraphAttr(TensorFlowGraphAttr *attr);
887888
void visitTFParameterAttr(TFParameterAttr *attr);
888889
void visitFieldwiseProductSpaceAttr(FieldwiseProductSpaceAttr *attr);
890+
void visitNoDerivativeAttr(NoDerivativeAttr *attr);
889891
};
890892
} // end anonymous namespace
891893

@@ -2720,6 +2722,27 @@ void AttributeChecker::visitFieldwiseProductSpaceAttr(
27202722
// here: the assertions in TFDifferentiation suffice.
27212723
}
27222724

2725+
void AttributeChecker::visitNoDerivativeAttr(NoDerivativeAttr *attr) {
2726+
auto *vd = dyn_cast<VarDecl>(D);
2727+
if (!vd) {
2728+
diagnoseAndRemoveAttr(attr,
2729+
diag::noderivative_only_on_stored_properties_in_differentiable_structs);
2730+
return;
2731+
}
2732+
auto *structDecl = dyn_cast<StructDecl>(vd->getDeclContext());
2733+
if (!structDecl) {
2734+
diagnoseAndRemoveAttr(attr,
2735+
diag::noderivative_only_on_stored_properties_in_differentiable_structs);
2736+
return;
2737+
}
2738+
auto *diffable = TC.Context.getProtocol(KnownProtocolKind::Differentiable);
2739+
if (!TC.conformsToProtocol(structDecl->getDeclaredInterfaceType(), diffable,
2740+
structDecl->getDeclContext(),
2741+
ConformanceCheckFlags::Used))
2742+
diagnoseAndRemoveAttr(attr,
2743+
diag::noderivative_only_on_stored_properties_in_differentiable_structs);
2744+
}
2745+
27232746
void TypeChecker::checkDeclAttributes(Decl *D) {
27242747
AttributeChecker Checker(*this, D);
27252748

branches/tensorflow/lib/Sema/TypeCheckDeclOverride.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1219,6 +1219,7 @@ namespace {
12191219
UNINTERESTING_ATTR(TensorFlowGraph)
12201220
UNINTERESTING_ATTR(TFParameter)
12211221
UNINTERESTING_ATTR(FieldwiseProductSpace)
1222+
UNINTERESTING_ATTR(NoDerivative)
12221223

12231224
// These can't appear on overridable declarations.
12241225
UNINTERESTING_ATTR(Prefix)
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
// RUN: %target-swift-frontend -typecheck -verify %s
2+
3+
// expected-error @+1 {{@noDerivative is only allowed on stored properties in structure types that declare a conformance to 'Differentiable'}}
4+
@noDerivative var flag: Bool
5+
6+
struct Foo {
7+
// expected-error @+1 {{@noDerivative is only allowed on stored properties in structure types that declare a conformance to 'Differentiable'}}
8+
@noDerivative var flag: Bool
9+
}
10+
11+
// expected-error @+1 {{type 'Bar' does not conform to protocol 'Differentiable'}}
12+
struct Bar : Differentiable {
13+
@noDerivative var flag: Bool
14+
}

branches/tensorflow/test/IDE/complete_decl_attribute.swift

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ struct S{}
9595
@#^KEYWORD_LAST^#
9696

9797
// SWIFT_ENABLE_TENSORFLOW
98-
// KEYWORD_LAST: Begin completions, 26 items
98+
// KEYWORD_LAST: Begin completions, 27 items
9999
// KEYWORD_LAST-NEXT: Keyword/None: available[#Declaration Attribute#]; name=available{{$}}
100100
// KEYWORD_LAST-NEXT: Keyword/None: objc[#Declaration Attribute#]; name=objc{{$}}
101101
// SWIFT_ENABLE_TENSORFLOW
@@ -124,4 +124,5 @@ struct S{}
124124
// KEYWORD_LAST-NEXT: Keyword/None: compilerEvaluable[#Declaration Attribute#]; name=compilerEvaluable
125125
// KEYWORD_LAST-NEXT: Keyword/None: TensorFlowGraph[#Declaration Attribute#]; name=TensorFlowGraph
126126
// KEYWORD_LAST-NEXT: Keyword/None: TFParameter[#Declaration Attribute#]; name=TFParameter
127+
// KEYWORD_LAST-NEXT: Keyword/None: noDerivative[#Declaration Attribute#]; name=noDerivative
127128
// KEYWORD_LAST-NEXT: End completions

0 commit comments

Comments
 (0)