Skip to content
This repository was archived by the owner on Apr 23, 2025. It is now read-only.

Commit 4275c36

Browse files
brettkooncedan-zheng
authored andcommitted
s/CotangentVector/TangentVector/g (#162)
1 parent f9687bf commit 4275c36

File tree

4 files changed

+12
-12
lines changed

4 files changed

+12
-12
lines changed

CIFAR/Helpers.swift

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,9 @@ extension Array where Element: Differentiable {
2525
func reduceDerivative<Result: Differentiable>(
2626
_ initialResult: Result,
2727
_ nextPartialResult: @differentiable (Result, Element) -> Result
28-
) -> (Result, (Result.CotangentVector) -> (Array.CotangentVector, Result.CotangentVector)) {
29-
var pullbacks: [(Result.CotangentVector)
30-
-> (Result.CotangentVector, Element.CotangentVector)] = []
28+
) -> (Result, (Result.TangentVector) -> (Array.TangentVector, Result.TangentVector)) {
29+
var pullbacks: [(Result.TangentVector)
30+
-> (Result.TangentVector, Element.TangentVector)] = []
3131
let count = self.count
3232
pullbacks.reserveCapacity(count)
3333
var result = initialResult
@@ -38,14 +38,14 @@ extension Array where Element: Differentiable {
3838
}
3939
return (value: result, pullback: { cotangent in
4040
var resultCotangent = cotangent
41-
var elementCotangents = CotangentVector([])
41+
var elementCotangents = TangentVector([])
4242
elementCotangents.base.reserveCapacity(count)
4343
for pullback in pullbacks.reversed() {
4444
let (newResultCotangent, elementCotangent) = pullback(resultCotangent)
4545
resultCotangent = newResultCotangent
4646
elementCotangents.base.append(elementCotangent)
4747
}
48-
return (CotangentVector(elementCotangents.base.reversed()), resultCotangent)
48+
return (TangentVector(elementCotangents.base.reversed()), resultCotangent)
4949
})
5050
}
5151
}

MiniGo/Models/GoModel.swift

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -183,12 +183,12 @@ public struct GoModel: Layer {
183183

184184
@usableFromInline
185185
func _vjpCall(_ input: Tensor<Float>)
186-
-> (GoModelOutput, (GoModelOutput.CotangentVector)
187-
-> (GoModel.CotangentVector, Tensor<Float>)) {
186+
-> (GoModelOutput, (GoModelOutput.TangentVector)
187+
-> (GoModel.TangentVector, Tensor<Float>)) {
188188
// TODO(jekbradbury): add a real VJP
189189
// (we're only interested in inference for now and have control flow in our `call(_:)` method)
190190
return (self(input), {
191-
seed in (GoModel.CotangentVector.zero, Tensor<Float>(0))
191+
seed in (GoModel.TangentVector.zero, Tensor<Float>(0))
192192
})
193193
}
194194
}

MiniGo/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ gsutil cp 'gs://minigo-pub/v15-19x19/models/000939-heron.*' MiniGoCheckpoint/
4040
```sh
4141
# Run inference (self-plays).
4242
cd swift-models
43-
swift run -Xlinker -ltensorflow -c release MiniGo
43+
swift run -Xlinker -ltensorflow -c release MiniGoDemo
4444
```
4545

4646
[Swift for TensorFlow]: https://www.tensorflow.org/swift

Transformer/Model.swift

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ func makeAttentionInput(query: Tensor<Float>, key: Tensor<Float>, value: Tensor<
6363
}
6464

6565
func _vjpMakeAttentionInput(query: Tensor<Float>, key: Tensor<Float>, value: Tensor<Float>)
66-
-> (AttentionInput, (AttentionInput.CotangentVector) -> (Tensor<Float>, Tensor<Float>, Tensor<Float>)) {
66+
-> (AttentionInput, (AttentionInput.TangentVector) -> (Tensor<Float>, Tensor<Float>, Tensor<Float>)) {
6767
let result = AttentionInput(query: query, key: key, value: value)
6868
return (result, { seed in (seed.query, seed.key, seed.value) })
6969
}
@@ -80,7 +80,7 @@ func makeAttentionContext(key: Tensor<Float>, value: Tensor<Float>)
8080
}
8181

8282
func _vjpMakeAttentionContext(key: Tensor<Float>, value: Tensor<Float>)
83-
-> (AttentionContext, (AttentionContext.CotangentVector) -> (Tensor<Float>, Tensor<Float>)) {
83+
-> (AttentionContext, (AttentionContext.TangentVector) -> (Tensor<Float>, Tensor<Float>)) {
8484
let result = AttentionContext(key: key, value: value)
8585
return (result, { seed in (seed.key, seed.value) })
8686
}
@@ -170,7 +170,7 @@ func splitQKV(_ input: Tensor<Float>) -> AttentionInput {
170170
}
171171

172172
func _vjpSplitQKV(_ input: Tensor<Float>)
173-
-> (AttentionInput, (AttentionInput.CotangentVector) -> Tensor<Float>) {
173+
-> (AttentionInput, (AttentionInput.TangentVector) -> Tensor<Float>) {
174174
let value = splitQKV(input)
175175
return (value, { seed in
176176
return Raw.concatV2([seed.query, seed.key, seed.value], axis: Tensor<Int32>(2))

0 commit comments

Comments
 (0)