Skip to content

Commit e2b429b

Browse files
committed
Training Jupyter Notebook added
Signed-off-by: vnath <[email protected]>
1 parent c38fe18 commit e2b429b

File tree

2 files changed

+1448
-33
lines changed

2 files changed

+1448
-33
lines changed

nuclick/nuclick_training.py

Lines changed: 2 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,16 @@
11
import json
2-
import logging
32
import copy
43
import math
54
import os
6-
import random
7-
import cv2
85
import numpy as np
9-
import skimage
6+
import cv2
107
import torch
118
from tqdm import tqdm
9+
1210
from skimage.measure import regionprops
13-
#from lib.handlers import TensorBoardImageHandler
14-
#from lib.transforms import FilterImaged
15-
#from lib.utils import split_dataset, split_nuclei_dataset
16-
from monai.config import KeysCollection
1711
from monai.engines import SupervisedTrainer, SupervisedEvaluator
1812
from monai.handlers import (
1913
CheckpointSaver,
20-
EarlyStopHandler,
21-
LrScheduleHandler,
2214
MeanDice,
2315
StatsHandler,
2416
TensorBoardImageHandler,
@@ -30,7 +22,6 @@
3022
from monai.losses import DiceLoss
3123
from monai.networks.nets import BasicUNet
3224
from monai.data import (
33-
CacheDataset,
3425
Dataset,
3526
DataLoader,
3627
)
@@ -43,17 +34,13 @@
4334
EnsureTyped,
4435
LoadImaged,
4536
LoadImage,
46-
MapTransform,
47-
RandomizableTransform,
4837
RandRotate90d,
4938
ScaleIntensityRangeD,
5039
ToNumpyd,
5140
TorchVisiond,
5241
ToTensord,
53-
Transform,
5442
)
5543

56-
#from monai.apps.nuclick.dataset_prep import split_pannuke_dataset, split_nuclei_dataset
5744
from monai.apps.nuclick.transforms import (
5845
FlattenLabeld,
5946
ExtractPatchd,
@@ -62,8 +49,6 @@
6249
FilterImaged
6350
)
6451

65-
#from monailabel.interfaces.datastore import Datastore
66-
#from monailabel.tasks.train.basic_train import BasicTrainTask, Context
6752

6853
def split_pannuke_dataset(image, label, output_dir, groups):
6954
groups = groups if groups else dict()
@@ -81,15 +66,11 @@ def split_pannuke_dataset(image, label, output_dir, groups):
8166

8267
print(f"++ Using Groups: {groups}")
8368
print(f"++ Using Label Channels: {label_channels}")
84-
#logger.info(f"++ Using Groups: {groups}")
85-
#logger.info(f"++ Using Label Channels: {label_channels}")
8669

8770
images = np.load(image)
8871
labels = np.load(label)
8972
print(f"Image Shape: {images.shape}")
9073
print(f"Labels Shape: {labels.shape}")
91-
#logger.info(f"Image Shape: {images.shape}")
92-
#logger.info(f"Labels Shape: {labels.shape}")
9374

9475
images_dir = output_dir
9576
labels_dir = os.path.join(output_dir, "labels", "final")
@@ -128,7 +109,6 @@ def split_nuclei_dataset(d, centroid_key="centroid", mask_value_key="mask_value"
128109
stats = regionprops(labels)
129110
for stat in stats:
130111
if stat.area < min_area:
131-
#logger.debug(f"++++ Ignored label with smaller area => ( {stat.area} < {min_area})")
132112
print(f"++++ Ignored label with smaller area => ( {stat.area} < {min_area})")
133113
continue
134114

@@ -140,7 +120,6 @@ def split_nuclei_dataset(d, centroid_key="centroid", mask_value_key="mask_value"
140120
item[centroid_key] = (x, y)
141121
item[mask_value_key] = stat.label
142122

143-
# logger.info(f"{d['label']} => {len(stats)} => {mask.shape} => {stat.label}")
144123
dataset_json.append(item)
145124
return dataset_json
146125

@@ -274,12 +253,6 @@ def main():
274253
dice_loss = DiceLoss(sigmoid=True, squared_pred=True)
275254

276255
# Training Process
277-
#TODO Consider uisng the Supervised Trainer over here from MONAI
278-
#network.train()
279-
#TODO Refer here for how to fix up a validation when using a SupervisedTrainer. In short a supervisedevaluator needs to be created as a
280-
# training handler
281-
#TODO https://github.com/Project-MONAI/tutorials/blob/bc342633bd8e50be7b4a67b723006bb03285f6ba/acceleration/distributed_training/unet_training_workflows.py#L187
282-
283256
val_handlers = [
284257
# use the logger "train_log" defined at the beginning of this program
285258
StatsHandler(name="train_log", output_transform=lambda x: None),
@@ -299,16 +272,12 @@ def main():
299272
inferer=val_inferer,
300273
postprocessing=train_post_transforms,
301274
key_val_metric=val_key_metric,
302-
#additional_metrics={"val_acc": Accuracy(output_transform=from_engine(["pred", "label"]))},
303275
val_handlers=val_handlers,
304276
# if no FP16 support in GPU or PyTorch version < 1.6, will not enable AMP evaluation
305277
amp=False,
306278
)
307279

308280
train_handlers = [
309-
# apply “EarlyStop” logic based on the loss value, use “-” negative value because smaller loss is better
310-
#EarlyStopHandler(trainer=None, patience=20, score_function=lambda x: -x.state.output[0]["loss"], epoch_level=False),
311-
#LrScheduleHandler(lr_scheduler=lr_scheduler, print_lr=True),
312281
ValidationHandler(validator=evaluator, interval=1, epoch_level=True),
313282
# use the logger "train_log" defined at the beginning of this program
314283
StatsHandler(name="train_log",

0 commit comments

Comments
 (0)