Skip to content

Gardening. #25498

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jun 15, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions include/swift/AST/KnownProtocols.def
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,6 @@ PROTOCOL(Encodable)
PROTOCOL(Decodable)
// SWIFT_ENABLE_TENSORFLOW
PROTOCOL(AdditiveArithmetic)
PROTOCOL(Numeric)
PROTOCOL(FloatingPoint)
PROTOCOL(KeyPathIterable)
PROTOCOL(TensorArrayProtocol)
PROTOCOL(TensorGroup)
Expand Down
3 changes: 0 additions & 3 deletions lib/IRGen/GenMeta.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4189,7 +4189,6 @@ SpecialProtocol irgen::getSpecialProtocolID(ProtocolDecl *P) {
case KnownProtocolKind::ExpressibleByColorLiteral:
case KnownProtocolKind::ExpressibleByImageLiteral:
case KnownProtocolKind::ExpressibleByFileReferenceLiteral:
// SWIFT_ENABLE_TENSORFLOW
case KnownProtocolKind::ExpressibleByBuiltinBooleanLiteral:
case KnownProtocolKind::ExpressibleByBuiltinExtendedGraphemeClusterLiteral:
case KnownProtocolKind::ExpressibleByBuiltinFloatLiteral:
Expand All @@ -4206,9 +4205,7 @@ SpecialProtocol irgen::getSpecialProtocolID(ProtocolDecl *P) {
case KnownProtocolKind::Decodable:
case KnownProtocolKind::StringInterpolationProtocol:
// SWIFT_ENABLE_TENSORFLOW
case KnownProtocolKind::FloatingPoint:
case KnownProtocolKind::AdditiveArithmetic:
case KnownProtocolKind::Numeric:
case KnownProtocolKind::KeyPathIterable:
case KnownProtocolKind::TensorArrayProtocol:
case KnownProtocolKind::TensorGroup:
Expand Down
24 changes: 5 additions & 19 deletions lib/SILOptimizer/Mandatory/Differentiation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -865,18 +865,12 @@ class ADContext {
/// Saved for deletion during cleanup.
SmallVector<SILValue, 32> generatedAssociatedFunctionReferences;

/// The AdditiveArithmetic protocol in the standard library.
ProtocolDecl *additiveArithmeticProtocol =
astCtx.getProtocol(KnownProtocolKind::AdditiveArithmetic);
/// The VectorProtocol protocol in the standard library.
ProtocolDecl *vectorProtocolProtocol =
astCtx.getProtocol(KnownProtocolKind::VectorProtocol);
/// The Numeric protocol in the standard library.
ProtocolDecl *numericProtocol =
astCtx.getProtocol(KnownProtocolKind::Numeric);
/// The AdditiveArithmetic protocol in the standard library.
ProtocolDecl *additiveArithmeticProtocol =
astCtx.getProtocol(KnownProtocolKind::AdditiveArithmetic);
/// The FloatingPoint protocol in the stanard library.
ProtocolDecl *floatingPointProtocol =
astCtx.getProtocol(KnownProtocolKind::FloatingPoint);

/// `AdditiveArithmetic.+` declaration.
mutable FuncDecl *cachedPlusFn = nullptr;
Expand Down Expand Up @@ -926,20 +920,12 @@ class ADContext {
return generatedAssociatedFunctionReferences;
}

ProtocolDecl *getVectorProtocolProtocol() const {
return vectorProtocolProtocol;
}

ProtocolDecl *getNumericProtocol() const {
return numericProtocol;
}

ProtocolDecl *getAdditiveArithmeticProtocol() const {
return additiveArithmeticProtocol;
}

ProtocolDecl *getFloatingPointProtocol() const {
return floatingPointProtocol;
ProtocolDecl *getVectorProtocolProtocol() const {
return vectorProtocolProtocol;
}

FuncDecl *getPlusDecl() const {
Expand Down
15 changes: 8 additions & 7 deletions lib/Sema/DerivedConformanceAdditiveArithmetic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,20 +48,21 @@ static StringRef getMathOperatorName(MathOperator op) {
}

// Return the protocol requirement with the specified name.
// TODO: Move function to shared place for use with other derived conformances.
static ValueDecl *getProtocolRequirement(ProtocolDecl *proto, Identifier name) {
auto lookup = proto->lookupDirect(name);
lookup.erase(std::remove_if(lookup.begin(), lookup.end(),
[](ValueDecl *v) {
return !isa<ProtocolDecl>(
v->getDeclContext()) ||
!v->isProtocolRequirement();
}),
lookup.end());
// Erase declarations that are not protocol requirements.
// This is important for removing default implementations of the same name.
llvm::erase_if(lookup, [](ValueDecl *v) {
return !isa<ProtocolDecl>(v->getDeclContext()) ||
!v->isProtocolRequirement();
});
assert(lookup.size() == 1 && "Ambiguous protocol requirement");
return lookup.front();
}

// Return true if given nominal type has a `let` stored with an initial value.
// TODO: Move function to shared place for use with other derived conformances.
static bool hasLetStoredPropertyWithInitialValue(NominalTypeDecl *nominal) {
return llvm::any_of(nominal->getStoredProperties(), [&](VarDecl *v) {
return v->isLet() && v->hasInitialValue();
Expand Down
29 changes: 15 additions & 14 deletions lib/Sema/DerivedConformanceVectorProtocol.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,19 +31,27 @@
using namespace swift;

// Return the protocol requirement with the specified name.
// TODO: Move function to shared place for use with other derived conformances.
static ValueDecl *getProtocolRequirement(ProtocolDecl *proto, Identifier name) {
auto lookup = proto->lookupDirect(name);
lookup.erase(std::remove_if(lookup.begin(), lookup.end(),
[](ValueDecl *v) {
return !isa<ProtocolDecl>(
v->getDeclContext()) ||
!v->isProtocolRequirement();
}),
lookup.end());
// Erase declarations that are not protocol requirements.
// This is important for removing default implementations of the same name.
llvm::erase_if(lookup, [](ValueDecl *v) {
return !isa<ProtocolDecl>(v->getDeclContext()) ||
!v->isProtocolRequirement();
});
assert(lookup.size() == 1 && "Ambiguous protocol requirement");
return lookup.front();
}

// Return true if given nominal type has a `let` stored with an initial value.
// TODO: Move function to shared place for use with other derived conformances.
static bool hasLetStoredPropertyWithInitialValue(NominalTypeDecl *nominal) {
return llvm::any_of(nominal->getStoredProperties(), [&](VarDecl *v) {
return v->isLet() && v->hasInitialValue();
});
}

// Return the `VectorSpaceScalar` associated type for the given `ValueDecl` if
// it conforms to `VectorProtocol` in the given context. Otherwise, return
// `nullptr`.
Expand Down Expand Up @@ -97,13 +105,6 @@ static Type deriveVectorProtocol_VectorSpaceScalar(NominalTypeDecl *nominal,
return sameScalarType;
}

// Return true if given nominal type has a `let` stored with an initial value.
static bool hasLetStoredPropertyWithInitialValue(NominalTypeDecl *nominal) {
return llvm::any_of(nominal->getStoredProperties(), [&](VarDecl *v) {
return v->isLet() && v->hasInitialValue();
});
}

bool DerivedConformance::canDeriveVectorProtocol(NominalTypeDecl *nominal,
DeclContext *DC) {
// Must not have any `let` stored properties with an initial value.
Expand Down
49 changes: 49 additions & 0 deletions test/AutoDiff/separate_tangent_type.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
// RUN: %target-run-simple-swift
// REQUIRES: executable_test

import StdlibUnittest
#if os(macOS)
import Darwin.C
#else
import Glibc
#endif

var SeparateTangentTypeTests = TestSuite("SeparateTangentType")

struct DifferentiableSubset : Differentiable {
@differentiable(wrt: self)
var w: Float
@differentiable(wrt: self)
var b: Float
@noDerivative var flag: Bool

struct TangentVector : Differentiable, VectorProtocol {
typealias TangentVector = DifferentiableSubset.TangentVector
var w: Float
var b: Float
}
mutating func move(along v: TangentVector) {
w.move(along: v.w)
b.move(along: v.b)
}
}

SeparateTangentTypeTests.test("Trivial") {
let x = DifferentiableSubset(w: 0, b: 1, flag: false)
let pb = pullback(at: x) { x in x }
expectEqual(pb(DifferentiableSubset.TangentVector.zero), DifferentiableSubset.TangentVector.zero)
}

SeparateTangentTypeTests.test("Initialization") {
let x = DifferentiableSubset(w: 0, b: 1, flag: false)
let pb = pullback(at: x) { x in DifferentiableSubset(w: 1, b: 2, flag: true) }
expectEqual(pb(DifferentiableSubset.TangentVector.zero), DifferentiableSubset.TangentVector.zero)
}

SeparateTangentTypeTests.test("SomeArithmetics") {
let x = DifferentiableSubset(w: 0, b: 1, flag: false)
let pb = pullback(at: x) { x in DifferentiableSubset(w: x.w * x.w, b: x.b * x.b, flag: true) }
expectEqual(pb(DifferentiableSubset.TangentVector.zero), DifferentiableSubset.TangentVector.zero)
}

runAllTests()