Skip to content

Commit 1c6f610

Browse files
committed
Gardening.
- Remove unused SWIFT_ENABLE_TENSORFLOW known protocols. - Readd AutoDiff test, previously accidentally removed.
1 parent c1211a3 commit 1c6f610

File tree

4 files changed

+54
-24
lines changed

4 files changed

+54
-24
lines changed

include/swift/AST/KnownProtocols.def

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,8 +78,6 @@ PROTOCOL(Encodable)
7878
PROTOCOL(Decodable)
7979
// SWIFT_ENABLE_TENSORFLOW
8080
PROTOCOL(AdditiveArithmetic)
81-
PROTOCOL(Numeric)
82-
PROTOCOL(FloatingPoint)
8381
PROTOCOL(KeyPathIterable)
8482
PROTOCOL(TensorArrayProtocol)
8583
PROTOCOL(TensorGroup)

lib/IRGen/GenMeta.cpp

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4189,7 +4189,6 @@ SpecialProtocol irgen::getSpecialProtocolID(ProtocolDecl *P) {
41894189
case KnownProtocolKind::ExpressibleByColorLiteral:
41904190
case KnownProtocolKind::ExpressibleByImageLiteral:
41914191
case KnownProtocolKind::ExpressibleByFileReferenceLiteral:
4192-
// SWIFT_ENABLE_TENSORFLOW
41934192
case KnownProtocolKind::ExpressibleByBuiltinBooleanLiteral:
41944193
case KnownProtocolKind::ExpressibleByBuiltinExtendedGraphemeClusterLiteral:
41954194
case KnownProtocolKind::ExpressibleByBuiltinFloatLiteral:
@@ -4206,9 +4205,7 @@ SpecialProtocol irgen::getSpecialProtocolID(ProtocolDecl *P) {
42064205
case KnownProtocolKind::Decodable:
42074206
case KnownProtocolKind::StringInterpolationProtocol:
42084207
// SWIFT_ENABLE_TENSORFLOW
4209-
case KnownProtocolKind::FloatingPoint:
42104208
case KnownProtocolKind::AdditiveArithmetic:
4211-
case KnownProtocolKind::Numeric:
42124209
case KnownProtocolKind::KeyPathIterable:
42134210
case KnownProtocolKind::TensorArrayProtocol:
42144211
case KnownProtocolKind::TensorGroup:

lib/SILOptimizer/Mandatory/Differentiation.cpp

Lines changed: 5 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -865,18 +865,12 @@ class ADContext {
865865
/// Saved for deletion during cleanup.
866866
SmallVector<SILValue, 32> generatedAssociatedFunctionReferences;
867867

868+
/// The AdditiveArithmetic protocol in the standard library.
869+
ProtocolDecl *additiveArithmeticProtocol =
870+
astCtx.getProtocol(KnownProtocolKind::AdditiveArithmetic);
868871
/// The VectorProtocol protocol in the standard library.
869872
ProtocolDecl *vectorProtocolProtocol =
870873
astCtx.getProtocol(KnownProtocolKind::VectorProtocol);
871-
/// The Numeric protocol in the standard library.
872-
ProtocolDecl *numericProtocol =
873-
astCtx.getProtocol(KnownProtocolKind::Numeric);
874-
/// The AdditiveArithmetic protocol in the standard library.
875-
ProtocolDecl *additiveArithmeticProtocol =
876-
astCtx.getProtocol(KnownProtocolKind::AdditiveArithmetic);
877-
/// The FloatingPoint protocol in the stanard library.
878-
ProtocolDecl *floatingPointProtocol =
879-
astCtx.getProtocol(KnownProtocolKind::FloatingPoint);
880874

881875
/// `AdditiveArithmetic.+` declaration.
882876
mutable FuncDecl *cachedPlusFn = nullptr;
@@ -926,20 +920,12 @@ class ADContext {
926920
return generatedAssociatedFunctionReferences;
927921
}
928922

929-
ProtocolDecl *getVectorProtocolProtocol() const {
930-
return vectorProtocolProtocol;
931-
}
932-
933-
ProtocolDecl *getNumericProtocol() const {
934-
return numericProtocol;
935-
}
936-
937923
ProtocolDecl *getAdditiveArithmeticProtocol() const {
938924
return additiveArithmeticProtocol;
939925
}
940926

941-
ProtocolDecl *getFloatingPointProtocol() const {
942-
return floatingPointProtocol;
927+
ProtocolDecl *getVectorProtocolProtocol() const {
928+
return vectorProtocolProtocol;
943929
}
944930

945931
FuncDecl *getPlusDecl() const {
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
// RUN: %target-run-simple-swift
2+
// REQUIRES: executable_test
3+
4+
import StdlibUnittest
5+
#if os(macOS)
6+
import Darwin.C
7+
#else
8+
import Glibc
9+
#endif
10+
11+
var SeparateTangentTypeTests = TestSuite("SeparateTangentType")
12+
13+
struct DifferentiableSubset : Differentiable {
14+
@differentiable(wrt: self)
15+
var w: Float
16+
@differentiable(wrt: self)
17+
var b: Float
18+
@noDerivative var flag: Bool
19+
20+
struct TangentVector : Differentiable, VectorProtocol {
21+
typealias TangentVector = DifferentiableSubset.TangentVector
22+
var w: Float
23+
var b: Float
24+
}
25+
mutating func move(along v: TangentVector) {
26+
w.move(along: v.w)
27+
b.move(along: v.b)
28+
}
29+
}
30+
31+
SeparateTangentTypeTests.test("Trivial") {
32+
let x = DifferentiableSubset(w: 0, b: 1, flag: false)
33+
let pb = pullback(at: x) { x in x }
34+
expectEqual(pb(DifferentiableSubset.TangentVector.zero), DifferentiableSubset.TangentVector.zero)
35+
}
36+
37+
SeparateTangentTypeTests.test("Initialization") {
38+
let x = DifferentiableSubset(w: 0, b: 1, flag: false)
39+
let pb = pullback(at: x) { x in DifferentiableSubset(w: 1, b: 2, flag: true) }
40+
expectEqual(pb(DifferentiableSubset.TangentVector.zero), DifferentiableSubset.TangentVector.zero)
41+
}
42+
43+
SeparateTangentTypeTests.test("SomeArithmetics") {
44+
let x = DifferentiableSubset(w: 0, b: 1, flag: false)
45+
let pb = pullback(at: x) { x in DifferentiableSubset(w: x.w * x.w, b: x.b * x.b, flag: true) }
46+
expectEqual(pb(DifferentiableSubset.TangentVector.zero), DifferentiableSubset.TangentVector.zero)
47+
}
48+
49+
runAllTests()

0 commit comments

Comments
 (0)