Skip to content

Commit deaabb8

Browse files
boliu61pre-commit-ci[bot]yiheng-wang-nv
authored
add surgtoolloc preprocess notebooks (#973)
Fixes #966 cc @yiheng-wang-nv ### Description 1. Add `competitions` directory and move `kaggle` sub-directory under it. Also create `MICCAI` sub-directory. ( Contributed by @boliu61 ) 2. Add preprocessing notebooks to surgtoolloc competition folder (Contributed by @boliu61 ) 3. Add solution of task 2 (Contributed by @yiheng-wang-nv ) 4. Add solution of task 1 (Contributed by David Austin, and cleaned by @yiheng-wang-nv ) ### Checks <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [ ] Notebook runs automatically `./runner [-p <regex_pattern>]` Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Yiheng Wang <[email protected]> Co-authored-by: Yiheng Wang <[email protected]>
1 parent 062a95d commit deaabb8

20 files changed

+1827
-1
lines changed

.github/workflows/test-modified.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,5 +37,5 @@ jobs:
3737
python -m pip install -r requirements.txt; python -m pip list
3838
python -c "import monai; monai.config.print_debug_info()"
3939
export CUDA_VISIBLE_DEVICES=0
40-
git diff --name-only origin/main | while read line; do if [[ $line == *.ipynb ]]; then ./runner.sh -p "-and -wholename './${line}'"; fi; done;
40+
git diff --name-only origin/main | while read line; do if [[ $line == *.ipynb ]]; then ./runner.sh -p " -and -wholename './${line}'"; fi; done;
4141
# [[ $line == *.ipynb ]] && ./runner.sh --file "$line"

README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,8 @@ Tutorial to show the pipeline of fine tuning an endoscopic inbody classification
211211
**modules**
212212
#### [bundle](./bundle)
213213
Get started tutorial and concrete training / inference examples for MONAI bundle features.
214+
#### [competitions](./competitions)
215+
MONAI based solutions of competitions in healthcare imaging.
214216
#### [engines](./modules/engines)
215217
Training and evaluation examples of 3D segmentation based on UNet3D and synthetic dataset with MONAI workflows, which contains engines, event-handlers, and post-transforms. And GAN training and evaluation example for a medical image generative adversarial network. Easy run training script uses `GanTrainer` to train a 2D CT scan reconstruction network. Evaluation script generates random samples from a trained network.
216218

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
---
2+
cfg:
3+
output_dir: "./output/"
4+
data_dir: "/raid/surg/image640_blur/"
5+
backbone: "efficientnet-b4"
6+
train_df: "cleaned_clf_train_data.csv"
7+
img_size: [640, 640]
8+
batch_size: 196
9+
num_classes: 14
10+
lr: 0.001
11+
epochs: 5
12+
oversample_rate: 4
13+
clf_threshold: 0.4
14+
num_workers: 8
15+
gpu: 0
16+
device: "cuda:0"
17+
image_load:
18+
- _target_: LoadImaged
19+
keys: "input"
20+
image_only: true
21+
- _target_: EnsureChannelFirstd
22+
keys: "input"
23+
- _target_: Resized
24+
keys: "input"
25+
spatial_size: "@cfg#img_size"
26+
mode: "bilinear"
27+
align_corners: false
28+
- _target_: Lambdad
29+
keys: "input"
30+
func: "$lambda x: x / 255.0"
31+
image_aug:
32+
- _target_: RandFlipd
33+
keys: "input"
34+
prob: 0.5
35+
spatial_axis: 0
36+
- _target_: RandFlipd
37+
keys: "input"
38+
prob: 0.5
39+
spatial_axis: 1
40+
- _target_: RandRotate90d
41+
keys: "input"
42+
prob: 0.5
43+
train_aug:
44+
_target_: Compose
45+
transforms: "$@cfg#image_load + @cfg#image_aug"
46+
val_aug:
47+
_target_: Compose
48+
transforms: "@cfg#image_load"
Lines changed: 245 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,245 @@
1+
import argparse
2+
import gc
3+
import os
4+
import sys
5+
from types import SimpleNamespace
6+
7+
import numpy as np
8+
import pandas as pd
9+
import torch
10+
from monai.bundle import ConfigParser
11+
from monai.metrics import ConfusionMatrixMetric
12+
from monai.networks.nets import EfficientNetBN
13+
from torch.cuda.amp import GradScaler, autocast
14+
from torch.utils.tensorboard import SummaryWriter
15+
from tqdm import tqdm
16+
from utils import (
17+
SurgDataset,
18+
create_checkpoint,
19+
get_train_dataloader,
20+
get_val_dataloader,
21+
mixup_data,
22+
set_seed,
23+
)
24+
25+
26+
def main(cfg):
27+
28+
os.makedirs(str(cfg.output_dir + f"/fold{cfg.fold}/"), exist_ok=True)
29+
set_seed(cfg.seed)
30+
# set dataset, dataloader
31+
df = pd.read_csv(cfg.train_df)
32+
sucirr_videos = df[df["suction irrigator"] > 0].clip_name.unique()
33+
tipup_videos = df[df["tip-up fenestrated grasper"] > 0].clip_name.unique()
34+
df_oversample = df[
35+
df.clip_name.isin(sucirr_videos) | df.clip_name.isin(tipup_videos)
36+
]
37+
38+
cfg.labels = df.columns.values[4:18]
39+
40+
val_df = df[df["fold"] == cfg.fold]
41+
train_df = pd.concat(
42+
[df[df["fold"] != cfg.fold]]
43+
+ [df_oversample[df_oversample["fold"] != cfg.fold]] * cfg.oversample_rate
44+
)
45+
46+
train_dataset = SurgDataset(cfg, df=train_df, mode="train")
47+
val_dataset = SurgDataset(cfg, df=val_df, mode="val")
48+
49+
train_dataloader = get_train_dataloader(train_dataset, cfg)
50+
val_dataloader = get_val_dataloader(val_dataset, cfg)
51+
52+
# set model
53+
model = EfficientNetBN(
54+
model_name=cfg.backbone, pretrained=True, num_classes=cfg.num_classes
55+
)
56+
model = torch.nn.DataParallel(model)
57+
model.to(cfg.device)
58+
59+
if cfg.weights is not None:
60+
model.load_state_dict(
61+
torch.load(os.path.join(f"{cfg.output_dir}/fold{cfg.fold}", cfg.weights))[
62+
"model"
63+
]
64+
)
65+
print(f"weights from: {cfg.weights} are loaded.")
66+
67+
# set optimizer, lr scheduler
68+
optimizer = torch.optim.AdamW(model.parameters(), lr=cfg.lr)
69+
scheduler = torch.optim.lr_scheduler.OneCycleLR(
70+
optimizer,
71+
max_lr=cfg.lr,
72+
epochs=cfg.epochs,
73+
steps_per_epoch=int(train_df.shape[0] / cfg.batch_size),
74+
pct_start=0.1,
75+
anneal_strategy="cos",
76+
final_div_factor=10**5,
77+
)
78+
# set loss, metric
79+
class_num = list(train_df[cfg.labels].sum())
80+
class_weights = [
81+
train_df.shape[0] / (n * cfg.num_classes) if n > 0 else 1 for n in class_num
82+
]
83+
loss_function = torch.nn.BCEWithLogitsLoss(
84+
weight=torch.as_tensor(class_weights).to(cfg.device)
85+
)
86+
metric = ConfusionMatrixMetric(metric_name="F1", reduction="mean_batch")
87+
88+
# set other tools
89+
scaler = GradScaler()
90+
writer = SummaryWriter(str(cfg.output_dir + f"/fold{cfg.fold}/"))
91+
92+
# train and val loop
93+
step = 0
94+
i = 0
95+
best_metric = run_eval(
96+
model=model,
97+
val_dataloader=val_dataloader,
98+
cfg=cfg,
99+
writer=writer,
100+
epoch=-1,
101+
metric=metric,
102+
)
103+
optimizer.zero_grad()
104+
print("start from: ", best_metric)
105+
for epoch in range(cfg.epochs):
106+
print("EPOCH:", epoch)
107+
gc.collect()
108+
run_train(
109+
model=model,
110+
train_dataloader=train_dataloader,
111+
optimizer=optimizer,
112+
scheduler=scheduler,
113+
cfg=cfg,
114+
scaler=scaler,
115+
writer=writer,
116+
epoch=epoch,
117+
iteration=i,
118+
step=step,
119+
loss_function=loss_function,
120+
)
121+
122+
val_metric = run_eval(
123+
model=model,
124+
val_dataloader=val_dataloader,
125+
cfg=cfg,
126+
writer=writer,
127+
epoch=epoch,
128+
metric=metric,
129+
)
130+
131+
if val_metric > best_metric:
132+
print(f"SAVING CHECKPOINT: val_metric {best_metric:.5} -> {val_metric:.5}")
133+
best_metric = val_metric
134+
135+
checkpoint = create_checkpoint(
136+
model,
137+
optimizer,
138+
epoch,
139+
scheduler=scheduler,
140+
scaler=scaler,
141+
)
142+
torch.save(
143+
checkpoint,
144+
f"{cfg.output_dir}/fold{cfg.fold}/checkpoint_best_metric.pth",
145+
)
146+
147+
148+
def run_train(
149+
model,
150+
train_dataloader,
151+
optimizer,
152+
scheduler,
153+
cfg,
154+
scaler,
155+
writer,
156+
epoch,
157+
iteration,
158+
step,
159+
loss_function,
160+
):
161+
model.train()
162+
losses = []
163+
progress_bar = tqdm(range(len(train_dataloader)))
164+
tr_it = iter(train_dataloader)
165+
166+
for itr in progress_bar:
167+
batch = next(tr_it)
168+
inputs, labels = batch["input"].to(cfg.device), batch["label"].to(cfg.device)
169+
iteration += 1
170+
171+
step += cfg.batch_size
172+
torch.set_grad_enabled(True)
173+
if torch.rand(1) > 0.5:
174+
inputs, labels_a, labels_b, lam = mixup_data(inputs, labels)
175+
with autocast():
176+
outputs = model(inputs)
177+
loss = lam * loss_function(outputs, labels_a) + (
178+
1 - lam
179+
) * loss_function(outputs, labels_b)
180+
else:
181+
with autocast():
182+
outputs = model(inputs)
183+
loss = loss_function(outputs, labels)
184+
losses.append(loss.item())
185+
186+
scaler.scale(loss).backward()
187+
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
188+
scaler.step(optimizer)
189+
scaler.update()
190+
optimizer.zero_grad()
191+
scheduler.step()
192+
193+
progress_bar.set_description(f"loss: {np.mean(losses):.2f}")
194+
195+
196+
def run_eval(model, val_dataloader, cfg, writer, epoch, metric):
197+
198+
model.eval()
199+
torch.set_grad_enabled(False)
200+
201+
progress_bar = tqdm(range(len(val_dataloader)))
202+
tr_it = iter(val_dataloader)
203+
204+
for itr in progress_bar:
205+
batch = next(tr_it)
206+
inputs, labels = batch["input"].to(cfg.device), batch["label"].to(cfg.device)
207+
outputs = model(inputs)
208+
outputs = (torch.sigmoid(outputs) > cfg.clf_threshold).float()
209+
metric(outputs, labels)
210+
score = metric.aggregate()[0]
211+
print(score)
212+
score = torch.mean(score).item()
213+
metric.reset()
214+
writer.add_scalar("F1", score, epoch)
215+
216+
return score
217+
218+
219+
if __name__ == "__main__":
220+
221+
sys.path.append("configs")
222+
sys.path.append("models")
223+
sys.path.append("data")
224+
225+
arg_parser = argparse.ArgumentParser(description="")
226+
227+
arg_parser.add_argument(
228+
"-c", "--config", type=str, default="cfg_efnb4.yaml", help="config filename"
229+
)
230+
arg_parser.add_argument("-f", "--fold", type=int, default=0, help="fold")
231+
arg_parser.add_argument("-s", "--seed", type=int, default=-1, help="seed")
232+
arg_parser.add_argument("-w", "--weights", default=None, help="the path of weights")
233+
234+
input_args, _ = arg_parser.parse_known_args(sys.argv)
235+
236+
config_parser = ConfigParser()
237+
config_parser.read_config(input_args.config)
238+
config_parser.parse()
239+
cfg = SimpleNamespace(**config_parser.get_parsed_content("cfg"))
240+
241+
cfg.fold = input_args.fold
242+
cfg.seed = input_args.seed
243+
cfg.weights = input_args.weights
244+
245+
main(cfg)

0 commit comments

Comments
 (0)