Skip to content

Commit 9945357

Browse files
ericspodpre-commit-ci[bot]KumoLiu
authored andcommitted
Updates to GAN script examples (Project-MONAI#1727)
### Description This updates the two example scripts which used old code MONAI no long has. ### Checks <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [ ] Avoid including large-size files in the PR. - [ ] Clean up long text outputs from code cells in the notebook. - [ ] For security purposes, please check the contents and remove any sensitive info such as user names and private key. - [ ] Ensure (1) hyperlinks and markdown anchors are working (2) use relative paths for tutorial repo files (3) put figure and graphs in the `./figure` folder - [ ] Notebook runs automatically `./runner.sh -t <path to .ipynb file>` --------- Signed-off-by: Eric Kerfoot <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: YunLiu <[email protected]>
1 parent 96302a4 commit 9945357

File tree

2 files changed

+16
-8
lines changed

2 files changed

+16
-8
lines changed

modules/engines/gan_evaluation.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,18 +23,21 @@
2323
import torch
2424

2525
import monai
26-
from monai.data import png_writer
2726
from monai.engines.utils import default_make_latent as make_latent
2827
from monai.networks.nets import Generator
2928
from monai.utils.misc import set_determinism
29+
from monai.data.image_writer import PILWriter
3030

3131

3232
def save_generator_fakes(run_folder, g_output_tensor):
33+
writer_obj = PILWriter(output_dtype=np.uint8)
34+
3335
for i, image in enumerate(g_output_tensor):
34-
filename = "gen-fake-%d.png" % i
36+
filename = f"gen-fake-{i}.png"
3537
save_path = os.path.join(run_folder, filename)
36-
img_array = image[0].cpu().data.numpy()
37-
png_writer.write_png(img_array, save_path, scale=255)
38+
img_array = monai.transforms.utils.rescale_array(image[0].cpu().data.numpy())
39+
writer_obj.set_data_array(img_array, channel_dim=None)
40+
writer_obj.write(save_path, format="PNG")
3841

3942

4043
def main():

modules/engines/gan_training.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
import numpy as np
3030
import monai
3131
from monai.apps.utils import download_and_extract
32-
from monai.data import CacheDataset, DataLoader, png_writer
32+
from monai.data import CacheDataset, DataLoader
3333
from monai.engines import GanTrainer
3434
from monai.engines.utils import GanKeys as Keys
3535
from monai.engines.utils import default_make_latent as make_latent
@@ -47,6 +47,7 @@
4747
EnsureTypeD,
4848
)
4949
from monai.utils.misc import set_determinism
50+
from monai.data.image_writer import PILWriter
5051

5152

5253
def main():
@@ -193,11 +194,15 @@ def generator_loss(gen_images):
193194
test_img_count = 10
194195
test_latents = make_latent(test_img_count, latent_size).to(device)
195196
fakes = gen_net(test_latents)
197+
198+
writer_obj = PILWriter(output_dtype=np.uint8)
199+
196200
for i, image in enumerate(fakes):
197-
filename = "gen-fake-final-%d.png" % i
201+
filename = f"gen-fake-final-{i}.png"
198202
save_path = os.path.join(run_dir, filename)
199-
img_array = image[0].cpu().data.numpy()
200-
png_writer.write_png(img_array, save_path, scale=255)
203+
img_array = monai.transforms.utils.rescale_array(image[0].cpu().data.numpy())
204+
writer_obj.set_data_array(img_array, channel_dim=None)
205+
writer_obj.write(save_path, format="PNG")
201206

202207

203208
if __name__ == "__main__":

0 commit comments

Comments
 (0)