Skip to content

Commit 9040108

Browse files
authored
[NFC] Gardening. (#25498)
- Remove unused SWIFT_ENABLE_TENSORFLOW known protocols. - Re-add AutoDiff test, previously accidentally removed. - Garden derived conformances.
1 parent 5ce378a commit 9040108

File tree

6 files changed

+77
-45
lines changed

6 files changed

+77
-45
lines changed

include/swift/AST/KnownProtocols.def

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,8 +78,6 @@ PROTOCOL(Encodable)
7878
PROTOCOL(Decodable)
7979
// SWIFT_ENABLE_TENSORFLOW
8080
PROTOCOL(AdditiveArithmetic)
81-
PROTOCOL(Numeric)
82-
PROTOCOL(FloatingPoint)
8381
PROTOCOL(KeyPathIterable)
8482
PROTOCOL(TensorArrayProtocol)
8583
PROTOCOL(TensorGroup)

lib/IRGen/GenMeta.cpp

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4189,7 +4189,6 @@ SpecialProtocol irgen::getSpecialProtocolID(ProtocolDecl *P) {
41894189
case KnownProtocolKind::ExpressibleByColorLiteral:
41904190
case KnownProtocolKind::ExpressibleByImageLiteral:
41914191
case KnownProtocolKind::ExpressibleByFileReferenceLiteral:
4192-
// SWIFT_ENABLE_TENSORFLOW
41934192
case KnownProtocolKind::ExpressibleByBuiltinBooleanLiteral:
41944193
case KnownProtocolKind::ExpressibleByBuiltinExtendedGraphemeClusterLiteral:
41954194
case KnownProtocolKind::ExpressibleByBuiltinFloatLiteral:
@@ -4206,9 +4205,7 @@ SpecialProtocol irgen::getSpecialProtocolID(ProtocolDecl *P) {
42064205
case KnownProtocolKind::Decodable:
42074206
case KnownProtocolKind::StringInterpolationProtocol:
42084207
// SWIFT_ENABLE_TENSORFLOW
4209-
case KnownProtocolKind::FloatingPoint:
42104208
case KnownProtocolKind::AdditiveArithmetic:
4211-
case KnownProtocolKind::Numeric:
42124209
case KnownProtocolKind::KeyPathIterable:
42134210
case KnownProtocolKind::TensorArrayProtocol:
42144211
case KnownProtocolKind::TensorGroup:

lib/SILOptimizer/Mandatory/Differentiation.cpp

Lines changed: 5 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -865,18 +865,12 @@ class ADContext {
865865
/// Saved for deletion during cleanup.
866866
SmallVector<SILValue, 32> generatedAssociatedFunctionReferences;
867867

868+
/// The AdditiveArithmetic protocol in the standard library.
869+
ProtocolDecl *additiveArithmeticProtocol =
870+
astCtx.getProtocol(KnownProtocolKind::AdditiveArithmetic);
868871
/// The VectorProtocol protocol in the standard library.
869872
ProtocolDecl *vectorProtocolProtocol =
870873
astCtx.getProtocol(KnownProtocolKind::VectorProtocol);
871-
/// The Numeric protocol in the standard library.
872-
ProtocolDecl *numericProtocol =
873-
astCtx.getProtocol(KnownProtocolKind::Numeric);
874-
/// The AdditiveArithmetic protocol in the standard library.
875-
ProtocolDecl *additiveArithmeticProtocol =
876-
astCtx.getProtocol(KnownProtocolKind::AdditiveArithmetic);
877-
/// The FloatingPoint protocol in the stanard library.
878-
ProtocolDecl *floatingPointProtocol =
879-
astCtx.getProtocol(KnownProtocolKind::FloatingPoint);
880874

881875
/// `AdditiveArithmetic.+` declaration.
882876
mutable FuncDecl *cachedPlusFn = nullptr;
@@ -926,20 +920,12 @@ class ADContext {
926920
return generatedAssociatedFunctionReferences;
927921
}
928922

929-
ProtocolDecl *getVectorProtocolProtocol() const {
930-
return vectorProtocolProtocol;
931-
}
932-
933-
ProtocolDecl *getNumericProtocol() const {
934-
return numericProtocol;
935-
}
936-
937923
ProtocolDecl *getAdditiveArithmeticProtocol() const {
938924
return additiveArithmeticProtocol;
939925
}
940926

941-
ProtocolDecl *getFloatingPointProtocol() const {
942-
return floatingPointProtocol;
927+
ProtocolDecl *getVectorProtocolProtocol() const {
928+
return vectorProtocolProtocol;
943929
}
944930

945931
FuncDecl *getPlusDecl() const {

lib/Sema/DerivedConformanceAdditiveArithmetic.cpp

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -48,15 +48,15 @@ static StringRef getMathOperatorName(MathOperator op) {
4848
}
4949

5050
// Return the protocol requirement with the specified name.
51+
// TODO: Move function to shared place for use with other derived conformances.
5152
static ValueDecl *getProtocolRequirement(ProtocolDecl *proto, Identifier name) {
5253
auto lookup = proto->lookupDirect(name);
53-
lookup.erase(std::remove_if(lookup.begin(), lookup.end(),
54-
[](ValueDecl *v) {
55-
return !isa<ProtocolDecl>(
56-
v->getDeclContext()) ||
57-
!v->isProtocolRequirement();
58-
}),
59-
lookup.end());
54+
// Erase declarations that are not protocol requirements.
55+
// This is important for removing default implementations of the same name.
56+
llvm::erase_if(lookup, [](ValueDecl *v) {
57+
return !isa<ProtocolDecl>(v->getDeclContext()) ||
58+
!v->isProtocolRequirement();
59+
});
6060
assert(lookup.size() == 1 && "Ambiguous protocol requirement");
6161
return lookup.front();
6262
}
@@ -76,6 +76,7 @@ static ConstructorDecl *getOrCreateEffectiveMemberwiseInitializer(
7676
}
7777

7878
// Return true if given nominal type has a `let` stored with an initial value.
79+
// TODO: Move function to shared place for use with other derived conformances.
7980
static bool hasLetStoredPropertyWithInitialValue(NominalTypeDecl *nominal) {
8081
return llvm::any_of(nominal->getStoredProperties(), [&](VarDecl *v) {
8182
return v->isLet() && v->hasInitialValue();

lib/Sema/DerivedConformanceVectorProtocol.cpp

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -31,19 +31,27 @@
3131
using namespace swift;
3232

3333
// Return the protocol requirement with the specified name.
34+
// TODO: Move function to shared place for use with other derived conformances.
3435
static ValueDecl *getProtocolRequirement(ProtocolDecl *proto, Identifier name) {
3536
auto lookup = proto->lookupDirect(name);
36-
lookup.erase(std::remove_if(lookup.begin(), lookup.end(),
37-
[](ValueDecl *v) {
38-
return !isa<ProtocolDecl>(
39-
v->getDeclContext()) ||
40-
!v->isProtocolRequirement();
41-
}),
42-
lookup.end());
37+
// Erase declarations that are not protocol requirements.
38+
// This is important for removing default implementations of the same name.
39+
llvm::erase_if(lookup, [](ValueDecl *v) {
40+
return !isa<ProtocolDecl>(v->getDeclContext()) ||
41+
!v->isProtocolRequirement();
42+
});
4343
assert(lookup.size() == 1 && "Ambiguous protocol requirement");
4444
return lookup.front();
4545
}
4646

47+
// Return true if given nominal type has a `let` stored with an initial value.
48+
// TODO: Move function to shared place for use with other derived conformances.
49+
static bool hasLetStoredPropertyWithInitialValue(NominalTypeDecl *nominal) {
50+
return llvm::any_of(nominal->getStoredProperties(), [&](VarDecl *v) {
51+
return v->isLet() && v->hasInitialValue();
52+
});
53+
}
54+
4755
// Return the `VectorSpaceScalar` associated type for the given `ValueDecl` if
4856
// it conforms to `VectorProtocol` in the given context. Otherwise, return
4957
// `nullptr`.
@@ -97,13 +105,6 @@ static Type deriveVectorProtocol_VectorSpaceScalar(NominalTypeDecl *nominal,
97105
return sameScalarType;
98106
}
99107

100-
// Return true if given nominal type has a `let` stored with an initial value.
101-
static bool hasLetStoredPropertyWithInitialValue(NominalTypeDecl *nominal) {
102-
return llvm::any_of(nominal->getStoredProperties(), [&](VarDecl *v) {
103-
return v->isLet() && v->hasInitialValue();
104-
});
105-
}
106-
107108
bool DerivedConformance::canDeriveVectorProtocol(NominalTypeDecl *nominal,
108109
DeclContext *DC) {
109110
// Must not have any `let` stored properties with an initial value.
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
// RUN: %target-run-simple-swift
2+
// REQUIRES: executable_test
3+
4+
import StdlibUnittest
5+
#if os(macOS)
6+
import Darwin.C
7+
#else
8+
import Glibc
9+
#endif
10+
11+
var SeparateTangentTypeTests = TestSuite("SeparateTangentType")
12+
13+
struct DifferentiableSubset : Differentiable {
14+
@differentiable(wrt: self)
15+
var w: Float
16+
@differentiable(wrt: self)
17+
var b: Float
18+
@noDerivative var flag: Bool
19+
20+
struct TangentVector : Differentiable, VectorProtocol {
21+
typealias TangentVector = DifferentiableSubset.TangentVector
22+
var w: Float
23+
var b: Float
24+
}
25+
mutating func move(along v: TangentVector) {
26+
w.move(along: v.w)
27+
b.move(along: v.b)
28+
}
29+
}
30+
31+
SeparateTangentTypeTests.test("Trivial") {
32+
let x = DifferentiableSubset(w: 0, b: 1, flag: false)
33+
let pb = pullback(at: x) { x in x }
34+
expectEqual(pb(DifferentiableSubset.TangentVector.zero), DifferentiableSubset.TangentVector.zero)
35+
}
36+
37+
SeparateTangentTypeTests.test("Initialization") {
38+
let x = DifferentiableSubset(w: 0, b: 1, flag: false)
39+
let pb = pullback(at: x) { x in DifferentiableSubset(w: 1, b: 2, flag: true) }
40+
expectEqual(pb(DifferentiableSubset.TangentVector.zero), DifferentiableSubset.TangentVector.zero)
41+
}
42+
43+
SeparateTangentTypeTests.test("SomeArithmetics") {
44+
let x = DifferentiableSubset(w: 0, b: 1, flag: false)
45+
let pb = pullback(at: x) { x in DifferentiableSubset(w: x.w * x.w, b: x.b * x.b, flag: true) }
46+
expectEqual(pb(DifferentiableSubset.TangentVector.zero), DifferentiableSubset.TangentVector.zero)
47+
}
48+
49+
runAllTests()

0 commit comments

Comments
 (0)