Skip to content

Commit 7903809

Browse files
authored
[AutoDiff] NFC: Refactor TangentSpace (#21564)
* Fix SR-8770 * Refactor TangentSpace
1 parent 317664d commit 7903809

File tree

4 files changed

+153
-187
lines changed

4 files changed

+153
-187
lines changed

include/swift/AST/ASTContext.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ namespace swift {
104104
class VarDecl;
105105
class UnifiedStatsReporter;
106106
// SWIFT_ENABLE_TENSORFLOW
107-
class TangentSpace;
107+
class VectorSpace;
108108

109109
enum class KnownProtocolKind : uint8_t;
110110

@@ -955,7 +955,7 @@ class ASTContext final {
955955
// SWIFT_ENABLE_TENSORFLOW
956956
/// Compute the tangent space of this manifold, if the given type represents a
957957
/// differentiable manifold.
958-
Optional<TangentSpace> getTangentSpace(CanType type, ModuleDecl *module);
958+
Optional<VectorSpace> getTangentSpace(CanType type, ModuleDecl *module);
959959

960960
private:
961961
friend Decl;

include/swift/AST/AutoDiff.h

Lines changed: 39 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -357,32 +357,29 @@ class StructDecl;
357357
class TupleType;
358358
class EnumDecl;
359359

360-
/// A type that represents the tangent space of a differentiable type.
361-
class TangentSpace {
360+
/// A type that represents a vector space.
361+
class VectorSpace {
362362
public:
363363
/// A tangent space kind.
364364
enum class Kind {
365365
/// `Builtin.FP<...>`.
366-
BuiltinRealScalar,
367-
/// A type that conforms to `FloatingPoint`.
368-
RealScalar,
369-
/// A type that conforms to `VectorNumeric` where the associated
370-
/// `ScalarElement` conforms to `FloatingPoint`.
371-
RealVector,
372-
/// A product of tangent spaces as a struct.
373-
ProductStruct,
374-
/// A product of tangent spaces as a tuple.
375-
ProductTuple,
376-
/// A sum of tangent spaces.
377-
Sum
366+
BuiltinFloat,
367+
/// A type that conforms to `VectorNumeric`.
368+
Vector,
369+
/// A product of vector spaces as a struct.
370+
Struct,
371+
/// A product of vector spaces as a tuple.
372+
Tuple,
373+
/// A union of vector spaces.
374+
Enum
378375
};
379376

380377
private:
381378
Kind kind;
382379
union Value {
383380
// BuiltinRealScalar
384381
BuiltinFloatType *builtinFPType;
385-
// RealScalar or RealVector
382+
// RealVector
386383
NominalTypeDecl *realNominalType;
387384
// ProductStruct
388385
StructDecl *structDecl;
@@ -398,67 +395,55 @@ class TangentSpace {
398395
Value(EnumDecl *enumDecl) : enumDecl(enumDecl) {}
399396
} value;
400397

401-
TangentSpace(Kind kind, Value value)
398+
VectorSpace(Kind kind, Value value)
402399
: kind(kind), value(value) {}
403400

404401
public:
405-
TangentSpace() = delete;
402+
VectorSpace() = delete;
406403

407-
static TangentSpace
404+
static VectorSpace
408405
getBuiltinRealScalarSpace(BuiltinFloatType *builtinFP) {
409-
return {Kind::BuiltinRealScalar, builtinFP};
406+
return {Kind::BuiltinFloat, builtinFP};
410407
}
411-
static TangentSpace getRealScalarSpace(NominalTypeDecl *typeDecl) {
412-
return {Kind::RealScalar, typeDecl};
408+
static VectorSpace getRealVectorSpace(NominalTypeDecl *typeDecl) {
409+
return {Kind::Vector, typeDecl};
413410
}
414-
static TangentSpace getRealVectorSpace(NominalTypeDecl *typeDecl) {
415-
return {Kind::RealVector, typeDecl};
411+
static VectorSpace getStruct(StructDecl *structDecl) {
412+
return {Kind::Struct, structDecl};
416413
}
417-
static TangentSpace getProductStruct(StructDecl *structDecl) {
418-
return {Kind::ProductStruct, structDecl};
414+
static VectorSpace getTuple(TupleType *tupleTy) {
415+
return {Kind::Tuple, tupleTy};
419416
}
420-
static TangentSpace getProductTuple(TupleType *tupleTy) {
421-
return {Kind::ProductTuple, tupleTy};
422-
}
423-
static TangentSpace getSum(EnumDecl *enumDecl) {
424-
return {Kind::Sum, enumDecl};
417+
static VectorSpace getEnum(EnumDecl *enumDecl) {
418+
return {Kind::Enum, enumDecl};
425419
}
426420

427-
bool isBuiltinRealScalarSpace() const {
428-
return kind == Kind::BuiltinRealScalar;
421+
bool isBuiltinFloat() const {
422+
return kind == Kind::BuiltinFloat;
429423
}
430-
bool isRealScalarSpace() const { return kind == Kind::RealScalar; }
431-
bool isRealVectorSpace() const { return kind == Kind::RealVector; }
432-
bool isProductStruct() const { return kind == Kind::ProductStruct; }
433-
bool isProductTuple() const { return kind == Kind::ProductTuple; }
424+
bool isVector() const { return kind == Kind::Vector; }
425+
bool isStruct() const { return kind == Kind::Struct; }
426+
bool isTuple() const { return kind == Kind::Tuple; }
434427

435428
Kind getKind() const { return kind; }
436-
BuiltinFloatType *getBuiltinRealScalarSpace() const {
437-
assert(kind == Kind::BuiltinRealScalar);
429+
BuiltinFloatType *getBuiltinFloat() const {
430+
assert(kind == Kind::BuiltinFloat);
438431
return value.builtinFPType;
439432
}
440-
NominalTypeDecl *getRealScalarSpace() const {
441-
assert(kind == Kind::RealScalar);
442-
return value.realNominalType;
443-
}
444-
NominalTypeDecl *getRealVectorSpace() const {
445-
assert(kind == Kind::RealVector);
446-
return value.realNominalType;
447-
}
448-
NominalTypeDecl *getRealScalarOrVectorSpace() const {
449-
assert(kind == Kind::RealScalar || kind == Kind::RealVector);
433+
NominalTypeDecl *getVector() const {
434+
assert(kind == Kind::Vector);
450435
return value.realNominalType;
451436
}
452-
StructDecl *getProductStruct() const {
453-
assert(kind == Kind::ProductStruct);
437+
StructDecl *getStruct() const {
438+
assert(kind == Kind::Struct);
454439
return value.structDecl;
455440
}
456-
TupleType *getProductTuple() const {
457-
assert(kind == Kind::ProductTuple);
441+
TupleType *getTuple() const {
442+
assert(kind == Kind::Tuple);
458443
return value.tupleType;
459444
}
460-
EnumDecl *getSum() const {
461-
assert(kind == Kind::Sum);
445+
EnumDecl *getEnum() const {
446+
assert(kind == Kind::Enum);
462447
return value.enumDecl;
463448
}
464449
};

lib/AST/ASTContext.cpp

Lines changed: 20 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -382,7 +382,7 @@ FOR_KNOWN_FOUNDATION_TYPES(CACHE_FOUNDATION_DECL)
382382

383383
// SWIFT_ENABLE_TENSORFLOW
384384
/// A cache of tangent spaces per type.
385-
llvm::DenseMap<CanType, Optional<TangentSpace>> TangentSpaces;
385+
llvm::DenseMap<CanType, Optional<VectorSpace>> VectorSpaces;
386386

387387
/// For uniquifying `AutoDiffParameterIndices` allocations.
388388
llvm::FoldingSet<AutoDiffParameterIndices> AutoDiffParameterIndicesSet;
@@ -5195,39 +5195,30 @@ LayoutConstraint LayoutConstraint::getLayoutConstraint(LayoutConstraintKind Kind
51955195
}
51965196

51975197
// SWIFT_ENABLE_TENSORFLOW
5198-
Optional<TangentSpace> ASTContext::getTangentSpace(CanType type,
5199-
ModuleDecl *module) {
5200-
auto lookup = getImpl().TangentSpaces.find(type);
5201-
if (lookup != getImpl().TangentSpaces.end())
5198+
Optional<VectorSpace> ASTContext::getTangentSpace(CanType type,
5199+
ModuleDecl *module) {
5200+
auto lookup = getImpl().VectorSpaces.find(type);
5201+
if (lookup != getImpl().VectorSpaces.end())
52025202
return lookup->getSecond();
52035203
// A helper that is used to cache the computed tangent space for the
52045204
// specified type and retuns the same tangent space.
5205-
auto cache = [&](Optional<TangentSpace> tangentSpace) {
5206-
getImpl().TangentSpaces.insert({type, tangentSpace});
5207-
return tangentSpace;
5205+
auto cache = [&](Optional<VectorSpace> space) {
5206+
getImpl().VectorSpaces.insert({type, space});
5207+
return space;
52085208
};
52095209
// `Builtin.FP<...>` is a builtin real scalar space.
52105210
if (auto *fpType = type->getAs<BuiltinFloatType>())
5211-
return cache(TangentSpace::getBuiltinRealScalarSpace(fpType));
5212-
// Look up conformance to `FloatingPoint`.
5213-
auto *fpProto = getProtocol(KnownProtocolKind::FloatingPoint);
5214-
if (auto maybeFPConf = module->lookupConformance(type, fpProto)) {
5215-
auto *typeDecl = type->getAnyNominal();
5216-
assert(typeDecl);
5217-
return cache(TangentSpace::getRealScalarSpace(typeDecl));
5218-
}
5211+
return cache(VectorSpace::getBuiltinRealScalarSpace(fpType));
52195212
// Look up conformance to `Differentiable`.
52205213
auto *diffableProto = getProtocol(KnownProtocolKind::Differentiable);
5221-
if (auto maybeDiffableConf = module->lookupConformance(type, diffableProto)) {
5222-
auto tangentLookup =
5223-
diffableProto->lookupDirect(getIdentifier("TangentVector"));
5224-
auto *tangentAssocDecl = cast<AssociatedTypeDecl>(tangentLookup[0]);
5225-
auto subMap = type->getMemberSubstitutionMap(module, tangentAssocDecl);
5226-
auto tangent = tangentAssocDecl->getDeclaredInterfaceType().subst(subMap);
5227-
auto *tangentDecl = tangent->getAnyNominal();
5228-
assert(tangentDecl &&
5229-
"Tangent must be a nominal type because it has protocol contraints");
5230-
return cache(TangentSpace::getRealVectorSpace(tangentDecl));
5214+
if (module->lookupConformance(type, diffableProto).hasValue()) {
5215+
auto tangentType = type->getAutoDiffAssociatedType(
5216+
AutoDiffAssociatedTypeKind::TangentVector,
5217+
LookUpConformanceInModule(module));
5218+
assert(tangentType);
5219+
auto *nomTypeDecl = tangentType->getAnyNominal();
5220+
assert(nomTypeDecl);
5221+
return VectorSpace::getRealVectorSpace(nomTypeDecl);
52315222
}
52325223
// Nominal types can be either a struct or an enum.
52335224
if (auto *nominal = type->getAnyNominal()) {
@@ -5243,7 +5234,7 @@ Optional<TangentSpace> ASTContext::getTangentSpace(CanType type,
52435234
module);
52445235
});
52455236
if (allMembersHaveTangentSpace)
5246-
return cache(TangentSpace::getProductStruct(structDecl));
5237+
return cache(VectorSpace::getStruct(structDecl));
52475238
}
52485239
// Frozen enum types, all of whose payloads have a tangent space, are a
52495240
// sum of the product of payloads in each case.
@@ -5263,15 +5254,15 @@ Optional<TangentSpace> ASTContext::getTangentSpace(CanType type,
52635254
});
52645255
});
52655256
if (allMembersHaveTangentSpace)
5266-
return cache(TangentSpace::getSum(enumDecl));
5257+
return cache(VectorSpace::getEnum(enumDecl));
52675258
}
52685259
}
52695260
// Tuple types, each of whose elements has a tangent space, are a product of
52705261
// those tangent space.
52715262
if (TupleType *tupleType = type->getAs<TupleType>())
52725263
if (llvm::all_of(tupleType->getElementTypes(), [&](Type t) {
52735264
return (bool)getTangentSpace(t->getCanonicalType(), module); }))
5274-
return cache(TangentSpace::getProductTuple(tupleType));
5265+
return cache(VectorSpace::getTuple(tupleType));
52755266
// Otherwise, the type does not have a tangent space. That is, it does not
52765267
// support differentiation.
52775268
return cache(None);

0 commit comments

Comments
 (0)