Skip to content
This repository was archived by the owner on Apr 23, 2025. It is now read-only.

Commit acf4c34

Browse files
brettkooncedan-zheng
authored andcommitted
tweak gradient calls to match api change (#235)
1 parent b6593df commit acf4c34

File tree

4 files changed

+5
-5
lines changed

4 files changed

+5
-5
lines changed

Autoencoder/main.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ for epoch in 1...epochCount {
7171
for batch in trainingShuffled.batched(batchSize) {
7272
let x = batch.data
7373

74-
let 𝛁model = autoencoder.gradient { autoencoder -> Tensor<Float> in
74+
let 𝛁model = TensorFlow.gradient(at: autoencoder) { autoencoder -> Tensor<Float> in
7575
let image = autoencoder(x)
7676
return meanSquaredError(predicted: image, expected: x)
7777
}

Benchmarks/Models/ImageClassificationTraining.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ where
4141
sampleCount: dataset.trainingExampleCount, randomSeed: Int64(epoch))
4242
for batch in trainingShuffled.batched(batchSize) {
4343
let (labels, images) = (batch.label, batch.data)
44-
let 𝛁model = model.gradient { model -> Tensor<Float> in
44+
let 𝛁model = TensorFlow.gradient(at: model) { model -> Tensor<Float> in
4545
let logits = model(images)
4646
return softmaxCrossEntropy(logits: logits, labels: labels)
4747
}

Examples/LeNet-MNIST/main.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ for epoch in 1...epochCount {
5555
for batch in trainingShuffled.batched(batchSize) {
5656
let (labels, images) = (batch.label, batch.data)
5757
// Compute the gradient with respect to the model.
58-
let 𝛁model = classifier.gradient { classifier -> Tensor<Float> in
58+
let 𝛁model = TensorFlow.gradient(at: classifier) { classifier -> Tensor<Float> in
5959
let ŷ = classifier(images)
6060
let correctPredictions = ŷ.argmax(squeezingAxis: 1) .== labels
6161
trainStats.correctGuessCount += Int(

GAN/main.swift

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,7 @@ for epoch in 1...epochCount {
153153
// Update generator.
154154
let vec1 = sampleVector(size: batchSize)
155155

156-
let 𝛁generator = generator.gradient { generator -> Tensor<Float> in
156+
let 𝛁generator = TensorFlow.gradient(at: generator) { generator -> Tensor<Float> in
157157
let fakeImages = generator(vec1)
158158
let fakeLogits = discriminator(fakeImages)
159159
let loss = generatorLoss(fakeLogits: fakeLogits)
@@ -166,7 +166,7 @@ for epoch in 1...epochCount {
166166
let vec2 = sampleVector(size: batchSize)
167167
let fakeImages = generator(vec2)
168168

169-
let 𝛁discriminator = discriminator.gradient { discriminator -> Tensor<Float> in
169+
let 𝛁discriminator = TensorFlow.gradient(at: discriminator) { discriminator -> Tensor<Float> in
170170
let realLogits = discriminator(realImages)
171171
let fakeLogits = discriminator(fakeImages)
172172
let loss = discriminatorLoss(realLogits: realLogits, fakeLogits: fakeLogits)

0 commit comments

Comments
 (0)