Skip to content

Commit f163072

Browse files
committed
[AutoDiff] Add TangentStoredPropertyRequest.
Add request that resolves the "tangent stored property" corresponding to an original stored property in a `Differentiable`-conforming type. Enables better non-differentiability differentiation transform diagnostics.
1 parent 1052d3c commit f163072

File tree

7 files changed

+271
-22
lines changed

7 files changed

+271
-22
lines changed

include/swift/AST/ASTTypeIDZone.def

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ SWIFT_TYPEID(PropertyWrapperTypeInfo)
2727
SWIFT_TYPEID(Requirement)
2828
SWIFT_TYPEID(ResilienceExpansion)
2929
SWIFT_TYPEID(FragileFunctionKind)
30+
SWIFT_TYPEID(TangentPropertyInfo)
3031
SWIFT_TYPEID(Type)
3132
SWIFT_TYPEID(TypePair)
3233
SWIFT_TYPEID(TypeWitnessAndDecl)

include/swift/AST/ASTTypeIDs.h

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
#include "swift/Basic/LLVM.h"
2121
#include "swift/Basic/TypeID.h"
22+
2223
namespace swift {
2324

2425
class AbstractFunctionDecl;
@@ -58,14 +59,14 @@ class Requirement;
5859
enum class ResilienceExpansion : unsigned;
5960
struct FragileFunctionKind;
6061
class SourceFile;
62+
struct TangentPropertyInfo;
6163
class Type;
62-
class ValueDecl;
63-
class VarDecl;
64-
class Witness;
6564
class TypeAliasDecl;
66-
class Type;
6765
struct TypePair;
6866
struct TypeWitnessAndDecl;
67+
class ValueDecl;
68+
class VarDecl;
69+
class Witness;
6970
enum class AncestryFlags : uint8_t;
7071
enum class ImplicitMemberAction : uint8_t;
7172
struct FingerprintAndMembers;

include/swift/AST/AutoDiff.h

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ class AnyFunctionType;
3535
class SourceFile;
3636
class SILFunctionType;
3737
class TupleType;
38+
class VarDecl;
3839

3940
/// A function type differentiability kind.
4041
enum class DifferentiabilityKind : uint8_t {
@@ -459,6 +460,99 @@ class DerivativeFunctionTypeError
459460
}
460461
};
461462

463+
/// Describes the "tangent stored property" corresponding to an original stored
464+
/// property in a `Differentiable`-conforming type.
465+
///
466+
/// The tangent stored property is the stored property in the `TangentVector`
467+
/// struct of the `Differentiable`-conforming type, with the same name as the
468+
/// original stored property and with the original stored property's
469+
/// `TangentVector` type.
470+
struct TangentPropertyInfo {
471+
struct Error {
472+
enum class Kind {
473+
/// The original property is `@noDerivative`.
474+
NoDerivativeOriginalProperty,
475+
/// The nominal parent type does not conform to `Differentiable`.
476+
NominalParentNotDifferentiable,
477+
/// The original property's type does not conform to `Differentiable`.
478+
OriginalPropertyNotDifferentiable,
479+
/// The parent `TangentVector` type is not a struct.
480+
ParentTangentVectorNotStruct,
481+
/// The parent `TangentVector` struct does not declare a stored property
482+
/// with the same name as the original property.
483+
TangentPropertyNotFound,
484+
/// The tangent property's type is not equal to the original property's
485+
/// `TangentVector` type.
486+
TangentPropertyWrongType,
487+
/// The tangent property is not a stored property.
488+
TangentPropertyNotStored
489+
};
490+
491+
/// The error kind.
492+
Kind kind;
493+
494+
private:
495+
union Value {
496+
Type type;
497+
Value(Type type) : type(type) {}
498+
Value() {}
499+
} value;
500+
501+
public:
502+
Error(Kind kind) : kind(kind), value() {
503+
assert(kind == Kind::NoDerivativeOriginalProperty ||
504+
kind == Kind::NominalParentNotDifferentiable ||
505+
kind == Kind::OriginalPropertyNotDifferentiable ||
506+
kind == Kind::ParentTangentVectorNotStruct ||
507+
kind == Kind::TangentPropertyNotFound ||
508+
kind == Kind::TangentPropertyNotStored);
509+
};
510+
511+
Error(Kind kind, Type type) : kind(kind), value(type) {
512+
assert(kind == Kind::TangentPropertyWrongType);
513+
};
514+
515+
Type getType() const {
516+
assert(kind == Kind::TangentPropertyWrongType);
517+
return value.type;
518+
}
519+
520+
friend bool operator==(const Error &lhs, const Error &rhs);
521+
};
522+
523+
/// The tangent stored property.
524+
VarDecl *tangentProperty = nullptr;
525+
526+
/// An optional error.
527+
Optional<Error> error = None;
528+
529+
private:
530+
TangentPropertyInfo(VarDecl *tangentProperty, Optional<Error> error)
531+
: tangentProperty(tangentProperty), error(error) {}
532+
533+
public:
534+
TangentPropertyInfo(VarDecl *tangentProperty)
535+
: TangentPropertyInfo(tangentProperty, None) {}
536+
537+
TangentPropertyInfo(Error::Kind errorKind)
538+
: TangentPropertyInfo(nullptr, Error(errorKind)) {}
539+
540+
TangentPropertyInfo(Error::Kind errorKind, Type errorType)
541+
: TangentPropertyInfo(nullptr, Error(errorKind, errorType)) {}
542+
543+
/// Returns `true` iff this tangent property info is valid.
544+
bool isValid() const { return tangentProperty && !error; }
545+
546+
explicit operator bool() const { return isValid(); }
547+
548+
friend bool operator==(const TangentPropertyInfo &lhs,
549+
const TangentPropertyInfo &rhs) {
550+
return lhs.tangentProperty == rhs.tangentProperty && lhs.error == rhs.error;
551+
}
552+
};
553+
554+
void simple_display(llvm::raw_ostream &OS, TangentPropertyInfo info);
555+
462556
/// The key type used for uniquing `SILDifferentiabilityWitness` in
463557
/// `SILModule`: original function name, parameter indices, result indices, and
464558
/// derivative generic signature.

