Skip to content

Commit a095223

Browse files
committed
---
yaml --- r: 262127 b: refs/heads/tensorflow c: 78c9f57 h: refs/heads/master i: 262125: 406f4db 262123: 56059af 262119: a05fe24 262111: 16bc367
1 parent 258eb36 commit a095223

22 files changed

+189
-570
lines changed

[refs]

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -818,7 +818,7 @@ refs/tags/swift-DEVELOPMENT-SNAPSHOT-2018-04-25-a: 22f738a831d43aff2b9c9773bcb65
818818
refs/tags/swift-DEVELOPMENT-SNAPSHOT-2018-05-08-a: 7d98cc16689baba5c8a3b90a9329bdcc1a12b4e9
819819
refs/heads/cherr42: a566ad54b073c2c56ac0a705d0a5bed9743135a5
820820
"refs/heads/codable_test_comment_fix": fc8f6824f7f347e1e8db55bff62db385c5728b5a
821-
refs/heads/tensorflow: ba9b37e342043993abbb3b2c566870ee56b3f4bc
821+
refs/heads/tensorflow: 78c9f57e76df3bd298db5767854a52776f284d09
822822
refs/tags/swift-4.1-DEVELOPMENT-SNAPSHOT-2018-05-11-a: 8126fd7a652e2f70ad6d76505239e34fb2ef3e1a
823823
refs/tags/swift-4.1-DEVELOPMENT-SNAPSHOT-2018-05-12-a: b3fd3dd84df6717f2e2e9df58c6d7e99fed57086
824824
refs/tags/swift-4.1-DEVELOPMENT-SNAPSHOT-2018-05-13-a: 71135119579039dc321c5f65d870050fe36efda2

