1
1
import json
2
- import logging
3
2
import copy
4
3
import math
5
4
import os
6
- import random
7
- import cv2
8
5
import numpy as np
9
- import skimage
6
+ import cv2
10
7
import torch
11
8
from tqdm import tqdm
9
+
12
10
from skimage .measure import regionprops
13
- #from lib.handlers import TensorBoardImageHandler
14
- #from lib.transforms import FilterImaged
15
- #from lib.utils import split_dataset, split_nuclei_dataset
16
- from monai .config import KeysCollection
17
11
from monai .engines import SupervisedTrainer , SupervisedEvaluator
18
12
from monai .handlers import (
19
13
CheckpointSaver ,
20
- EarlyStopHandler ,
21
- LrScheduleHandler ,
22
14
MeanDice ,
23
15
StatsHandler ,
24
16
TensorBoardImageHandler ,
30
22
from monai .losses import DiceLoss
31
23
from monai .networks .nets import BasicUNet
32
24
from monai .data import (
33
- CacheDataset ,
34
25
Dataset ,
35
26
DataLoader ,
36
27
)
43
34
EnsureTyped ,
44
35
LoadImaged ,
45
36
LoadImage ,
46
- MapTransform ,
47
- RandomizableTransform ,
48
37
RandRotate90d ,
49
38
ScaleIntensityRangeD ,
50
39
ToNumpyd ,
51
40
TorchVisiond ,
52
41
ToTensord ,
53
- Transform ,
54
42
)
55
43
56
- #from monai.apps.nuclick.dataset_prep import split_pannuke_dataset, split_nuclei_dataset
57
44
from monai .apps .nuclick .transforms import (
58
45
FlattenLabeld ,
59
46
ExtractPatchd ,
62
49
FilterImaged
63
50
)
64
51
65
- #from monailabel.interfaces.datastore import Datastore
66
- #from monailabel.tasks.train.basic_train import BasicTrainTask, Context
67
52
68
53
def split_pannuke_dataset (image , label , output_dir , groups ):
69
54
groups = groups if groups else dict ()
@@ -81,15 +66,11 @@ def split_pannuke_dataset(image, label, output_dir, groups):
81
66
82
67
print (f"++ Using Groups: { groups } " )
83
68
print (f"++ Using Label Channels: { label_channels } " )
84
- #logger.info(f"++ Using Groups: {groups}")
85
- #logger.info(f"++ Using Label Channels: {label_channels}")
86
69
87
70
images = np .load (image )
88
71
labels = np .load (label )
89
72
print (f"Image Shape: { images .shape } " )
90
73
print (f"Labels Shape: { labels .shape } " )
91
- #logger.info(f"Image Shape: {images.shape}")
92
- #logger.info(f"Labels Shape: {labels.shape}")
93
74
94
75
images_dir = output_dir
95
76
labels_dir = os .path .join (output_dir , "labels" , "final" )
@@ -128,7 +109,6 @@ def split_nuclei_dataset(d, centroid_key="centroid", mask_value_key="mask_value"
128
109
stats = regionprops (labels )
129
110
for stat in stats :
130
111
if stat .area < min_area :
131
- #logger.debug(f"++++ Ignored label with smaller area => ( {stat.area} < {min_area})")
132
112
print (f"++++ Ignored label with smaller area => ( { stat .area } < { min_area } )" )
133
113
continue
134
114
@@ -140,7 +120,6 @@ def split_nuclei_dataset(d, centroid_key="centroid", mask_value_key="mask_value"
140
120
item [centroid_key ] = (x , y )
141
121
item [mask_value_key ] = stat .label
142
122
143
- # logger.info(f"{d['label']} => {len(stats)} => {mask.shape} => {stat.label}")
144
123
dataset_json .append (item )
145
124
return dataset_json
146
125
@@ -274,12 +253,6 @@ def main():
274
253
dice_loss = DiceLoss (sigmoid = True , squared_pred = True )
275
254
276
255
# Training Process
277
- #TODO Consider uisng the Supervised Trainer over here from MONAI
278
- #network.train()
279
- #TODO Refer here for how to fix up a validation when using a SupervisedTrainer. In short a supervisedevaluator needs to be created as a
280
- # training handler
281
- #TODO https://github.com/Project-MONAI/tutorials/blob/bc342633bd8e50be7b4a67b723006bb03285f6ba/acceleration/distributed_training/unet_training_workflows.py#L187
282
-
283
256
val_handlers = [
284
257
# use the logger "train_log" defined at the beginning of this program
285
258
StatsHandler (name = "train_log" , output_transform = lambda x : None ),
@@ -299,16 +272,12 @@ def main():
299
272
inferer = val_inferer ,
300
273
postprocessing = train_post_transforms ,
301
274
key_val_metric = val_key_metric ,
302
- #additional_metrics={"val_acc": Accuracy(output_transform=from_engine(["pred", "label"]))},
303
275
val_handlers = val_handlers ,
304
276
# if no FP16 support in GPU or PyTorch version < 1.6, will not enable AMP evaluation
305
277
amp = False ,
306
278
)
307
279
308
280
train_handlers = [
309
- # apply “EarlyStop” logic based on the loss value, use “-” negative value because smaller loss is better
310
- #EarlyStopHandler(trainer=None, patience=20, score_function=lambda x: -x.state.output[0]["loss"], epoch_level=False),
311
- #LrScheduleHandler(lr_scheduler=lr_scheduler, print_lr=True),
312
281
ValidationHandler (validator = evaluator , interval = 1 , epoch_level = True ),
313
282
# use the logger "train_log" defined at the beginning of this program
314
283
StatsHandler (name = "train_log" ,
0 commit comments