Skip to content

[AutoDiff upstream] Add TangentSpace. #29107

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions include/swift/AST/ASTContext.h
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,9 @@ class ASTContext final {
/// across invocations of both the parser and the type-checker.
unsigned NextAutoClosureDiscriminator = 0;

/// Cached mapping from types to their associated tangent spaces.
llvm::DenseMap<Type, Optional<TangentSpace>> AutoDiffTangentSpaces;

/// Cache of `@derivative` attributes keyed by parameter indices and
/// derivative function kind. Used to diagnose duplicate `@derivative`
/// attributes for the same key.
Expand Down
69 changes: 69 additions & 0 deletions include/swift/AST/AutoDiff.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
namespace swift {

class AnyFunctionType;
class TupleType;

/// A function type differentiability kind.
enum class DifferentiabilityKind : uint8_t {
Expand Down Expand Up @@ -133,6 +134,74 @@ class ParsedAutoDiffParameter {
}
};

/// The tangent space of a type.
///
/// For `Differentiable`-conforming types:
/// - The tangent space is the `TangentVector` associated type.
///
/// For tuple types:
/// - The tangent space is a tuple of the elements' tangent space types, for the
/// elements that have a tangent space.
///
/// Other types have no tangent space.
class TangentSpace {
public:
/// A tangent space kind.
enum class Kind {
/// The `TangentVector` associated type of a `Differentiable`-conforming
/// type.
TangentVector,
/// A product of tangent spaces as a tuple.
Tuple
};

private:
Kind kind;
union Value {
// TangentVector
Type tangentVectorType;
// Tuple
TupleType *tupleType;

Value(Type tangentVectorType) : tangentVectorType(tangentVectorType) {}
Value(TupleType *tupleType) : tupleType(tupleType) {}
} value;

TangentSpace(Kind kind, Value value) : kind(kind), value(value) {}

public:
TangentSpace() = delete;

static TangentSpace getTangentVector(Type tangentVectorType) {
return {Kind::TangentVector, tangentVectorType};
}
static TangentSpace getTuple(TupleType *tupleTy) {
return {Kind::Tuple, tupleTy};
}

bool isTangentVector() const { return kind == Kind::TangentVector; }
bool isTuple() const { return kind == Kind::Tuple; }

Kind getKind() const { return kind; }
Type getTangentVector() const {
assert(kind == Kind::TangentVector);
return value.tangentVectorType;
}
TupleType *getTuple() const {
assert(kind == Kind::Tuple);
return value.tupleType;
}

/// Get the tangent space type.
Type getType() const;

/// Get the tangent space canonical type.
CanType getCanonicalType() const;

/// Get the underlying nominal type declaration of the tangent space type.
NominalTypeDecl *getNominal() const;
};

/// Automatic differentiation utility namespace.
namespace autodiff {

Expand Down
5 changes: 5 additions & 0 deletions include/swift/AST/Types.h
Original file line number Diff line number Diff line change
Expand Up @@ -1161,6 +1161,11 @@ class alignas(1 << TypeAlignInBits) TypeBase {
/// object type.
TypeTraitResult canBeClass();

/// Return the tangent space of the given type, if it exists. Otherwise,
/// return `None`.
Optional<TangentSpace>
getAutoDiffTangentSpace(LookupConformanceFn lookupConformance);

private:
// Make vanilla new/delete illegal for Types.
void *operator new(size_t Bytes) throw() = delete;
Expand Down
18 changes: 18 additions & 0 deletions lib/AST/AutoDiff.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,3 +66,21 @@ void autodiff::getSubsetParameterTypes(IndexSubset *subset,
results.push_back(curryLevel->getParams()[paramIndex].getOldType());
}
}

Type TangentSpace::getType() const {
switch (kind) {
case Kind::TangentVector:
return value.tangentVectorType;
case Kind::Tuple:
return value.tupleType;
}
}

CanType TangentSpace::getCanonicalType() const {
return getType()->getCanonicalType();
}

NominalTypeDecl *TangentSpace::getNominal() const {
assert(isTangentVector());
return getTangentVector()->getNominalOrBoundGenericNominal();
}
55 changes: 55 additions & 0 deletions lib/AST/Type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4855,6 +4855,61 @@ CanType swift::substOpaqueTypesWithUnderlyingTypes(CanType ty,
return ty.subst(replacer, replacer, flags)->getCanonicalType();
}

Optional<TangentSpace>
TypeBase::getAutoDiffTangentSpace(LookupConformanceFn lookupConformance) {
assert(lookupConformance);
auto &ctx = getASTContext();

Type cacheKey = this;
auto lookup = ctx.AutoDiffTangentSpaces.find(cacheKey);
if (lookup != ctx.AutoDiffTangentSpaces.end())
return lookup->getSecond();
auto cache = [&](Optional<TangentSpace> tangentSpace) {
ctx.AutoDiffTangentSpaces.insert({cacheKey, tangentSpace});
return tangentSpace;
};

// For tuple types: the tangent space is a tuple of the elements' tangent
// space types, for the elements that have a tangent space.
if (auto *tupleTy = getAs<TupleType>()) {
SmallVector<TupleTypeElt, 8> newElts;
for (auto elt : tupleTy->getElements()) {
auto eltSpace = elt.getType()->getAutoDiffTangentSpace(lookupConformance);
if (!eltSpace)
continue;
newElts.push_back(elt.getWithType(eltSpace->getType()));
}
if (newElts.empty())
return cache(
TangentSpace::getTuple(ctx.TheEmptyTupleType->castTo<TupleType>()));
if (newElts.size() == 1)
return cache(TangentSpace::getTangentVector(newElts.front().getType()));
auto *tupleType = TupleType::get(newElts, ctx)->castTo<TupleType>();
return cache(TangentSpace::getTuple(tupleType));
}

// For `Differentiable`-conforming types: the tangent space is the
// `TangentVector` associated type.
auto *differentiableProtocol =
ctx.getProtocol(KnownProtocolKind::Differentiable);
assert(differentiableProtocol && "`Differentiable` protocol not found");
auto associatedTypeLookup =
differentiableProtocol->lookupDirect(ctx.Id_TangentVector);
assert(associatedTypeLookup.size() == 1);
auto *dependentType = DependentMemberType::get(
differentiableProtocol->getDeclaredInterfaceType(),
cast<AssociatedTypeDecl>(associatedTypeLookup[0]));

// Try to get the `TangentVector` associated type of `base`.
// Return the associated type if it is valid.
auto assocTy = dependentType->substBaseType(this, lookupConformance);
if (!assocTy->hasError())
return cache(TangentSpace::getTangentVector(assocTy));

// Otherwise, there is no associated tangent space. Return `None`.
return cache(None);
}

CanSILFunctionType
SILFunctionType::withSubstitutions(SubstitutionMap subs) const {
return SILFunctionType::get(getSubstGenericSignature(),
Expand Down