Skip to content

[AutoDiff] [stdlib] Derive conformances for 'EuclideanDifferentiable'. #26867

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions include/swift/AST/DiagnosticsSema.def
Original file line number Diff line number Diff line change
Expand Up @@ -2576,6 +2576,8 @@ ERROR(broken_vector_protocol_requirement,none,
"VectorProtocol protocol is broken: unexpected requirement", ())
ERROR(broken_differentiable_requirement,none,
"Differentiable protocol is broken: unexpected requirement", ())
ERROR(broken_euclidean_differentiable_requirement,none,
"EuclideanDifferentiable protocol is broken: unexpected requirement", ())
ERROR(broken_key_path_iterable_requirement,none,
"KeyPathIterable protocol is broken: unexpected requirement", ())
ERROR(broken_tensor_array_protocol_requirement,none,
Expand Down
1 change: 1 addition & 0 deletions include/swift/AST/KnownIdentifiers.def
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,7 @@ IDENTIFIER(x)
// Differentiable
IDENTIFIER(TangentVector)
IDENTIFIER(move)
IDENTIFIER(vectorView)

// Kinds of layout constraints
IDENTIFIER_WITH_NAME(UnknownLayout, "_UnknownLayout")
Expand Down
1 change: 1 addition & 0 deletions include/swift/AST/KnownProtocols.def
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ PROTOCOL_(TensorFlowDataTypeCompatible)
PROTOCOL(TensorProtocol)
PROTOCOL(VectorProtocol)
PROTOCOL(Differentiable)
PROTOCOL(EuclideanDifferentiable)

PROTOCOL_(ObjectiveCBridgeable)
PROTOCOL_(DestructorSafeContainer)
Expand Down
1 change: 1 addition & 0 deletions lib/IRGen/GenMeta.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4223,6 +4223,7 @@ SpecialProtocol irgen::getSpecialProtocolID(ProtocolDecl *P) {
case KnownProtocolKind::TensorProtocol:
case KnownProtocolKind::VectorProtocol:
case KnownProtocolKind::Differentiable:
case KnownProtocolKind::EuclideanDifferentiable:
return SpecialProtocol::None;
}

Expand Down
185 changes: 155 additions & 30 deletions lib/Sema/DerivedConformanceDifferentiable.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@

using namespace swift;

// Return the protocol requirement with the specified name.
// TODO: Move function to shared place for use with other derived conformances.
/// Return the protocol requirement with the specified name.
/// TODO: Move function to shared place for use with other derived conformances.
static ValueDecl *getProtocolRequirement(ProtocolDecl *proto, Identifier name) {
auto lookup = proto->lookupDirect(name);
// Erase declarations that are not protocol requirements.
Expand All @@ -46,8 +46,8 @@ static ValueDecl *getProtocolRequirement(ProtocolDecl *proto, Identifier name) {
return lookup.front();
}

// Get the stored properties of a nominal type that are relevant for
// differentiation, except the ones tagged `@noDerivative`.
/// Get the stored properties of a nominal type that are relevant for
/// differentiation, except the ones tagged `@noDerivative`.
static void
getStoredPropertiesForDifferentiation(NominalTypeDecl *nominal,
DeclContext *DC,
Expand All @@ -71,8 +71,8 @@ getStoredPropertiesForDifferentiation(NominalTypeDecl *nominal,
}
}

// Convert the given `ValueDecl` to a `StructDecl` if it is a `StructDecl` or a
// `TypeDecl` with an underlying struct type. Otherwise, return `nullptr`.
/// Convert the given `ValueDecl` to a `StructDecl` if it is a `StructDecl` or a
/// `TypeDecl` with an underlying struct type. Otherwise, return `nullptr`.
static StructDecl *convertToStructDecl(ValueDecl *v) {
if (auto *structDecl = dyn_cast<StructDecl>(v))
return structDecl;
Expand All @@ -83,10 +83,10 @@ static StructDecl *convertToStructDecl(ValueDecl *v) {
typeDecl->getDeclaredInterfaceType()->getAnyNominal());
}

// Get the `Differentiable` protocol `TangentVector` associated type for the
// given `VarDecl`.
// TODO: Generalize and move function to shared place for use with other derived
// conformances.
/// Get the `Differentiable` protocol `TangentVector` associated type for the
/// given `VarDecl`.
/// TODO: Generalize and move function to shared place for use with other derived
/// conformances.
static Type getTangentVectorType(VarDecl *decl, DeclContext *DC) {
auto &C = decl->getASTContext();
auto *diffableProto = C.getProtocol(KnownProtocolKind::Differentiable);
Expand Down Expand Up @@ -196,7 +196,34 @@ bool DerivedConformance::canDeriveDifferentiable(NominalTypeDecl *nominal,
});
}

// Synthesize body for a `Differentiable` method requirement.
/// Determine if a EuclideanDifferentiable requirement can be derived for a type.
///
/// \returns True if the requirement can be derived.
bool DerivedConformance::canDeriveEuclideanDifferentiable(
NominalTypeDecl *nominal, DeclContext *DC) {
if (!canDeriveDifferentiable(nominal, DC))
return false;
auto &C = nominal->getASTContext();
auto *lazyResolver = C.getLazyResolver();
auto *addArithProto = C.getProtocol(KnownProtocolKind::AdditiveArithmetic);
// Return true if all differentiation stored properties conform to
// `AdditiveArithmetic` and their `TangentVector` equals themselves.
SmallVector<VarDecl *, 16> diffProperties;
getStoredPropertiesForDifferentiation(nominal, DC, diffProperties);
return llvm::all_of(diffProperties, [&](VarDecl *member) {
if (!member->hasInterfaceType())
lazyResolver->resolveDeclSignature(member);
if (!member->hasInterfaceType())
return false;
auto varType = DC->mapTypeIntoContext(member->getValueInterfaceType());
if (!TypeChecker::conformsToProtocol(varType, addArithProto, DC, None))
return false;
auto memberAssocType = getTangentVectorType(member, DC);
return member->getType()->isEqual(memberAssocType);
});
}

/// Synthesize body for a `Differentiable` method requirement.
static std::pair<BraceStmt *, bool>
deriveBodyDifferentiable_method(AbstractFunctionDecl *funcDecl,
Identifier methodName,
Expand Down Expand Up @@ -283,15 +310,15 @@ deriveBodyDifferentiable_method(AbstractFunctionDecl *funcDecl,
return std::pair<BraceStmt *, bool>(braceStmt, false);
}

// Synthesize body for `move(along:)`.
/// Synthesize body for `move(along:)`.
static std::pair<BraceStmt *, bool>
deriveBodyDifferentiable_move(AbstractFunctionDecl *funcDecl, void *) {
auto &C = funcDecl->getASTContext();
return deriveBodyDifferentiable_method(funcDecl, C.Id_move,
C.getIdentifier("along"));
}

// Synthesize function declaration for a `Differentiable` method requirement.
/// Synthesize function declaration for a `Differentiable` method requirement.
static ValueDecl *deriveDifferentiable_method(
DerivedConformance &derived, Identifier methodName, Identifier argumentName,
Identifier parameterName, Type parameterType, Type returnType,
Expand Down Expand Up @@ -329,7 +356,7 @@ static ValueDecl *deriveDifferentiable_method(
return funcDecl;
}

// Synthesize the `move(along:)` function declaration.
/// Synthesize the `move(along:)` function declaration.
static ValueDecl *deriveDifferentiable_move(DerivedConformance &derived) {
auto &C = derived.TC.Context;
auto *parentDC = derived.getConformanceContext();
Expand All @@ -343,8 +370,92 @@ static ValueDecl *deriveDifferentiable_move(DerivedConformance &derived) {
{deriveBodyDifferentiable_move, nullptr});
}

// Return associated `TangentVector` struct for a nominal type, if it exists.
// If not, synthesize the struct.
/// Synthesize the `vectorView` property declaration.
static ValueDecl *deriveEuclideanDifferentiable_vectorView(
DerivedConformance &derived) {
auto &C = derived.TC.Context;
auto *parentDC = derived.getConformanceContext();

auto *tangentDecl = getTangentVectorStructDecl(parentDC);
auto tangentType = tangentDecl->getDeclaredInterfaceType();
auto tangentContextualType = parentDC->mapTypeIntoContext(tangentType);

VarDecl *vectorViewDecl;
PatternBindingDecl *pbDecl;
std::tie(vectorViewDecl, pbDecl) = derived.declareDerivedProperty(
C.Id_vectorView, tangentType, tangentContextualType, /*isStatic*/ false,
/*isFinal*/ true);

struct GetterSynthesizerContext {
StructDecl *tangentDecl;
Type tangentContextualType;
};

auto getterSynthesizer = [](AbstractFunctionDecl *getterDecl, void *ctx)
-> std::pair<BraceStmt *, bool> {
auto *context = reinterpret_cast<GetterSynthesizerContext *>(ctx);
assert(context && "Invalid context");
auto *parentDC = getterDecl->getParent();
auto *nominal = parentDC->getSelfNominalTypeDecl();
auto &C = nominal->getASTContext();
SmallVector<VarDecl *, 8> diffProperties;
getStoredPropertiesForDifferentiation(nominal, nominal->getDeclContext(),
diffProperties);

// Create a reference to the memberwise initializer: `TangentVector.init`.
auto *memberwiseInitDecl =
context->tangentDecl->getEffectiveMemberwiseInitializer();
assert(memberwiseInitDecl && "Memberwise initializer must exist");
// `TangentVector`
auto *tangentTypeExpr =
TypeExpr::createImplicit(context->tangentContextualType, C);
// `TangentVector.init`
auto *initDRE = new (C) DeclRefExpr(memberwiseInitDecl, DeclNameLoc(),
/*Implicit*/ true);
initDRE->setFunctionRefKind(FunctionRefKind::SingleApply);
auto *initExpr = new (C) ConstructorRefCallExpr(initDRE, tangentTypeExpr);
initExpr->setThrows(false);
initExpr->setImplicit();

// Create a call:
// TangentVector.init(
// <property_name_1...>: self.<property_name_1>,
// <property_name_2...>: self.<property_name_2>,
// ...
// )
SmallVector<Identifier, 8> argLabels;
SmallVector<Expr *, 8> memberRefs;
auto *selfDRE = new (C) DeclRefExpr(getterDecl->getImplicitSelfDecl(),
DeclNameLoc(),
/*Implicit*/ true);
for (auto *member : diffProperties) {
argLabels.push_back(member->getName());
memberRefs.push_back(
new (C) MemberRefExpr(selfDRE, SourceLoc(), member, DeclNameLoc(),
/*Implicit*/ true));
}
assert(memberRefs.size() == argLabels.size());
CallExpr *callExpr =
CallExpr::createImplicit(C, initExpr, memberRefs, argLabels);

// Create a return statement: `return TangentVector.init(...)`.
ASTNode retStmt =
new (C) ReturnStmt(SourceLoc(), callExpr, /*implicit*/ true);
auto *braceStmt = BraceStmt::create(C, SourceLoc(), retStmt, SourceLoc(),
/*implicit*/ true);
return std::make_pair(braceStmt, false);
};
auto *getterDecl = derived.addGetterToReadOnlyDerivedProperty(
vectorViewDecl, tangentContextualType);
getterDecl->setBodySynthesizer(
getterSynthesizer, /*context*/ C.AllocateObjectCopy(
GetterSynthesizerContext{tangentDecl, tangentContextualType}));
derived.addMembersToConformanceContext({vectorViewDecl, pbDecl});
return vectorViewDecl;
}

/// Return associated `TangentVector` struct for a nominal type, if it exists.
/// If not, synthesize the struct.
static StructDecl *
getOrSynthesizeTangentVectorStruct(DerivedConformance &derived, Identifier id) {
auto &TC = derived.TC;
Expand All @@ -362,8 +473,7 @@ getOrSynthesizeTangentVectorStruct(DerivedConformance &derived, Identifier id) {
return structDecl;
}

// Otherwise, synthesize a new struct. The struct must conform to
// `Differentiable`.
// Otherwise, synthesize a new struct.
auto *diffableProto = C.getProtocol(KnownProtocolKind::Differentiable);
auto diffableType = TypeLoc::withoutLoc(diffableProto->getDeclaredType());
auto *addArithProto = C.getProtocol(KnownProtocolKind::AdditiveArithmetic);
Expand All @@ -378,9 +488,9 @@ getOrSynthesizeTangentVectorStruct(DerivedConformance &derived, Identifier id) {
auto *kpIterableProto = C.getProtocol(KnownProtocolKind::KeyPathIterable);
auto kpIterableType = TypeLoc::withoutLoc(kpIterableProto->getDeclaredType());

SmallVector<TypeLoc, 4> inherited{diffableType};
// `TangentVector` must conform to `AdditiveArithmetic`.
inherited.push_back(addArithType);
// By definition, `TangentVector` must conform to `Differentiable` and
// `AdditiveArithmetic`.
SmallVector<TypeLoc, 4> inherited{diffableType, addArithType};

// Cache original members and their associated types for later use.
SmallVector<VarDecl *, 8> diffProperties;
Expand Down Expand Up @@ -551,8 +661,8 @@ getOrSynthesizeTangentVectorStruct(DerivedConformance &derived, Identifier id) {
return structDecl;
}

// Add a typealias declaration with the given name and underlying target
// struct type to the given source nominal declaration context.
/// Add a typealias declaration with the given name and underlying target
/// struct type to the given source nominal declaration context.
static void addAssociatedTypeAliasDecl(Identifier name,
DeclContext *sourceDC,
StructDecl *target,
Expand Down Expand Up @@ -585,11 +695,11 @@ static void addAssociatedTypeAliasDecl(Identifier name,
C.addSynthesizedDecl(aliasDecl);
};

// Diagnose stored properties in the nominal that do not have an explicit
// `@noDerivative` attribute, but either:
// - Do not conform to `Differentiable`.
// - Are a `let` stored property.
// Emit a warning and a fixit so that users will make the attribute explicit.
/// Diagnose stored properties in the nominal that do not have an explicit
/// `@noDerivative` attribute, but either:
/// - Do not conform to `Differentiable`.
/// - Are a `let` stored property.
/// Emit a warning and a fixit so that users will make the attribute explicit.
static void checkAndDiagnoseImplicitNoDerivative(TypeChecker &TC,
NominalTypeDecl *nominal,
DeclContext* DC) {
Expand Down Expand Up @@ -637,7 +747,7 @@ static void checkAndDiagnoseImplicitNoDerivative(TypeChecker &TC,
}
}

// Get or synthesize `TangentVector` struct type.
/// Get or synthesize `TangentVector` struct type.
static Type
getOrSynthesizeTangentVectorStructType(DerivedConformance &derived) {
auto &TC = derived.TC;
Expand Down Expand Up @@ -669,7 +779,7 @@ getOrSynthesizeTangentVectorStructType(DerivedConformance &derived) {
tangentStruct->getDeclaredInterfaceType());
}

// Synthesize the `TangentVector` struct type.
/// Synthesize the `TangentVector` struct type.
static Type
deriveDifferentiable_TangentVectorStruct(DerivedConformance &derived) {
auto &TC = derived.TC;
Expand Down Expand Up @@ -756,3 +866,18 @@ Type DerivedConformance::deriveDifferentiable(AssociatedTypeDecl *requirement) {
TC.diagnose(requirement->getLoc(), diag::broken_differentiable_requirement);
return nullptr;
}

/// Derive a EuclideanDifferentiable requirement for a nominal type.
///
/// \returns the derived member, which will also be added to the type.
ValueDecl *DerivedConformance::deriveEuclideanDifferentiable(
ValueDecl *requirement) {
// Diagnose conformances in disallowed contexts.
if (checkAndDiagnoseDisallowedContext(requirement))
return nullptr;
if (requirement->getFullName() == TC.Context.Id_vectorView)
return deriveEuclideanDifferentiable_vectorView(*this);
TC.diagnose(requirement->getLoc(),
diag::broken_euclidean_differentiable_requirement);
return nullptr;
}
9 changes: 9 additions & 0 deletions lib/Sema/DerivedConformances.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,10 @@ bool DerivedConformance::derivesProtocolConformance(DeclContext *DC,
if (*knownProtocol == KnownProtocolKind::Differentiable)
return canDeriveDifferentiable(Nominal, DC);

// SWIFT_ENABLE_TENSORFLOW
if (*knownProtocol == KnownProtocolKind::EuclideanDifferentiable)
return canDeriveEuclideanDifferentiable(Nominal, DC);

if (auto *enumDecl = dyn_cast<EnumDecl>(Nominal)) {
switch (*knownProtocol) {
// The presence of a raw type is an explicit declaration that
Expand Down Expand Up @@ -255,6 +259,11 @@ ValueDecl *DerivedConformance::getDerivableRequirement(NominalTypeDecl *nominal,
if (name.isSimpleName(ctx.Id_zero))
return getRequirement(KnownProtocolKind::AdditiveArithmetic);

// SWIFT_ENABLE_TENSORFLOW
// EuclideanDifferentiable.vectorView
if (name.isSimpleName(ctx.Id_vectorView))
return getRequirement(KnownProtocolKind::EuclideanDifferentiable);

// SWIFT_ENABLE_TENSORFLOW
// PointwiseMultiplicative.one
if (name.isSimpleName(ctx.Id_one))
Expand Down
11 changes: 11 additions & 0 deletions lib/Sema/DerivedConformances.h
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,17 @@ class DerivedConformance {
/// \returns the derived member, which will also be added to the type.
ValueDecl *deriveDifferentiable(ValueDecl *requirement);

/// Determine if a Differentiable requirement can be derived for a type.
///
/// \returns True if the requirement can be derived.
static bool canDeriveEuclideanDifferentiable(NominalTypeDecl *type,
DeclContext *DC);

/// Derive a EuclideanDifferentiable requirement for a nominal type.
///
/// \returns the derived member, which will also be added to the type.
ValueDecl *deriveEuclideanDifferentiable(ValueDecl *requirement);

/// Derive a Differentiable type witness for a nominal type.
///
/// \returns the derived member, which will also be added to the type.
Expand Down
4 changes: 4 additions & 0 deletions lib/Sema/TypeCheckProtocol.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5363,6 +5363,10 @@ ValueDecl *TypeChecker::deriveProtocolRequirement(DeclContext *DC,
case KnownProtocolKind::Differentiable:
return derived.deriveDifferentiable(Requirement);

// SWIFT_ENABLE_TENSORFLOW
case KnownProtocolKind::EuclideanDifferentiable:
return derived.deriveEuclideanDifferentiable(Requirement);

default:
return nullptr;
}
Expand Down
7 changes: 2 additions & 5 deletions stdlib/public/core/AutoDiff.swift
Original file line number Diff line number Diff line change
Expand Up @@ -229,14 +229,11 @@ public extension Differentiable where TangentVector == Self {
/// `TangentVector` is equal to its vector space component.
public protocol EuclideanDifferentiable: Differentiable {
/// The differentiable vector component of `self`.
var vectorView: TangentVector { get set }
var vectorView: TangentVector { get }
}

public extension EuclideanDifferentiable where TangentVector == Self {
var vectorView: TangentVector {
_read { yield self }
_modify { yield &self }
}
var vectorView: TangentVector { _read { yield self } }
}

/// Returns `x` like an identity function. When used in a context where `x` is
Expand Down
Loading