Skip to content

add surgtoolloc preprocess notebooks #973

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 27 commits into from
Nov 21, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
52a201f
add surgtoolloc preprocess notebooks
boliu61 Oct 4, 2022
b3427cb
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 4, 2022
10a4597
Merge branch 'main' into main
yiheng-wang-nv Oct 13, 2022
2b5ec3d
fix pep8 error and add det
yiheng-wang-nv Oct 13, 2022
34e83cc
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 13, 2022
c77887a
fix pep8 remove yaml
yiheng-wang-nv Oct 13, 2022
21f7d8f
modify runner.sh
yiheng-wang-nv Oct 13, 2022
4605d80
exclude max epochs
yiheng-wang-nv Oct 13, 2022
471c89b
add code to process labels.csv
yiheng-wang-nv Oct 13, 2022
8f43429
Merge branch 'main' into main
yiheng-wang-nv Oct 13, 2022
8ccf22c
Merge branch 'main' into main
yiheng-wang-nv Nov 10, 2022
647050a
modify wrongly changed runner.sh
yiheng-wang-nv Nov 10, 2022
c7b53e3
add clf training pipeline
yiheng-wang-nv Nov 17, 2022
3ff3917
Merge branch 'main' into main
yiheng-wang-nv Nov 17, 2022
f49dd72
change runner.sh
yiheng-wang-nv Nov 17, 2022
dbf81ee
remove deleted get started in runner.sh
yiheng-wang-nv Nov 17, 2022
bcd8dd3
rename changed notebooks
yiheng-wang-nv Nov 17, 2022
da4fdd1
Merge branch 'main' into main
yiheng-wang-nv Nov 19, 2022
369e93c
use bundle for config
yiheng-wang-nv Nov 21, 2022
4bc3638
Merge branch 'main' of https://github.com/boliu61/tutorials into main
yiheng-wang-nv Nov 21, 2022
a649225
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 21, 2022
800bec2
simply cfg yaml
yiheng-wang-nv Nov 21, 2022
3d58911
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 21, 2022
b3f5c99
Update test-modified.yml
wyli Nov 21, 2022
d670862
remove simple name space
yiheng-wang-nv Nov 21, 2022
03e1f7c
remove simple name space
yiheng-wang-nv Nov 21, 2022
7cc8b11
Merge branch 'main' of https://github.com/boliu61/tutorials into main
yiheng-wang-nv Nov 21, 2022
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/test-modified.yml
Original file line number Diff line number Diff line change
Expand Up @@ -37,5 +37,5 @@ jobs:
python -m pip install -r requirements.txt; python -m pip list
python -c "import monai; monai.config.print_debug_info()"
export CUDA_VISIBLE_DEVICES=0
git diff --name-only origin/main | while read line; do if [[ $line == *.ipynb ]]; then ./runner.sh -p "-and -wholename './${line}'"; fi; done;
git diff --name-only origin/main | while read line; do if [[ $line == *.ipynb ]]; then ./runner.sh -p " -and -wholename './${line}'"; fi; done;
# [[ $line == *.ipynb ]] && ./runner.sh --file "$line"
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,8 @@ Tutorial to show the pipeline of fine tuning an endoscopic inbody classification
**modules**
#### [bundle](./bundle)
Get started tutorial and concrete training / inference examples for MONAI bundle features.
#### [competitions](./competitions)
MONAI based solutions of competitions in healthcare imaging.
#### [engines](./modules/engines)
Training and evaluation examples of 3D segmentation based on UNet3D and synthetic dataset with MONAI workflows, which contains engines, event-handlers, and post-transforms. And GAN training and evaluation example for a medical image generative adversarial network. Easy run training script uses `GanTrainer` to train a 2D CT scan reconstruction network. Evaluation script generates random samples from a trained network.

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
---
cfg:
output_dir: "./output/"
data_dir: "/raid/surg/image640_blur/"
backbone: "efficientnet-b4"
train_df: "cleaned_clf_train_data.csv"
img_size: [640, 640]
batch_size: 196
num_classes: 14
lr: 0.001
epochs: 5
oversample_rate: 4
clf_threshold: 0.4
num_workers: 8
gpu: 0
device: "cuda:0"
image_load:
- _target_: LoadImaged
keys: "input"
image_only: true
- _target_: EnsureChannelFirstd
keys: "input"
- _target_: Resized
keys: "input"
spatial_size: "@cfg#img_size"
mode: "bilinear"
align_corners: false
- _target_: Lambdad
keys: "input"
func: "$lambda x: x / 255.0"
image_aug:
- _target_: RandFlipd
keys: "input"
prob: 0.5
spatial_axis: 0
- _target_: RandFlipd
keys: "input"
prob: 0.5
spatial_axis: 1
- _target_: RandRotate90d
keys: "input"
prob: 0.5
train_aug:
_target_: Compose
transforms: "$@cfg#image_load + @cfg#image_aug"
val_aug:
_target_: Compose
transforms: "@cfg#image_load"
245 changes: 245 additions & 0 deletions competitions/MICCAI/surgtoolloc/classification_files/train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,245 @@
import argparse
import gc
import os
import sys
from types import SimpleNamespace

