Skip to content

Commit 9a9106d

Browse files
SSL Tutorial updated as being notebooks (#1177)
Signed-off-by: Vishwesh Nath <[email protected]> Converts the python scripts to being jupyter notebooks that are reflected in MONAI Toolkit. Signed-off-by: Vishwesh Nath <[email protected]> Co-authored-by: Vishwesh Nath <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 1cb8f69 commit 9a9106d

File tree

11 files changed

+1039
-623
lines changed

11 files changed

+1039
-623
lines changed

README.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -291,3 +291,10 @@ This tutorial shows several visualization approaches for 3D image during transfo
291291

292292
#### [Auto3DSeg](./auto3dseg/)
293293
This folder shows how to run the comprehensive Auto3DSeg pipeline with minimal inputs and customize the Auto3Dseg modules to meet different user requirements.
294+
295+
#### <ins>**Self-Supervised Learning**</ins>
296+
##### [self_supervised_pretraining](./self_supervised_pretraining/ssl_train.ipynb)
297+
This tutorial shows how to construct a training workflow of self-supervised learning where unlabeled data is utilized. The tutorial shows how to train a model on TCIA dataset of unlabeled Covid-19 cases.
298+
299+
##### [self_supervised_pretraining_based_finetuning](./self_supervised_pretraining/ssl_finetune.ipynb)
300+
This tutorial shows how to utilize pre-trained weights from the self-supervised learning framework where unlabeled data is utilized. This tutorial shows how to train a model of multi-class 3D segmentation using pretrained weights.

runner.sh

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,8 @@ doesnt_contain_max_epochs=("${doesnt_contain_max_epochs[@]}" monailabel_endoscop
5757
doesnt_contain_max_epochs=("${doesnt_contain_max_epochs[@]}" monailabel_pathology_nuclei_segmentation_QuPath.ipynb)
5858
doesnt_contain_max_epochs=("${doesnt_contain_max_epochs[@]}" monailabel_radiology_spleen_segmentation_OHIF.ipynb)
5959
doesnt_contain_max_epochs=("${doesnt_contain_max_epochs[@]}" example_feature.ipynb)
60+
doesnt_contain_max_epochs=("${doesnt_contain_max_epochs[@]}" ssl_train.ipynb)
61+
doesnt_contain_max_epochs=("${doesnt_contain_max_epochs[@]}" ssl_finetune.ipynb)
6062

6163
# Execution of the notebook in these folders / with the filename cannot be automated
6264
skip_run_papermill=()
@@ -83,6 +85,8 @@ skip_run_papermill=("${skip_run_papermill[@]}" .*monailabel_pancreas_tumor_segme
8385
skip_run_papermill=("${skip_run_papermill[@]}" .*monailabel_endoscopy_cvat_tooltracking*)
8486
skip_run_papermill=("${skip_run_papermill[@]}" .*monailabel_pathology_nuclei_segmentation_QuPath*)
8587
skip_run_papermill=("${skip_run_papermill[@]}" .*monailabel_radiology_spleen_segmentation_OHIF*)
88+
skip_run_papermill=("${skip_run_papermill[@]}" .*ssl_train*)
89+
skip_run_papermill=("${skip_run_papermill[@]}" .*ssl_finetune*)
8690
skip_run_papermill=("${skip_run_papermill[@]}" .*transform_visualization*) # https://github.com/Project-MONAI/tutorials/issues/1155
8791

8892
# output formatting

self_supervised_pretraining/README.md

Lines changed: 31 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,42 +1,37 @@
11
# Self-Supervised Pre-training Tutorial
22

3-
This directory contains two scripts. The first script 'ssl_script_train.py' generates
3+
This directory contains two Jupyter notebooks. The first is notebook 'ssl_train.ipynb' which generates
44
a good set of pre-trained weights using unlabeled data with self-supervised tasks that
5-
are based on augmentations of different types. The second script 'ssl_finetune_train.py' uses
5+
are based on augmentations of different types. The second is notebook 'ssl_finetune.ipynb' uses
66
the pre-trained weights generated from the first script and performs fine-tuning on a fully supervised
77
task.
88

9-
In case, the user wants to skip the pre-training part, the pre-trained weights can be
10-
[downloaded from here](https://drive.google.com/file/d/1D7G1FhgZfBhql4djMfiSy0xODVXnLlpd/view?usp=sharing)
11-
to use for fine-tuning tasks and directly skip to the second part of the tutorial which is using the
12-
'ssl_finetune_train.py'.
139

1410
### Steps to run the tutorial
15-
1.) Download the two datasets [TCIA-Covid19](https://wiki.cancerimagingarchive.net/display/Public/CT+Images+in+COVID-19)
16-
& [BTCV](https://www.synapse.org/#!Synapse:syn3193805/wiki/217789) (More detail about them in the Data section)\
17-
2.) Modify the paths for data_root, json_path & logdir in ssl_script_train.py\
18-
3.) Run the 'ssl_script_train.py'\
19-
4.) Modify the paths for data_root, json_path, pre-trained_weights_path from 2.) and
20-
logdir_path in 'ssl_finetuning_train.py'\
21-
5.) Run the 'ssl_finetuning_script.py'\
22-
6.) And that's all folks, use the model to your needs
23-
24-
### 1.Data
11+
1. Download the two datasets [TCIA-Covid19](https://wiki.cancerimagingarchive.net/display/Public/CT+Images+in+COVID-19)
12+
& [BTCV](https://www.synapse.org/#!Synapse:syn3193805/wiki/217789) (More detail about them in the Data section)
13+
2. Modify the paths for data_root, json_path & logdir in ssl_train.ipynb
14+
3. Run the 'ssl_train.ipynb'
15+
4. Modify the paths for data_root, json_path, pre-trained_weights_path from Step 2 and
16+
logdir_path in 'ssl_finetune.ipynb'
17+
5. Run the 'ssl_finetune.ipynb'
18+
6. And that's all folks, use the model to your needs
19+
20+
### 1. Data
2521
Pre-training Dataset: The TCIA Covid-19 dataset was used for generating the
26-
[pre-trained weights](https://drive.google.com/file/d/1D7G1FhgZfBhql4djMfiSy0xODVXnLlpd/view?usp=sharing).
27-
The dataset contains a total of 771 3D CT Volumes. The volumes were split into training and validation sets
22+
pre-trained weights. The dataset contains a total of 771 3D CT Volumes. The volumes were split into training and validation sets
2823
of 600 and 171 3D volumes correspondingly. The data is available for download at this
2924
[link](https://wiki.cancerimagingarchive.net/display/Public/CT+Images+in+COVID-19).
30-
If this dataset is being used in your work, please use [1] as reference. A json file is provided
31-
which contains the training and validation splits that were used for the training. The json file can be found in the
25+
If this dataset is being used in your work, please use 1 as a reference. A JSON file is provided
26+
which contains the training and validation splits that were used for the training. The JSON file can be found in the
3227
json_files directory of the self-supervised training tutorial.
3328

3429
Fine-tuning Dataset: The dataset from Beyond the Cranial Vault Challenge
3530
[(BTCV)](https://www.synapse.org/#!Synapse:syn3193805/wiki/217789)
3631
2015 hosted at MICCAI, was used as a fully supervised fine-tuning task on the pre-trained weights. The dataset
37-
consists of 30 3D Volumes with annotated labels of up to 13 different organs [2]. There are 3 json files provided in the
38-
json_files directory for the dataset. They correspond to having different number of training volumes ranging from
39-
6, 12 and 24. All 3 json files have the same validation split.
32+
consists of 30 3D Volumes with annotated labels of up to 13 different organs [2]. There are 3 JSON files provided in the
33+
json_files directory for the dataset. They correspond to having a different number of training volumes ranging from
34+
6, 12 and 24. All 3 JSON files have the same validation split.
4035

4136
References:
4237

@@ -48,15 +43,14 @@ Medical Image Analysis 69 (2021): 101894.
4843

4944
### 2. Network Architectures
5045

51-
For pre-training a modified version of ViT [1] has been used, it can be referred
46+
For pre-training a modified version of ViT [1] has been used, it can be referred to
5247
[here](https://docs.monai.io/en/latest/networks.html#vitautoenc)
53-
from MONAI. The original ViT was modified by attachment of two 3D Convolutional Transpose Layers to achieve a similar
48+
from MONAI. The original ViT was modified by the attachment of two 3D Convolutional Transpose Layers to achieve a similar
5449
reconstruction size as that of the input image. The ViT is the backbone for the UNETR [2] network architecture which
55-
was used for the fine-tuning fully supervised tasks.
50+
was used for fine-tuning fully supervised tasks.
5651

57-
The pre-trained backbone of ViT weights were loaded to UNETR and the decoder head still relies on random initialization
58-
for adaptability of the new downstream task. This flexibility also allows the user to adapt the ViT backbone to their
59-
own custom created network architectures as well.
52+
The pre-trained backbone of ViT weights was loaded to UNETR and the decoder head still relies on random initialization
53+
for adaptability of the new downstream task. This flexibility also allows the user to adapt the ViT backbone to their custom-created network architectures as well.
6054

6155
References:
6256

@@ -76,7 +70,7 @@ volume. Two augmented views of the same 3D patch are generated for the contrasti
7670
the two augmented views closer to each other if the views are generated from the same patch, if not then it tries to
7771
maximize the disagreement. The CL offers this functionality on a mini-batch.
7872

79-
![image](../figures/SSL_Overview_Figure.png)
73+
![SSL_overview](../figures/SSL_Overview_Figure.png)
8074

8175
The augmentations mutate the 3D patch in various ways, the primary task of the network is to reconstruct
8276
the original image. The different augmentations used are classical techniques such as in-painting [1], out-painting [1]
@@ -88,13 +82,12 @@ by the reconstruction loss as a dynamic weight itself.
8882
The below example image depicts the usage of the augmentation pipeline where two augmented views are drawn of the same
8983
3D patch:
9084

91-
![image](../figures/SSL_Different_Augviews.png)
85+
![SSL_augs](../figures/SSL_Different_Augviews.png)
9286

9387
Multiple axial slices of a 96x96x96 patch are shown before the augmentation (Ref Original Patch in the above figure).
94-
Augmented View 1 & 2 are different augmentations generated via the transforms on the same cubic patch. The objective
95-
of the SSL network is to reconstruct the original top row image from the first view. The contrastive loss
96-
is driven by maximizing agreement of the reconstruction based on input of the two augmented views.
97-
`matshow3d` from `monai.visualize` was used for creating this figure, a tutorial for using can be found [here](https://github.com/Project-MONAI/tutorials/blob/main/modules/transform_visualization.ipynb)
88+
Augmented Views 1 & 2 are different augmentations generated via the transforms on the same cubic patch. The objective
89+
of the SSL network is to reconstruct the original top-row image from the first view. The contrastive loss
90+
is driven by maximizing the agreement of the reconstruction based on the input of the two augmented views.
9891

9992
References:
10093

@@ -107,7 +100,7 @@ image analysis 58 (2019): 101539.
107100
3.) Chen, Ting, et al. "A simple framework for contrastive learning of visual representations." International conference
108101
on machine learning. PMLR, 2020.
109102

110-
### 4. Experiment Hyper-parameters
103+
### 4. Experiment with Hyper-parameters
111104

112105
Training Hyper-Parameters for SSL: \
113106
Epochs: 300 \
@@ -118,8 +111,7 @@ Loss Function: L1
118111
Contrastive Loss Temperature: 0.005
119112

120113
Training Hyper-parameters for Fine-tuning BTCV task (All settings have been kept consistent with prior
121-
[UNETR 3D
122-
Segmentation tutorial](https://github.com/Project-MONAI/tutorials/blob/main/3d_segmentation/unetr_btcv_segmentation_3d.ipynb)): \
114+
[UNETR 3D Segmentation tutorial](../3d_segmentation/unetr_btcv_segmentation_3d.ipynb): \
123115
Number of Steps: 30000 \
124116
Validation Frequency: 100 steps \
125117
Batch Size: 1 3D Volume (4 samples are drawn per 3D volume) \
@@ -130,7 +122,7 @@ Loss Function: DiceCELoss
130122

131123
![image](../figures/ssl_pretrain_losses.png)
132124

133-
L1 error reported for training and validation when performing the SSL training. Please note contrastive loss is not
125+
L1 error reported for training and validation when performing the SSL training. Please note the contrastive loss is not
134126
L1.
135127

136128
### 5. Results of the Fine-tuning vs Random Initialization on BTCV

0 commit comments

Comments
 (0)