Skip to content

Commit 4120b1c

Browse files
authored
Fix the DCGAN C++ shape warning
1 parent 4a0ae38 commit 4120b1c

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

advanced_source/cpp_frontend.rst

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -969,15 +969,15 @@ the data loader every epoch and then write the GAN training code:
969969
discriminator->zero_grad();
970970
torch::Tensor real_images = batch.data;
971971
torch::Tensor real_labels = torch::empty(batch.data.size(0)).uniform_(0.8, 1.0);
972-
torch::Tensor real_output = discriminator->forward(real_images);
972+
torch::Tensor real_output = discriminator->forward(real_images).reshape(real_labels.sizes());
973973
torch::Tensor d_loss_real = torch::binary_cross_entropy(real_output, real_labels);
974974
d_loss_real.backward();
975975
976976
// Train discriminator with fake images.
977977
torch::Tensor noise = torch::randn({batch.data.size(0), kNoiseSize, 1, 1});
978978
torch::Tensor fake_images = generator->forward(noise);
979979
torch::Tensor fake_labels = torch::zeros(batch.data.size(0));
980-
torch::Tensor fake_output = discriminator->forward(fake_images.detach());
980+
torch::Tensor fake_output = discriminator->forward(fake_images.detach()).reshape(fake_labels.sizes());
981981
torch::Tensor d_loss_fake = torch::binary_cross_entropy(fake_output, fake_labels);
982982
d_loss_fake.backward();
983983
@@ -987,7 +987,7 @@ the data loader every epoch and then write the GAN training code:
987987
// Train generator.
988988
generator->zero_grad();
989989
fake_labels.fill_(1);
990-
fake_output = discriminator->forward(fake_images);
990+
fake_output = discriminator->forward(fake_images).reshape(fake_labels.sizes());
991991
torch::Tensor g_loss = torch::binary_cross_entropy(fake_output, fake_labels);
992992
g_loss.backward();
993993
generator_optimizer.step();

0 commit comments

Comments
 (0)