include/swift/AST/TypeCheckRequests.h

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2191,6 +2191,27 @@ class DerivativeAttrOriginalDeclRequest
21912191
bool isCached() const { return true; }
21922192
};
21932193

2194+
/// Resolves the "tangent stored property" corresponding to an original stored
2195+
/// property in a `Differentiable`-conforming type.
2196+
class TangentStoredPropertyRequest
2197+
: public SimpleRequest<TangentStoredPropertyRequest,
2198+
TangentPropertyInfo(VarDecl *),
2199+
RequestFlags::Cached> {
2200+
public:
2201+
using SimpleRequest::SimpleRequest;
2202+
2203+
private:
2204+
friend SimpleRequest;
2205+
2206+
// Evaluation.
2207+
TangentPropertyInfo evaluate(Evaluator &evaluator,
2208+
VarDecl *originalField) const;
2209+
2210+
public:
2211+
// Caching.
2212+
bool isCached() const { return true; }
2213+
};
2214+
21942215
/// Checks whether a type eraser has a viable initializer.
21952216
class TypeEraserHasViableInitRequest
21962217
: public SimpleRequest<TypeEraserHasViableInitRequest,

include/swift/AST/TypeCheckerTypeIDZone.def

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,8 @@ SWIFT_REQUEST(TypeChecker, SuperclassTypeRequest,
203203
SWIFT_REQUEST(TypeChecker, SynthesizeAccessorRequest,
204204
AccessorDecl *(AbstractStorageDecl *, AccessorKind),
205205
SeparatelyCached, NoLocationInfo)
206+
SWIFT_REQUEST(TypeChecker, TangentStoredPropertyRequest,
207+
llvm::Expected<VarDecl *>(VarDecl *), Cached, NoLocationInfo)
206208
SWIFT_REQUEST(TypeChecker, TypeCheckFunctionBodyRequest,
207209
bool(AbstractFunctionDecl *), Cached, NoLocationInfo)
208210
SWIFT_REQUEST(TypeChecker, TypeCheckFunctionBodyAtLocRequest,

lib/AST/AutoDiff.cpp

Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -421,3 +421,128 @@ void DerivativeFunctionTypeError::log(raw_ostream &OS) const {
421421
}
422422
}
423423
}
424+
425+
bool swift::operator==(const TangentPropertyInfo::Error &lhs,
426+
const TangentPropertyInfo::Error &rhs) {
427+
if (lhs.kind != rhs.kind)
428+
return false;
429+
switch (lhs.kind) {
430+
case TangentPropertyInfo::Error::Kind::NoDerivativeOriginalProperty:
431+
case TangentPropertyInfo::Error::Kind::NominalParentNotDifferentiable:
432+
case TangentPropertyInfo::Error::Kind::OriginalPropertyNotDifferentiable:
433+
case TangentPropertyInfo::Error::Kind::ParentTangentVectorNotStruct:
434+
case TangentPropertyInfo::Error::Kind::TangentPropertyNotFound:
435+
case TangentPropertyInfo::Error::Kind::TangentPropertyNotStored:
436+
return true;
437+
case TangentPropertyInfo::Error::Kind::TangentPropertyWrongType:
438+
return lhs.getType()->isEqual(rhs.getType());
439+
}
440+
}
441+
442+
void swift::simple_display(llvm::raw_ostream &os, TangentPropertyInfo info) {
443+
os << "{ ";
444+
os << "tangent property: "
445+
<< (info.tangentProperty ? info.tangentProperty->printRef() : "null");
446+
if (info.error) {
447+
os << ", error: ";
448+
switch (info.error->kind) {
449+
case TangentPropertyInfo::Error::Kind::NoDerivativeOriginalProperty:
450+
os << "'@noDerivative' original property has no tangent property";
451+
break;
452+
case TangentPropertyInfo::Error::Kind::NominalParentNotDifferentiable:
453+
os << "nominal parent does not conform to 'Differentiable'";
454+
break;
455+
case TangentPropertyInfo::Error::Kind::OriginalPropertyNotDifferentiable:
456+
os << "original property type does not conform to 'Differentiable'";
457+
break;
458+
case TangentPropertyInfo::Error::Kind::ParentTangentVectorNotStruct:
459+
os << "'TangentVector' type is not a struct";
460+
break;
461+
case TangentPropertyInfo::Error::Kind::TangentPropertyNotFound:
462+
os << "'TangentVector' struct does not have stored property with the "
463+
"same name as the original property";
464+
break;
465+
case TangentPropertyInfo::Error::Kind::TangentPropertyWrongType:
466+
os << "tangent property's type is not equal to the original property's "
467+
"'TangentVector' type";
468+
break;
469+
case TangentPropertyInfo::Error::Kind::TangentPropertyNotStored:
470+
os << "'TangentVector' property '" << info.tangentProperty->getName()
471+
<< "' is not a stored property";
472+
break;
473+
}
474+
}
475+
os << " }";
476+
}
477+
478+
TangentPropertyInfo
479+
TangentStoredPropertyRequest::evaluate(Evaluator &evaluator,
480+
VarDecl *originalField) const {
481+
assert(originalField->hasStorage() && originalField->isInstanceMember() &&
482+
"Expected stored property");
483+
auto *parentDC = originalField->getDeclContext();
484+
assert(parentDC->isTypeContext());
485+
auto parentType = parentDC->getDeclaredTypeInContext();
486+
auto *moduleDecl = originalField->getModuleContext();
487+
auto parentTan = parentType->getAutoDiffTangentSpace(
488+
LookUpConformanceInModule(moduleDecl));
489+
// Error if parent nominal type does not conform to `Differentiable`.
490+
if (!parentTan) {
491+
return TangentPropertyInfo(
492+
TangentPropertyInfo::Error::Kind::NominalParentNotDifferentiable);
493+
}
494+
// Error if original stored property is `@noDerivative`.
495+
if (originalField->getAttrs().hasAttribute<NoDerivativeAttr>()) {
496+
return TangentPropertyInfo(
497+
TangentPropertyInfo::Error::Kind::NoDerivativeOriginalProperty);
498+
}
499+
// Error if original property's type does not conform to `Differentiable`.
500+
auto originalFieldTan = originalField->getType()->getAutoDiffTangentSpace(
501+
LookUpConformanceInModule(moduleDecl));
502+
if (!originalFieldTan) {
503+
return TangentPropertyInfo(
504+
TangentPropertyInfo::Error::Kind::OriginalPropertyNotDifferentiable);
505+
}
506+
auto parentTanType = parentTan->getType();
507+
auto *parentTanStruct = parentTanType->getStructOrBoundGenericStruct();
508+
// Error if parent `TangentVector` is not a struct.
509+
if (!parentTanStruct) {
510+
return TangentPropertyInfo(
511+
TangentPropertyInfo::Error::Kind::ParentTangentVectorNotStruct);
512+
}
513+
// Find the corresponding field in the tangent space.
514+
VarDecl *tanField = nullptr;
515+
// If `TangentVector` is the original struct, then the tangent property is the
516+
// original property.
517+
if (parentTanStruct == parentDC->getSelfStructDecl()) {
518+
tanField = originalField;
519+
}
520+
// Otherwise, look up the field by name.
521+
else {
522+
auto tanFieldLookup =
523+
parentTanStruct->lookupDirect(originalField->getName());
524+
llvm::erase_if(tanFieldLookup,
525+
[](ValueDecl *v) { return !isa<VarDecl>(v); });
526+
// Error if tangent property could not be found.
527+
if (tanFieldLookup.empty()) {
528+
return TangentPropertyInfo(
529+
TangentPropertyInfo::Error::Kind::TangentPropertyNotFound);
530+
}
531+
tanField = cast<VarDecl>(tanFieldLookup.front());
532+
}
533+
// Error if tangent property's type is not equal to the original property's
534+
// `TangentVector` type.
535+
auto originalFieldTanType = originalFieldTan->getType();
536+
if (!originalFieldTanType->isEqual(tanField->getType())) {
537+
return TangentPropertyInfo(
538+
TangentPropertyInfo::Error::Kind::TangentPropertyWrongType,
539+
originalFieldTanType);
540+
}
541+
// Error if tangent property is not a stored property.
542+
if (!tanField->hasStorage()) {
543+
return TangentPropertyInfo(
544+
TangentPropertyInfo::Error::Kind::TangentPropertyNotStored);
545+
}
546+
// Otherwise, tangent property is valid.
547+
return TangentPropertyInfo(tanField);
548+
}

