Skip to content

Commit c38fe18

Browse files
committed
Rough training code added
Signed-off-by: vnath <[email protected]>
1 parent a74ef4c commit c38fe18

File tree

1 file changed

+347
-0
lines changed

1 file changed

+347
-0
lines changed

nuclick/nuclick_training.py

Lines changed: 347 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,347 @@
1+
import json
2+
import logging
3+
import copy
4+
import math
5+
import os
6+
import random
7+
import cv2
8+
import numpy as np
9+
import skimage
10+
import torch
11+
from tqdm import tqdm
12+
from skimage.measure import regionprops
13+
#from lib.handlers import TensorBoardImageHandler
14+
#from lib.transforms import FilterImaged
15+
#from lib.utils import split_dataset, split_nuclei_dataset
16+
from monai.config import KeysCollection
17+
from monai.engines import SupervisedTrainer, SupervisedEvaluator
18+
from monai.handlers import (
19+
CheckpointSaver,
20+
EarlyStopHandler,
21+
LrScheduleHandler,
22+
MeanDice,
23+
StatsHandler,
24+
TensorBoardImageHandler,
25+
TensorBoardStatsHandler,
26+
ValidationHandler,
27+
from_engine
28+
)
29+
from monai.inferers import SimpleInferer
30+
from monai.losses import DiceLoss
31+
from monai.networks.nets import BasicUNet
32+
from monai.data import (
33+
CacheDataset,
34+
Dataset,
35+
DataLoader,
36+
)
37+
from monai.transforms import (
38+
Activationsd,
39+
AddChanneld,
40+
AsChannelFirstd,
41+
AsDiscreted,
42+
Compose,
43+
EnsureTyped,
44+
LoadImaged,
45+
LoadImage,
46+
MapTransform,
47+
RandomizableTransform,
48+
RandRotate90d,
49+
ScaleIntensityRangeD,
50+
ToNumpyd,
51+
TorchVisiond,
52+
ToTensord,
53+
Transform,
54+
)
55+
56+
#from monai.apps.nuclick.dataset_prep import split_pannuke_dataset, split_nuclei_dataset
57+
from monai.apps.nuclick.transforms import (
58+
FlattenLabeld,
59+
ExtractPatchd,
60+
SplitLabeld,
61+
AddPointGuidanceSignald,
62+
FilterImaged
63+
)
64+
65+
#from monailabel.interfaces.datastore import Datastore
66+
#from monailabel.tasks.train.basic_train import BasicTrainTask, Context
67+
68+
def split_pannuke_dataset(image, label, output_dir, groups):
69+
groups = groups if groups else dict()
70+
groups = [groups] if isinstance(groups, str) else groups
71+
if not isinstance(groups, dict):
72+
groups = {v: k + 1 for k, v in enumerate(groups)}
73+
74+
label_channels = {
75+
0: "Neoplastic cells",
76+
1: "Inflammatory",
77+
2: "Connective/Soft tissue cells",
78+
3: "Dead Cells",
79+
4: "Epithelial",
80+
}
81+
82+
print(f"++ Using Groups: {groups}")
83+
print(f"++ Using Label Channels: {label_channels}")
84+
#logger.info(f"++ Using Groups: {groups}")
85+
#logger.info(f"++ Using Label Channels: {label_channels}")
86+
87+
images = np.load(image)
88+
labels = np.load(label)
89+
print(f"Image Shape: {images.shape}")
90+
print(f"Labels Shape: {labels.shape}")
91+
#logger.info(f"Image Shape: {images.shape}")
92+
#logger.info(f"Labels Shape: {labels.shape}")
93+
94+
images_dir = output_dir
95+
labels_dir = os.path.join(output_dir, "labels", "final")
96+
os.makedirs(images_dir, exist_ok=True)
97+
os.makedirs(labels_dir, exist_ok=True)
98+
99+
dataset_json = []
100+
for i in tqdm(range(images.shape[0])):
101+
name = f"img_{str(i).zfill(4)}.npy"
102+
image_file = os.path.join(images_dir, name)
103+
label_file = os.path.join(labels_dir, name)
104+
105+
image_np = images[i]
106+
mask = labels[i]
107+
label_np = np.zeros(shape=mask.shape[:2])
108+
109+
for idx, name in label_channels.items():
110+
if idx < mask.shape[2]:
111+
m = mask[:, :, idx]
112+
if np.count_nonzero(m):
113+
m[m > 0] = groups.get(name, 1)
114+
label_np = np.where(m > 0, m, label_np)
115+
116+
np.save(image_file, image_np)
117+
np.save(label_file, label_np)
118+
dataset_json.append({"image": image_file, "label": label_file})
119+
120+
return dataset_json
121+
122+
def split_nuclei_dataset(d, centroid_key="centroid", mask_value_key="mask_value", min_area=5):
123+
dataset_json = []
124+
125+
mask = LoadImage(image_only=True, dtype=np.uint8)(d["label"])
126+
_, labels, _, _ = cv2.connectedComponentsWithStats(mask, 4, cv2.CV_32S)
127+
128+
stats = regionprops(labels)
129+
for stat in stats:
130+
if stat.area < min_area:
131+
#logger.debug(f"++++ Ignored label with smaller area => ( {stat.area} < {min_area})")
132+
print(f"++++ Ignored label with smaller area => ( {stat.area} < {min_area})")
133+
continue
134+
135+
x, y = stat.centroid
136+
x = int(math.floor(x))
137+
y = int(math.floor(y))
138+
139+
item = copy.deepcopy(d)
140+
item[centroid_key] = (x, y)
141+
item[mask_value_key] = stat.label
142+
143+
# logger.info(f"{d['label']} => {len(stats)} => {mask.shape} => {stat.label}")
144+
dataset_json.append(item)
145+
return dataset_json
146+
147+
def main():
148+
149+
# Paths
150+
img_data_path = os.path.normpath('/scratch/pan_nuke_data/fold_1/Fold_1/images/fold1/images.npy')
151+
label_data_path = os.path.normpath('/scratch/pan_nuke_data/fold_1/Fold_1/masks/fold1/masks.npy')
152+
dataset_path = os.path.normpath('/home/vishwesh/nuclick_experiments/try_1/data')
153+
json_path = os.path.normpath('/home/vishwesh/nuclick_experiments/try_1/data_list.json')
154+
logging_dir = os.path.normpath('/home/vishwesh/nuclick_experiments/try_6/')
155+
groups = [
156+
"Neoplastic cells",
157+
"Inflammatory",
158+
"Connective/Soft tissue cells",
159+
"Dead Cells",
160+
"Epithelial",
161+
]
162+
163+
#Hyper-params
164+
patch_size = 128
165+
min_area = 5
166+
167+
# Create Dataset
168+
if os.path.isfile(json_path) == 0:
169+
dataset_json = split_pannuke_dataset(image=img_data_path,
170+
label=label_data_path,
171+
output_dir=dataset_path,
172+
groups=groups)
173+
174+
with open(json_path, 'w') as j_file:
175+
json.dump(dataset_json, j_file)
176+
j_file.close()
177+
else:
178+
with open(json_path, 'r') as j_file:
179+
dataset_json = json.load(j_file)
180+
j_file.close()
181+
182+
ds_json_new = []
183+
for d in tqdm(dataset_json):
184+
ds_json_new.extend(split_nuclei_dataset(d, min_area=min_area))
185+
186+
print('Total DataSize is {}'.format(len(ds_json_new)))
187+
val_split = round(len(ds_json_new) * 0.8)
188+
train_ds_json_new = ds_json_new[:val_split]
189+
val_ds_json_new = ds_json_new[val_split:]
190+
191+
# Transforms
192+
train_pre_transforms = Compose(
193+
[
194+
LoadImaged(keys=("image", "label"), dtype=np.uint8),
195+
FilterImaged(keys="image", min_size=5),
196+
FlattenLabeld(keys="label"),
197+
AsChannelFirstd(keys="image"),
198+
AddChanneld(keys="label"),
199+
ExtractPatchd(keys=("image", "label"), patch_size=patch_size),
200+
SplitLabeld(label="label", others="others", mask_value="mask_value", min_area=min_area),
201+
ToTensord(keys="image"),
202+
TorchVisiond(
203+
keys="image", name="ColorJitter", brightness=64.0 / 255.0, contrast=0.75, saturation=0.25, hue=0.04
204+
),
205+
ToNumpyd(keys="image"),
206+
RandRotate90d(keys=("image", "label", "others"), prob=0.5, spatial_axes=(0, 1)),
207+
ScaleIntensityRangeD(keys="image", a_min=0.0, a_max=255.0, b_min=-1.0, b_max=1.0),
208+
AddPointGuidanceSignald(image="image", label="label", others="others"),
209+
EnsureTyped(keys=("image", "label"))
210+
]
211+
)
212+
213+
train_post_transforms = Compose(
214+
[
215+
Activationsd(keys="pred", sigmoid=True),
216+
AsDiscreted(keys="pred", threshold_values=True, logit_thresh=0.5),
217+
]
218+
)
219+
220+
val_transforms = Compose(
221+
[
222+
LoadImaged(keys=("image", "label"), dtype=np.uint8),
223+
FilterImaged(keys="image", min_size=5),
224+
FlattenLabeld(keys="label"),
225+
AsChannelFirstd(keys="image"),
226+
AddChanneld(keys="label"),
227+
ExtractPatchd(keys=("image", "label"), patch_size=patch_size),
228+
SplitLabeld(label="label", others="others", mask_value="mask_value", min_area=min_area),
229+
ToTensord(keys="image"),
230+
TorchVisiond(
231+
keys="image", name="ColorJitter", brightness=64.0 / 255.0, contrast=0.75, saturation=0.25, hue=0.04
232+
),
233+
ToNumpyd(keys="image"),
234+
RandRotate90d(keys=("image", "label", "others"), prob=0.5, spatial_axes=(0, 1)),
235+
ScaleIntensityRangeD(keys="image", a_min=0.0, a_max=255.0, b_min=-1.0, b_max=1.0),
236+
AddPointGuidanceSignald(image="image", label="label", others="others", drop_rate=1.0),
237+
EnsureTyped(keys=("image", "label"))
238+
]
239+
)
240+
241+
train_key_metric = {"train_dice": MeanDice(include_background=False, output_transform=from_engine(["pred", "label"]))}
242+
val_key_metric = {"val_dice": MeanDice(include_background=False, output_transform=from_engine(["pred", "label"]))}
243+
val_inferer = SimpleInferer()
244+
245+
# Define Dataset & Loading
246+
train_data_set = Dataset(train_ds_json_new, transform=train_pre_transforms)
247+
train_data_loader = DataLoader(
248+
dataset=train_data_set,
249+
batch_size=32,
250+
shuffle=True,
251+
num_workers=2
252+
)
253+
254+
val_data_set = Dataset(val_ds_json_new, transform=val_transforms)
255+
val_data_loader = DataLoader(
256+
dataset=val_data_set,
257+
batch_size=32,
258+
shuffle=True,
259+
num_workers=2
260+
)
261+
262+
# Network Definition, Optimizer etc
263+
device = torch.device("cuda")
264+
265+
network = BasicUNet(
266+
spatial_dims=2,
267+
in_channels=5,
268+
out_channels=1,
269+
features=(32, 64, 128, 256, 512, 32),
270+
)
271+
272+
network.to(device)
273+
optimizer = torch.optim.Adam(network.parameters(), 0.0001)
274+
dice_loss = DiceLoss(sigmoid=True, squared_pred=True)
275+
276+
# Training Process
277+
#TODO Consider uisng the Supervised Trainer over here from MONAI
278+
#network.train()
279+
#TODO Refer here for how to fix up a validation when using a SupervisedTrainer. In short a supervisedevaluator needs to be created as a
280+
# training handler
281+
#TODO https://github.com/Project-MONAI/tutorials/blob/bc342633bd8e50be7b4a67b723006bb03285f6ba/acceleration/distributed_training/unet_training_workflows.py#L187
282+
283+
val_handlers = [
284+
# use the logger "train_log" defined at the beginning of this program
285+
StatsHandler(name="train_log", output_transform=lambda x: None),
286+
TensorBoardStatsHandler(log_dir=logging_dir, output_transform=lambda x: None),
287+
TensorBoardImageHandler(
288+
log_dir=logging_dir,
289+
batch_transform=from_engine(["image", "label"]),
290+
output_transform=from_engine(["pred"]),
291+
),
292+
CheckpointSaver(save_dir=logging_dir, save_dict={"network": network}, save_key_metric=True),
293+
]
294+
295+
evaluator = SupervisedEvaluator(
296+
device=device,
297+
val_data_loader=val_data_loader,
298+
network=network,
299+
inferer=val_inferer,
300+
postprocessing=train_post_transforms,
301+
key_val_metric=val_key_metric,
302+
#additional_metrics={"val_acc": Accuracy(output_transform=from_engine(["pred", "label"]))},
303+
val_handlers=val_handlers,
304+
# if no FP16 support in GPU or PyTorch version < 1.6, will not enable AMP evaluation
305+
amp=False,
306+
)
307+
308+
train_handlers = [
309+
# apply “EarlyStop” logic based on the loss value, use “-” negative value because smaller loss is better
310+
#EarlyStopHandler(trainer=None, patience=20, score_function=lambda x: -x.state.output[0]["loss"], epoch_level=False),
311+
#LrScheduleHandler(lr_scheduler=lr_scheduler, print_lr=True),
312+
ValidationHandler(validator=evaluator, interval=1, epoch_level=True),
313+
# use the logger "train_log" defined at the beginning of this program
314+
StatsHandler(name="train_log",
315+
tag_name="train_loss",
316+
output_transform=from_engine(["loss"], first=True)),
317+
TensorBoardStatsHandler(log_dir=logging_dir,
318+
tag_name="train_loss",
319+
output_transform=from_engine(["loss"], first=True)
320+
),
321+
CheckpointSaver(save_dir=logging_dir,
322+
save_dict={"net": network, "opt": optimizer},
323+
save_interval=1,
324+
epoch_level=True),
325+
]
326+
327+
trainer = SupervisedTrainer(
328+
device=device,
329+
max_epochs=30,
330+
train_data_loader=train_data_loader,
331+
network=network,
332+
optimizer=optimizer,
333+
loss_function=dice_loss,
334+
inferer=SimpleInferer(),
335+
# if no FP16 support in GPU or PyTorch version < 1.6, will not enable AMP evaluation
336+
amp=False,
337+
postprocessing=train_post_transforms,
338+
key_train_metric=train_key_metric,
339+
train_handlers=train_handlers,
340+
)
341+
trainer.run()
342+
343+
# End ...
344+
return None
345+
346+
if __name__=="__main__":
347+
main()

0 commit comments

Comments
 (0)