branches/tensorflow/include/swift/AST/ASTContext.h

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,7 @@ namespace swift {
104104
class VarDecl;
105105
class UnifiedStatsReporter;
106106
// SWIFT_ENABLE_TENSORFLOW
107+
enum class AutoDiffAssociatedVectorSpaceKind : unsigned;
107108
class VectorSpace;
108109

109110
enum class KnownProtocolKind : uint8_t;
@@ -269,6 +270,10 @@ class ASTContext final {
269270
/// Cache of remapped types (useful for diagnostics).
270271
llvm::StringMap<Type> RemappedTypes;
271272

273+
/// Cache of autodiff-associated vector spaces.
274+
llvm::DenseMap<std::pair<Type, unsigned>,
275+
Optional<VectorSpace>> AutoDiffVectorSpaces;
276+
272277
private:
273278
/// \brief The current generation number, which reflects the number of
274279
/// times that external modules have been loaded.
@@ -952,11 +957,6 @@ class ASTContext final {
952957
return LangOpts.isSwiftVersionAtLeast(major, minor);
953958
}
954959

955-
// SWIFT_ENABLE_TENSORFLOW
956-
/// Compute the tangent space of this manifold, if the given type represents a
957-
/// differentiable manifold.
958-
Optional<VectorSpace> getTangentSpace(CanType type, ModuleDecl *module);
959-
960960
private:
961961
friend Decl;
962962
Optional<RawComment> getRawComment(const Decl *D);

branches/tensorflow/include/swift/AST/AutoDiff.h

Lines changed: 18 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -315,12 +315,8 @@ class AutoDiffAssociatedFunctionIdentifier : public llvm::FoldingSetNode {
315315
};
316316

317317
/// The kind of an associated type.
318-
struct AutoDiffAssociatedTypeKind {
319-
enum innerty : uint8_t { TangentVector = 0, CotangentVector = 1 } rawValue;
320-
321-
AutoDiffAssociatedTypeKind() = default;
322-
AutoDiffAssociatedTypeKind(innerty rawValue) : rawValue(rawValue) {}
323-
operator innerty() const { return rawValue; }
318+
enum class AutoDiffAssociatedVectorSpaceKind : unsigned {
319+
Tangent = 0, Cotangent = 1
324320
};
325321

326322
/// Automatic differentiation utility namespace.
@@ -352,10 +348,8 @@ bool getBuiltinAutoDiffApplyConfig(StringRef operationName,
352348
} // end namespace autodiff
353349

354350
class BuiltinFloatType;
355-
class NominalTypeDecl;
356-
class StructDecl;
351+
class NominalOrBoundGenericNominalType;
357352
class TupleType;
358-
class EnumDecl;
359353

360354
/// A type that represents a vector space.
361355
class VectorSpace {
@@ -366,33 +360,23 @@ class VectorSpace {
366360
BuiltinFloat,
367361
/// A type that conforms to `VectorNumeric`.
368362
Vector,
369-
/// A product of vector spaces as a struct.
370-
Struct,
371363
/// A product of vector spaces as a tuple.
372364
Tuple,
373-
/// A union of vector spaces.
374-
Enum
375365
};
376366

377367
private:
378368
Kind kind;
379369
union Value {
380-
// BuiltinRealScalar
370+
// Builtin float
381371
BuiltinFloatType *builtinFPType;
382-
// RealVector
383-
NominalTypeDecl *realNominalType;
384-
// ProductStruct
385-
StructDecl *structDecl;
386-
// ProductTuple
372+
// Vector
373+
Type vectorType;
374+
// Tuple
387375
TupleType *tupleType;
388-
// Sum
389-
EnumDecl *enumDecl;
390376

391377
Value(BuiltinFloatType *builtinFP) : builtinFPType(builtinFP) {}
392-
Value(NominalTypeDecl *nominal) : realNominalType(nominal) {}
393-
Value(StructDecl *structDecl) : structDecl(structDecl) {}
378+
Value(Type vectorType) : vectorType(vectorType) {}
394379
Value(TupleType *tupleType) : tupleType(tupleType) {}
395-
Value(EnumDecl *enumDecl) : enumDecl(enumDecl) {}
396380
} value;
397381

398382
VectorSpace(Kind kind, Value value)
@@ -401,51 +385,37 @@ class VectorSpace {
401385
public:
402386
VectorSpace() = delete;
403387

404-
static VectorSpace
405-
getBuiltinRealScalarSpace(BuiltinFloatType *builtinFP) {
388+
static VectorSpace getBuiltinFloat(BuiltinFloatType *builtinFP) {
406389
return {Kind::BuiltinFloat, builtinFP};
407390
}
408-
static VectorSpace getRealVectorSpace(NominalTypeDecl *typeDecl) {
409-
return {Kind::Vector, typeDecl};
410-
}
411-
static VectorSpace getStruct(StructDecl *structDecl) {
412-
return {Kind::Struct, structDecl};
391+
static VectorSpace getVector(Type vectorType) {
392+
return {Kind::Vector, vectorType};
413393
}
414394
static VectorSpace getTuple(TupleType *tupleTy) {
415395
return {Kind::Tuple, tupleTy};
416396
}
417-
static VectorSpace getEnum(EnumDecl *enumDecl) {
418-
return {Kind::Enum, enumDecl};
419-
}
420397

421-
bool isBuiltinFloat() const {
422-
return kind == Kind::BuiltinFloat;
423-
}
398+
bool isBuiltinFloat() const { return kind == Kind::BuiltinFloat; }
424399
bool isVector() const { return kind == Kind::Vector; }
425-
bool isStruct() const { return kind == Kind::Struct; }
426400
bool isTuple() const { return kind == Kind::Tuple; }
427401

428402
Kind getKind() const { return kind; }
429403
BuiltinFloatType *getBuiltinFloat() const {
430404
assert(kind == Kind::BuiltinFloat);
431405
return value.builtinFPType;
432406
}
433-
NominalTypeDecl *getVector() const {
407+
Type getVector() const {
434408
assert(kind == Kind::Vector);
435-
return value.realNominalType;
436-
}
437-
StructDecl *getStruct() const {
438-
assert(kind == Kind::Struct);
439-
return value.structDecl;
409+
return value.vectorType;
440410
}
441411
TupleType *getTuple() const {
442412
assert(kind == Kind::Tuple);
443413
return value.tupleType;
444414
}
445-
EnumDecl *getEnum() const {
446-
assert(kind == Kind::Enum);
447-
return value.enumDecl;
448-
}
415+
416+
Type getType() const;
417+
CanType getCanonicalType() const;
418+
NominalTypeDecl *getNominal() const;
449419
};
450420

451421
} // end namespace swift

branches/tensorflow/include/swift/AST/DiagnosticsSIL.def

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -374,8 +374,6 @@ ERROR(autodiff_unsupported_type,none,
374374
"differentiating '%0' is not supported yet", (Type))
375375
ERROR(autodiff_function_not_differentiable,none,
376376
"function is not differentiable", ())
377-
ERROR(autodiff_property_not_differentiable,none,
378-
"property is not differentiable", ())
379377
NOTE(autodiff_function_generic_functions_unsupported,none,
380378
"differentiating generic functions is not supported yet", ())
381379
NOTE(autodiff_value_defined_here,none,

branches/tensorflow/include/swift/AST/TensorFlow.h

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020

2121
#include "swift/Basic/LLVM.h"
2222
#include "llvm/ADT/DenseMap.h"
23-
#include "llvm/ADT/SetVector.h"
2423

2524
namespace swift {
2625
class CanType;
@@ -93,12 +92,7 @@ namespace tf {
9392
bool containsTensorFlowValue(Type ty, bool checkHigherOrderFunctions);
9493

9594
private:
96-
bool containsTensorFlowValueImpl(
97-
Type ty, bool checkHigherOrderFunctions,
98-
llvm::SetVector<NominalTypeDecl *> &parentDecls);
99-
100-
bool structContainsTensorFlowValue(
101-
StructDecl *decl, llvm::SetVector<NominalTypeDecl *> &parentDecls);
95+
bool structContainsTensorFlowValue(StructDecl *decl);
10296
};
10397

10498
/// This class provides a single source of truth for the set of types that are

branches/tensorflow/include/swift/AST/Types.h

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1079,7 +1079,7 @@ class alignas(1 << TypeAlignInBits) TypeBase {
10791079
TypeTraitResult canBeClass();
10801080

10811081
// SWIFT_ENABLE_TENSORFLOW
1082-
/// Returns the associated tangent or cotangent type. Returns the null type if
1082+
/// Return the associated tangent or cotangent type. Return the null type if
10831083
/// there is no associated tangent/cotangent type.
10841084
///
10851085
/// `kind` specifies whether to return the tangent or cotangent type.
@@ -1090,8 +1090,9 @@ class alignas(1 << TypeAlignInBits) TypeBase {
10901090
/// associated tangent/cotangent type is the elementwise tangent/cotangent
10911091
/// type of its elements. If the type is a builtin float, then the associated
10921092
/// tangent/cotangent type is itself. Otherwise, there is no associated type.
1093-
Type getAutoDiffAssociatedType(AutoDiffAssociatedTypeKind kind,
1094-
LookupConformanceFn lookupConformance);
1093+
Optional<VectorSpace>
1094+
getAutoDiffAssociatedVectorSpace(AutoDiffAssociatedVectorSpaceKind kind,
1095+
LookupConformanceFn lookupConformance);
10951096

10961097
private:
10971098
// Make vanilla new/delete illegal for Types.

branches/tensorflow/lib/AST/ASTContext.cpp

Lines changed: 0 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -5195,79 +5195,6 @@ LayoutConstraint LayoutConstraint::getLayoutConstraint(LayoutConstraintKind Kind
51955195
}
51965196

51975197
// SWIFT_ENABLE_TENSORFLOW
5198-
Optional<VectorSpace> ASTContext::getTangentSpace(CanType type,
5199-
ModuleDecl *module) {
5200-
auto lookup = getImpl().VectorSpaces.find(type);
5201-
if (lookup != getImpl().VectorSpaces.end())
5202-
return lookup->getSecond();
5203-
// A helper that is used to cache the computed tangent space for the
5204-
// specified type and retuns the same tangent space.
5205-
auto cache = [&](Optional<VectorSpace> space) {
5206-
getImpl().VectorSpaces.insert({type, space});
5207-
return space;
5208-
};
5209-
// `Builtin.FP<...>` is a builtin real scalar space.
5210-
if (auto *fpType = type->getAs<BuiltinFloatType>())
5211-
return cache(VectorSpace::getBuiltinRealScalarSpace(fpType));
5212-
// Look up conformance to `Differentiable`.
5213-
auto *diffableProto = getProtocol(KnownProtocolKind::Differentiable);
5214-
if (module->lookupConformance(type, diffableProto).hasValue()) {
5215-
auto tangentType = type->getAutoDiffAssociatedType(
5216-
AutoDiffAssociatedTypeKind::TangentVector,
5217-
LookUpConformanceInModule(module));
5218-
assert(tangentType);
5219-
auto *nomTypeDecl = tangentType->getAnyNominal();
5220-
assert(nomTypeDecl);
5221-
return VectorSpace::getRealVectorSpace(nomTypeDecl);
5222-
}
5223-
// Nominal types can be either a struct or an enum.
5224-
if (auto *nominal = type->getAnyNominal()) {
5225-
// Fixed-layout struct types, each of whose elements has a tangent space,
5226-
// are a product of those tangent spaces.
5227-
if (auto *structDecl = dyn_cast<StructDecl>(nominal)) {
5228-
if (structDecl->getFormalAccess() >= AccessLevel::Public &&
5229-
!structDecl->getAttrs().hasAttribute<FixedLayoutAttr>())
5230-
return cache(None);
5231-
auto allMembersHaveTangentSpace =
5232-
llvm::all_of(structDecl->getStoredProperties(), [&](VarDecl *v) {
5233-
return (bool)getTangentSpace(v->getType()->getCanonicalType(),
5234-
module);
5235-
});
5236-
if (allMembersHaveTangentSpace)
5237-
return cache(VectorSpace::getStruct(structDecl));
5238-
}
5239-
// Frozen enum types, all of whose payloads have a tangent space, are a
5240-
// sum of the product of payloads in each case.
5241-
if (auto *enumDecl = dyn_cast<EnumDecl>(nominal)) {
5242-
if (enumDecl->getFormalAccess() >= AccessLevel::Public &&
5243-
!enumDecl->getAttrs().hasAttribute<FrozenAttr>())
5244-
return cache(None);
5245-
if (enumDecl->isIndirect())
5246-
return cache(None);
5247-
auto allMembersHaveTangentSpace =
5248-
llvm::all_of(enumDecl->getAllCases(), [&](EnumCaseDecl *cd) {
5249-
return llvm::all_of(cd->getElements(), [&](EnumElementDecl *eed) {
5250-
return llvm::all_of(*eed->getParameterList(), [&](ParamDecl *pd) {
5251-
return (bool)
5252-
getTangentSpace(pd->getType()->getCanonicalType(), module);
5253-
});
5254-
});
5255-
});
5256-
if (allMembersHaveTangentSpace)
5257-
return cache(VectorSpace::getEnum(enumDecl));
5258-
}
5259-
}
5260-
// Tuple types, each of whose elements has a tangent space, are a product of
5261-
// those tangent space.
5262-
if (TupleType *tupleType = type->getAs<TupleType>())
5263-
if (llvm::all_of(tupleType->getElementTypes(), [&](Type t) {
5264-
return (bool)getTangentSpace(t->getCanonicalType(), module); }))
5265-
return cache(VectorSpace::getTuple(tupleType));
5266-
// Otherwise, the type does not have a tangent space. That is, it does not
5267-
// support differentiation.
5268-
return cache(None);
5269-
}
5270-
52715198
AutoDiffParameterIndices *
52725199
AutoDiffParameterIndices::get(llvm::SmallBitVector indices, ASTContext &C) {
52735200
auto &foldingSet = C.getImpl().AutoDiffParameterIndicesSet;

branches/tensorflow/lib/AST/AutoDiff.cpp

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -282,3 +282,22 @@ void AutoDiffParameterIndicesBuilder::setParameter(unsigned paramIndex) {
282282
assert(paramIndex < parameters.size() && "paramIndex out of bounds");
283283
parameters.set(paramIndex);
284284
}
285+
286+
Type VectorSpace::getType() const {
287+
switch (kind) {
288+
case Kind::BuiltinFloat:
289+
return value.builtinFPType;
290+
case Kind::Vector:
291+
return value.vectorType;
292+
case Kind::Tuple:
293+
return value.tupleType;
294+
}
295+
}
296+
297+
CanType VectorSpace::getCanonicalType() const {
298+
return getType()->getCanonicalType();
299+
}
300+
301+
NominalTypeDecl *VectorSpace::getNominal() const {
302+
return getVector()->getNominalOrBoundGenericNominal();
303+
}

0 commit comments

Comments
 (0)