|
29 | 29 | import numpy as np
|
30 | 30 | import monai
|
31 | 31 | from monai.apps.utils import download_and_extract
|
32 |
| -from monai.data import CacheDataset, DataLoader, png_writer |
| 32 | +from monai.data import CacheDataset, DataLoader |
33 | 33 | from monai.engines import GanTrainer
|
34 | 34 | from monai.engines.utils import GanKeys as Keys
|
35 | 35 | from monai.engines.utils import default_make_latent as make_latent
|
|
47 | 47 | EnsureTypeD,
|
48 | 48 | )
|
49 | 49 | from monai.utils.misc import set_determinism
|
| 50 | +from monai.data.image_writer import PILWriter |
50 | 51 |
|
51 | 52 |
|
52 | 53 | def main():
|
@@ -193,11 +194,15 @@ def generator_loss(gen_images):
|
193 | 194 | test_img_count = 10
|
194 | 195 | test_latents = make_latent(test_img_count, latent_size).to(device)
|
195 | 196 | fakes = gen_net(test_latents)
|
| 197 | + |
| 198 | + writer_obj = PILWriter(output_dtype=np.uint8) |
| 199 | + |
196 | 200 | for i, image in enumerate(fakes):
|
197 |
| - filename = "gen-fake-final-%d.png" % i |
| 201 | + filename = f"gen-fake-final-{i}.png" |
198 | 202 | save_path = os.path.join(run_dir, filename)
|
199 |
| - img_array = image[0].cpu().data.numpy() |
200 |
| - png_writer.write_png(img_array, save_path, scale=255) |
| 203 | + img_array = monai.transforms.utils.rescale_array(image[0].cpu().data.numpy()) |
| 204 | + writer_obj.set_data_array(img_array, channel_dim=None) |
| 205 | + writer_obj.write(save_path, format="PNG") |
201 | 206 |
|
202 | 207 |
|
203 | 208 | if __name__ == "__main__":
|
|
0 commit comments