import numpy as np
import pandas as pd
import torch
from monai.bundle import ConfigParser
from monai.metrics import ConfusionMatrixMetric
from monai.networks.nets import EfficientNetBN
from torch.cuda.amp import GradScaler, autocast
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
from utils import (
SurgDataset,
create_checkpoint,
get_train_dataloader,
get_val_dataloader,
mixup_data,
set_seed,
)


def main(cfg):

os.makedirs(str(cfg.output_dir + f"/fold{cfg.fold}/"), exist_ok=True)
set_seed(cfg.seed)
# set dataset, dataloader
df = pd.read_csv(cfg.train_df)
sucirr_videos = df[df["suction irrigator"] > 0].clip_name.unique()
tipup_videos = df[df["tip-up fenestrated grasper"] > 0].clip_name.unique()
df_oversample = df[
df.clip_name.isin(sucirr_videos) | df.clip_name.isin(tipup_videos)
]

cfg.labels = df.columns.values[4:18]

val_df = df[df["fold"] == cfg.fold]
train_df = pd.concat(
[df[df["fold"] != cfg.fold]]
+ [df_oversample[df_oversample["fold"] != cfg.fold]] * cfg.oversample_rate
)

train_dataset = SurgDataset(cfg, df=train_df, mode="train")
val_dataset = SurgDataset(cfg, df=val_df, mode="val")

train_dataloader = get_train_dataloader(train_dataset, cfg)
val_dataloader = get_val_dataloader(val_dataset, cfg)

# set model
model = EfficientNetBN(
model_name=cfg.backbone, pretrained=True, num_classes=cfg.num_classes
)
model = torch.nn.DataParallel(model)
model.to(cfg.device)

if cfg.weights is not None:
model.load_state_dict(
torch.load(os.path.join(f"{cfg.output_dir}/fold{cfg.fold}", cfg.weights))[
"model"
]
)
print(f"weights from: {cfg.weights} are loaded.")

# set optimizer, lr scheduler
optimizer = torch.optim.AdamW(model.parameters(), lr=cfg.lr)
scheduler = torch.optim.lr_scheduler.OneCycleLR(
optimizer,
max_lr=cfg.lr,
epochs=cfg.epochs,
steps_per_epoch=int(train_df.shape[0] / cfg.batch_size),
pct_start=0.1,
anneal_strategy="cos",
final_div_factor=10**5,
)
# set loss, metric
class_num = list(train_df[cfg.labels].sum())
class_weights = [
train_df.shape[0] / (n * cfg.num_classes) if n > 0 else 1 for n in class_num
]
loss_function = torch.nn.BCEWithLogitsLoss(
weight=torch.as_tensor(class_weights).to(cfg.device)
)
metric = ConfusionMatrixMetric(metric_name="F1", reduction="mean_batch")

