Skip to content
This repository was archived by the owner on Apr 23, 2025. It is now read-only.

tweak gradient calls to match api change #235

Merged
merged 1 commit into from
Dec 4, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Autoencoder/main.swift
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ for epoch in 1...epochCount {
for batch in trainingShuffled.batched(batchSize) {
let x = batch.data

let 𝛁model = autoencoder.gradient { autoencoder -> Tensor<Float> in
let 𝛁model = TensorFlow.gradient(at: autoencoder) { autoencoder -> Tensor<Float> in
let image = autoencoder(x)
return meanSquaredError(predicted: image, expected: x)
}
Expand Down
2 changes: 1 addition & 1 deletion Benchmarks/Models/ImageClassificationTraining.swift
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ where
sampleCount: dataset.trainingExampleCount, randomSeed: Int64(epoch))
for batch in trainingShuffled.batched(batchSize) {
let (labels, images) = (batch.label, batch.data)
let 𝛁model = model.gradient { model -> Tensor<Float> in
let 𝛁model = TensorFlow.gradient(at: model) { model -> Tensor<Float> in
let logits = model(images)
return softmaxCrossEntropy(logits: logits, labels: labels)
}
Expand Down
2 changes: 1 addition & 1 deletion Examples/LeNet-MNIST/main.swift
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ for epoch in 1...epochCount {
for batch in trainingShuffled.batched(batchSize) {
let (labels, images) = (batch.label, batch.data)
// Compute the gradient with respect to the model.
let 𝛁model = classifier.gradient { classifier -> Tensor<Float> in
let 𝛁model = TensorFlow.gradient(at: classifier) { classifier -> Tensor<Float> in
let ŷ = classifier(images)
let correctPredictions = ŷ.argmax(squeezingAxis: 1) .== labels
trainStats.correctGuessCount += Int(
Expand Down
4 changes: 2 additions & 2 deletions GAN/main.swift
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ for epoch in 1...epochCount {
// Update generator.
let vec1 = sampleVector(size: batchSize)

let 𝛁generator = generator.gradient { generator -> Tensor<Float> in
let 𝛁generator = TensorFlow.gradient(at: generator) { generator -> Tensor<Float> in
let fakeImages = generator(vec1)
let fakeLogits = discriminator(fakeImages)
let loss = generatorLoss(fakeLogits: fakeLogits)
Expand All @@ -166,7 +166,7 @@ for epoch in 1...epochCount {
let vec2 = sampleVector(size: batchSize)
let fakeImages = generator(vec2)

let 𝛁discriminator = discriminator.gradient { discriminator -> Tensor<Float> in
let 𝛁discriminator = TensorFlow.gradient(at: discriminator) { discriminator -> Tensor<Float> in
let realLogits = discriminator(realImages)
let fakeLogits = discriminator(fakeImages)
let loss = discriminatorLoss(realLogits: realLogits, fakeLogits: fakeLogits)
Expand Down