lib/Sema/DerivedConformanceDifferentiable.cpp

Lines changed: 23 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
#include "swift/AST/PropertyWrappers.h"
2727
#include "swift/AST/ProtocolConformance.h"
2828
#include "swift/AST/Stmt.h"
29+
#include "swift/AST/TypeCheckRequests.h"
2930
#include "swift/AST/Types.h"
3031
#include "DerivedConformances.h"
3132

@@ -646,39 +647,43 @@ getOrSynthesizeTangentVectorStruct(DerivedConformance &derived, Identifier id) {
646647
structDecl->setImplicit();
647648
structDecl->copyFormalAccessFrom(nominal, /*sourceIsParentContext*/ true);
648649

649-
// Add members to `TangentVector` struct.
650+
// Add stored properties to the `TangentVector` struct.
650651
for (auto *member : diffProperties) {
651-
// Add this member's corresponding `TangentVector` type to the parent's
652-
// `TangentVector` struct.
653-
// Note: `newMember` is not marked as implicit here, because that
654-
// incorrectly affects memberwise initializer synthesis.
655-
auto *newMember = new (C) VarDecl(
652+
// Add a tangent stored property to the `TangentVector` struct, with the
653+
// name and `TangentVector` type of the original property.
654+
auto *tangentProperty = new (C) VarDecl(
656655
member->isStatic(), member->getIntroducer(), member->isCaptureList(),
657656
/*NameLoc*/ SourceLoc(), member->getName(), structDecl);
658-
657+
// Note: `tangentProperty` is not marked as implicit here, because that
658+
// incorrectly affects memberwise initializer synthesis.
659659
auto memberContextualType =
660660
parentDC->mapTypeIntoContext(member->getValueInterfaceType());
661661
auto memberTanType =
662662
getTangentVectorInterfaceType(memberContextualType, parentDC);
663-
newMember->setInterfaceType(memberTanType);
664-
Pattern *memberPattern = NamedPattern::createImplicit(C, newMember);
663+
tangentProperty->setInterfaceType(memberTanType);
664+
Pattern *memberPattern = NamedPattern::createImplicit(C, tangentProperty);
665665
memberPattern->setType(memberTanType);
666666
memberPattern =
667667
TypedPattern::createImplicit(C, memberPattern, memberTanType);
668668
memberPattern->setType(memberTanType);
669669
auto *memberBinding = PatternBindingDecl::createImplicit(
670670
C, StaticSpellingKind::None, memberPattern, /*initExpr*/ nullptr,
671671
structDecl);
672-
structDecl->addMember(newMember);
672+
structDecl->addMember(tangentProperty);
673673
structDecl->addMember(memberBinding);
674-
newMember->copyFormalAccessFrom(member, /*sourceIsParentContext*/ true);
675-
newMember->setSetterAccess(member->getFormalAccess());
676-
677-
// Now that this member is in the `TangentVector` type, it should be marked
678-
// `@differentiable` so that the differentiation transform will synthesize
679-
// derivative functions for it. We only add this to public stored
680-
// properties, because their access outside the module will go through a
681-
// call to the getter.
674+
tangentProperty->copyFormalAccessFrom(member,
675+
/*sourceIsParentContext*/ true);
676+
tangentProperty->setSetterAccess(member->getFormalAccess());
677+
678+
// Cache the tangent property.
679+
C.evaluator.cacheOutput(TangentStoredPropertyRequest{member},
680+
TangentPropertyInfo(tangentProperty));
681+
682+
// Now that the original property has a corresponding tangent property, it
683+
// should be marked `@differentiable` so that the differentiation transform
684+
// will synthesize derivative functions for its accessors. We only add this
685+
// to public stored properties, because their access outside the module will
686+
// go through accessor declarations.
682687
if (member->getEffectiveAccess() > AccessLevel::Internal &&
683688
!member->getAttrs().hasAttribute<DifferentiableAttr>()) {
684689
auto *getter = member->getSynthesizedAccessor(AccessorKind::Get);

0 commit comments

Comments
 (0)