# set other tools
scaler = GradScaler()
writer = SummaryWriter(str(cfg.output_dir + f"/fold{cfg.fold}/"))

# train and val loop
step = 0
i = 0
best_metric = run_eval(
model=model,
val_dataloader=val_dataloader,
cfg=cfg,
writer=writer,
epoch=-1,
metric=metric,
)
optimizer.zero_grad()
print("start from: ", best_metric)
for epoch in range(cfg.epochs):
print("EPOCH:", epoch)
gc.collect()
run_train(
model=model,
train_dataloader=train_dataloader,
optimizer=optimizer,
scheduler=scheduler,
cfg=cfg,
scaler=scaler,
writer=writer,
epoch=epoch,
iteration=i,
step=step,
loss_function=loss_function,
)

val_metric = run_eval(
model=model,
val_dataloader=val_dataloader,
cfg=cfg,
writer=writer,
epoch=epoch,
metric=metric,
)

if val_metric > best_metric:
print(f"SAVING CHECKPOINT: val_metric {best_metric:.5} -> {val_metric:.5}")
best_metric = val_metric

checkpoint = create_checkpoint(
model,
optimizer,
epoch,
scheduler=scheduler,
scaler=scaler,
)
torch.save(
checkpoint,
f"{cfg.output_dir}/fold{cfg.fold}/checkpoint_best_metric.pth",
)


def run_train(
model,
train_dataloader,
optimizer,
scheduler,
cfg,
scaler,
writer,
epoch,
iteration,
step,
loss_function,
):
model.train()
losses = []
progress_bar = tqdm(range(len(train_dataloader)))
tr_it = iter(train_dataloader)

for itr in progress_bar:
batch = next(tr_it)
inputs, labels = batch["input"].to(cfg.device), batch["label"].to(cfg.device)
iteration += 1

step += cfg.batch_size
torch.set_grad_enabled(True)
if torch.rand(1) > 0.5:
inputs, labels_a, labels_b, lam = mixup_data(inputs, labels)
with autocast():
outputs = model(inputs)
loss = lam * loss_function(outputs, labels_a) + (
1 - lam
) * loss_function(outputs, labels_b)
else:
with autocast():
outputs = model(inputs)
loss = loss_function(outputs, labels)
losses.append(loss.item())

scaler.scale(loss).backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()
scheduler.step()

progress_bar.set_description(f"loss: {np.mean(losses):.2f}")


def run_eval(model, val_dataloader, cfg, writer, epoch, metric):

model.eval()
torch.set_grad_enabled(False)

progress_bar = tqdm(range(len(val_dataloader)))
tr_it = iter(val_dataloader)

for itr in progress_bar:
batch = next(tr_it)
inputs, labels = batch["input"].to(cfg.device), batch["label"].to(cfg.device)
outputs = model(inputs)
outputs = (torch.sigmoid(outputs) > cfg.clf_threshold).float()
metric(outputs, labels)
score = metric.aggregate()[0]
print(score)
score = torch.mean(score).item()
metric.reset()
writer.add_scalar("F1", score, epoch)

return score


if __name__ == "__main__":

sys.path.append("configs")
sys.path.append("models")
sys.path.append("data")

arg_parser = argparse.ArgumentParser(description="")

arg_parser.add_argument(
"-c", "--config", type=str, default="cfg_efnb4.yaml", help="config filename"
)
arg_parser.add_argument("-f", "--fold", type=int, default=0, help="fold")
arg_parser.add_argument("-s", "--seed", type=int, default=-1, help="seed")
arg_parser.add_argument("-w", "--weights", default=None, help="the path of weights")

input_args, _ = arg_parser.parse_known_args(sys.argv)

config_parser = ConfigParser()
config_parser.read_config(input_args.config)
config_parser.parse()
cfg = SimpleNamespace(**config_parser.get_parsed_content("cfg"))

cfg.fold = input_args.fold
cfg.seed = input_args.seed
cfg.weights = input_args.weights

main(cfg)
Loading