Skip to content

Commit a98fe74

Browse files
authored
[AutoDiff upstream] Add TangentSpace. (#29107)
`TangentSpace` represents the tangent space of a type. - For `Differentiable`-conforming types: - The tangent space is the `TangentVector` associated type. - For tuple types: - The tangent space is a tuple of the elements' tangent space types, for the elements that have a tangent space. - Other types have no tangent space. `TypeBase::getAutoDiffTangentSpace` gets the tangent space of a type. `TangentSpace` is used to: - Compute the derivative function type of a given original function type. - Compute the type of tangent/adjoint values during automatic differentiation. Progress towards TF-828: upstream `@differentiable` attribute type-checking.
1 parent 1486d6b commit a98fe74

File tree

5 files changed

+150
-0
lines changed

5 files changed

+150
-0
lines changed

include/swift/AST/ASTContext.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -284,6 +284,9 @@ class ASTContext final {
284284
/// across invocations of both the parser and the type-checker.
285285
unsigned NextAutoClosureDiscriminator = 0;
286286

287+
/// Cached mapping from types to their associated tangent spaces.
288+
llvm::DenseMap<Type, Optional<TangentSpace>> AutoDiffTangentSpaces;
289+
287290
/// Cache of `@derivative` attributes keyed by parameter indices and
288291
/// derivative function kind. Used to diagnose duplicate `@derivative`
289292
/// attributes for the same key.

include/swift/AST/AutoDiff.h

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
namespace swift {
3030

3131
class AnyFunctionType;
32+
class TupleType;
3233

3334
/// A function type differentiability kind.
3435
enum class DifferentiabilityKind : uint8_t {
@@ -153,6 +154,74 @@ class ParsedAutoDiffParameter {
153154
}
154155
};
155156

157+
/// The tangent space of a type.
158+
///
159+
/// For `Differentiable`-conforming types:
160+
/// - The tangent space is the `TangentVector` associated type.
161+
///
162+
/// For tuple types:
163+
/// - The tangent space is a tuple of the elements' tangent space types, for the
164+
/// elements that have a tangent space.
165+
///
166+
/// Other types have no tangent space.
167+
class TangentSpace {
168+
public:
169+
/// A tangent space kind.
170+
enum class Kind {
171+
/// The `TangentVector` associated type of a `Differentiable`-conforming
172+
/// type.
173+
TangentVector,
174+
/// A product of tangent spaces as a tuple.
175+
Tuple
176+
};
177+
178+
private:
179+
Kind kind;
180+
union Value {
181+
// TangentVector
182+
Type tangentVectorType;
183+
// Tuple
184+
TupleType *tupleType;
185+
186+
Value(Type tangentVectorType) : tangentVectorType(tangentVectorType) {}
187+
Value(TupleType *tupleType) : tupleType(tupleType) {}
188+
} value;
189+
190+
TangentSpace(Kind kind, Value value) : kind(kind), value(value) {}
191+
192+
public:
193+
TangentSpace() = delete;
194+
195+
static TangentSpace getTangentVector(Type tangentVectorType) {
196+
return {Kind::TangentVector, tangentVectorType};
197+
}
198+
static TangentSpace getTuple(TupleType *tupleTy) {
199+
return {Kind::Tuple, tupleTy};
200+
}
201+
202+
bool isTangentVector() const { return kind == Kind::TangentVector; }
203+
bool isTuple() const { return kind == Kind::Tuple; }
204+
205+
Kind getKind() const { return kind; }
206+
Type getTangentVector() const {
207+
assert(kind == Kind::TangentVector);
208+
return value.tangentVectorType;
209+
}
210+
TupleType *getTuple() const {
211+
assert(kind == Kind::Tuple);
212+
return value.tupleType;
213+
}
214+
215+
/// Get the tangent space type.
216+
Type getType() const;
217+
218+
/// Get the tangent space canonical type.
219+
CanType getCanonicalType() const;
220+
221+
/// Get the underlying nominal type declaration of the tangent space type.
222+
NominalTypeDecl *getNominal() const;
223+
};
224+
156225
/// Automatic differentiation utility namespace.
157226
namespace autodiff {
158227

include/swift/AST/Types.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1161,6 +1161,11 @@ class alignas(1 << TypeAlignInBits) TypeBase {
11611161
/// object type.
11621162
TypeTraitResult canBeClass();
11631163

1164+
/// Return the tangent space of the given type, if it exists. Otherwise,
1165+
/// return `None`.
1166+
Optional<TangentSpace>
1167+
getAutoDiffTangentSpace(LookupConformanceFn lookupConformance);
1168+
11641169
private:
11651170
// Make vanilla new/delete illegal for Types.
11661171
void *operator new(size_t Bytes) throw() = delete;

lib/AST/AutoDiff.cpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,3 +66,21 @@ void autodiff::getSubsetParameterTypes(IndexSubset *subset,
6666
results.push_back(curryLevel->getParams()[paramIndex].getOldType());
6767
}
6868
}
69+
70+
Type TangentSpace::getType() const {
71+
switch (kind) {
72+
case Kind::TangentVector:
73+
return value.tangentVectorType;
74+
case Kind::Tuple:
75+
return value.tupleType;
76+
}
77+
}
78+
79+
CanType TangentSpace::getCanonicalType() const {
80+
return getType()->getCanonicalType();
81+
}
82+
83+
NominalTypeDecl *TangentSpace::getNominal() const {
84+
assert(isTangentVector());
85+
return getTangentVector()->getNominalOrBoundGenericNominal();
86+
}

lib/AST/Type.cpp

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4855,6 +4855,61 @@ CanType swift::substOpaqueTypesWithUnderlyingTypes(CanType ty,
48554855
return ty.subst(replacer, replacer, flags)->getCanonicalType();
48564856
}
48574857

4858+
Optional<TangentSpace>
4859+
TypeBase::getAutoDiffTangentSpace(LookupConformanceFn lookupConformance) {
4860+
assert(lookupConformance);
4861+
auto &ctx = getASTContext();
4862+
4863+
Type cacheKey = this;
4864+
auto lookup = ctx.AutoDiffTangentSpaces.find(cacheKey);
4865+
if (lookup != ctx.AutoDiffTangentSpaces.end())
4866+
return lookup->getSecond();
4867+
auto cache = [&](Optional<TangentSpace> tangentSpace) {
4868+
ctx.AutoDiffTangentSpaces.insert({cacheKey, tangentSpace});
4869+
return tangentSpace;
4870+
};
4871+
4872+
// For tuple types: the tangent space is a tuple of the elements' tangent
4873+
// space types, for the elements that have a tangent space.
4874+
if (auto *tupleTy = getAs<TupleType>()) {
4875+
SmallVector<TupleTypeElt, 8> newElts;
4876+
for (auto elt : tupleTy->getElements()) {
4877+
auto eltSpace = elt.getType()->getAutoDiffTangentSpace(lookupConformance);
4878+
if (!eltSpace)
4879+
continue;
4880+
newElts.push_back(elt.getWithType(eltSpace->getType()));
4881+
}
4882+
if (newElts.empty())
4883+
return cache(
4884+
TangentSpace::getTuple(ctx.TheEmptyTupleType->castTo<TupleType>()));
4885+
if (newElts.size() == 1)
4886+
return cache(TangentSpace::getTangentVector(newElts.front().getType()));
4887+
auto *tupleType = TupleType::get(newElts, ctx)->castTo<TupleType>();
4888+
return cache(TangentSpace::getTuple(tupleType));
4889+
}
4890+
4891+
// For `Differentiable`-conforming types: the tangent space is the
4892+
// `TangentVector` associated type.
4893+
auto *differentiableProtocol =
4894+
ctx.getProtocol(KnownProtocolKind::Differentiable);
4895+
assert(differentiableProtocol && "`Differentiable` protocol not found");
4896+
auto associatedTypeLookup =
4897+
differentiableProtocol->lookupDirect(ctx.Id_TangentVector);
4898+
assert(associatedTypeLookup.size() == 1);
4899+
auto *dependentType = DependentMemberType::get(
4900+
differentiableProtocol->getDeclaredInterfaceType(),
4901+
cast<AssociatedTypeDecl>(associatedTypeLookup[0]));
4902+
4903+
// Try to get the `TangentVector` associated type of `base`.
4904+
// Return the associated type if it is valid.
4905+
auto assocTy = dependentType->substBaseType(this, lookupConformance);
4906+
if (!assocTy->hasError())
4907+
return cache(TangentSpace::getTangentVector(assocTy));
4908+
4909+
// Otherwise, there is no associated tangent space. Return `None`.
4910+
return cache(None);
4911+
}
4912+
48584913
CanSILFunctionType
48594914
SILFunctionType::withSubstitutions(SubstitutionMap subs) const {
48604915
return SILFunctionType::get(getSubstGenericSignature(),

0 commit comments

Comments
 (0)