Skip to content

Optimize Swin UNETR #774

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 21 commits into from
Jul 28, 2022
Merged

Conversation

yuchen-xu
Copy link
Contributor

Signed-off-by: Yuchen Xu [email protected]

Fixes #773 .

Description

Accelerating optimizations for the Swin UNETR tutorial.

Status

Work in progress

Checks

@review-notebook-app
Copy link

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

@yuchen-xu yuchen-xu changed the title Optimize Swin UNETR [WIP] Optimize Swin UNETR Jul 7, 2022
@yuchen-xu
Copy link
Contributor Author

In preliminary experiments, training was run for 5 epochs/120 steps, with one validation. All experiments are on one 32GB RAM GPU.

As seen below, AMP seems to provide the biggest speed boost. Note that while the effect of ToDeviced and ThreadDataLoader is small, the effect could be amplified greatly since training is designed to run for 30,000 steps.

Original: 144.86s

  • ToDeviced + ThreadDataLoader: 137.97s
  • ToDeviced + ThreadDataLoader + AMP: 75.17s

Here are the results of running larger batch sizes in combination with thread workers. All experiments are only run once. Consistent with findings at #757, the use of thread workers does not seem to improve speed, but the use of larger batch sizes does seem to help (at batch size = 4), although further investigation is needed.

Thread | 1 | 2 | 3 | 4
Batch
1 | 72.67s | 72.62s | 72.97s | 78.30s
2 | 75.70s | 76.26s | 76.58s | 76.82s
3 | 79.41s | 77.04s | 78.56s | 77.14s
4 | 71.51s | 71.48s | 71.97s | 72.36s

Further profiling analysis will be conducted in the following week.

@yuchen-xu yuchen-xu self-assigned this Jul 9, 2022
@Nic-Ma
Copy link
Contributor

Nic-Ma commented Jul 10, 2022

Hi @yuchen-xu ,

Thanks for sharing the experiments summary, progress looks good.
Please note that changing batch size may affect the training culve, so we must evaluate the training speed with "achieving target validation metrics".

Thanks.

@yuchen-xu
Copy link
Contributor Author

Just making a note in case it comes up later: the data loading function in this is load_decathlon_datalist, which returns a list of dicts (and not MetaTensor). Hence the "image_meta_dict" seems to work fine for the slice visualization in 0.9.1rc2.

@Nic-Ma
Copy link
Contributor

Nic-Ma commented Jul 13, 2022

Hi @yuchen-xu ,

That because we tried to avoid breaking changes in the new release, it's compatible in API level.
But please change to use MetaTensor features as it's tutorial to show the latest method.
d["image_meta_dict"] is same as d["image"].meta.

Thanks.

@yuchen-xu
Copy link
Contributor Author

I have encountered two major issues when optimizing this tutorial:

  1. Training with AMP torch.cuda.amp.autocast() causes training to diverge. All else being similar, the training graph looks like this:
    image
    Validating with AMP is fine. AMP encloses these two lines of code, which is consistent with its intended use (official documentation)
        logit_map = model(x)
        loss = loss_function(logit_map, y)

I suspect this might have to do with the use of transformers and/or some specific floating point operations, but I am not able to confirm.

  1. Introducing ToDevice() before random operations, combined with ThreadDataLoader (no thread workers), seems to slow training down significantly compared to using ToTensor() and DataLoader. Curiously, this is opposite to what I observed when running in 0.9.0. Here are some numbers from the V100 GPU: (120 steps, one validation)

Val AMP, middle ToDevice before random transforms, ThreadDataLoader: 358.67s total, 63s for a typical pass through the data (24 steps)
Val AMP, end ToDevice, ThreadDataLoader: 525.75s, 95s
Val AMP only: 282.16s, 47s

@tangy5 @ahatamiz - please advise, and we can work together to figure out possible reasons. I can provide more information on the code I used (the 091 one in the commit above). Thanks!

@yuchen-xu
Copy link
Contributor Author

Update: the second issue does not seem to be isolated and seems to be a broader issue in 0.9.1; see comments in another PR.

The AMP training issue is baffling and has been observed in both 0.9.0 and 0.9.1.

@yuchen-xu
Copy link
Contributor Author

After 0.9.1rc4, we are bringing back the EnsureTyped transform with track_meta = False, which will speed up processing by a lot. We have a design choice here for validation data though: we have illustrations of an example validation image both before and after training, and those illustrations need the metadata. So we have two choices: either set track_meta = True for validation data, or not use EnsureTyped for validation data. Over 30000 training steps, both would slow overall training down by less than 1%, and both have essentially the same time. Suggestions?

@wyli
Copy link
Contributor

wyli commented Jul 23, 2022

'explicit is better than implicit', so I think set track meta to true is better...

@yuchen-xu
Copy link
Contributor Author

Thanks @wyli , I agree. I will also include a blurb explaining this choice.

@yuchen-xu
Copy link
Contributor Author

Lessons learned about AMP:

It seems that all of these are required for AMP to work (at least for Swin-UNETR):

  1. GradScaler (recommended in documentation, included in fast training tutorial)
  2. optimizer.zero_grad (not mentioned in documentation, included in ftt)
  3. scaler.unscale_ (mentioned in documentation as optional if you want to clip gradients, not included in ftt)

@yuchen-xu
Copy link
Contributor Author

Speedup performance in 0.9.1rc4 (30,000 steps, validation every 500):

  • Original: 63,634s
  • AMP + ThreadDataLoader + EnsureTyped (track_meta = False for training, True for validation): 30,535s

New training curve doesn't differ much from the original and gets to 0.84 target validation accuracy.
image

@yuchen-xu yuchen-xu requested review from Nic-Ma and wyli July 24, 2022 16:01
@Nic-Ma
Copy link
Contributor

Nic-Ma commented Jul 25, 2022

Hi @yuchen-xu ,

Thanks very much for the optimization in swin unetr tutorial, I have several comments:

  1. Frist of all, I think maybe you are not aware of the functionality of get_track_meta() and set_track_meta() APIs: https://github.com/Project-MONAI/MONAI/blob/dev/monai/data/meta_obj.py#L30
    It's global API, without setting set_track_meta(False), your random transforms still run with MetaTensor.
    The speed up you saw just because we fixed some bugs in the 0.9.1rc4.
    Please take a look how the set_track_meta() is used in the fast training tutorial:
    https://github.com/Project-MONAI/tutorials/blob/main/acceleration/fast_training_tutorial.ipynb
  2. About your description "we have illustrations of an example validation image both before and after training, and those illustrations need the metadata", you can set EnsureTyped(keys=["image", "label"], track_meta=get_track_meta()) in the val transforms and set_track_meta(False) only during the training.

What do you think?

Thanks.

@Nic-Ma
Copy link
Contributor

Nic-Ma commented Jul 25, 2022

So for now, let me clarify:
(1) set_track_meta() controls whether to convert input data to MetaTensor and compute meta in every transform, (2) EnsureTyped transform used in the fast training tutorial is just to ensure the cached data is Tensor instead of MetaTensor to avoid MetaTensor -> Tensor conversion (because get_track_meta() is False) in every epoch.
My suggestion is:

  1. Try to compare the training speed of set_track_meta(True/False), if no big difference, we can remove all the track_meta related code to make the tutorial easier to understand.
  2. If above set_track_meta(False) can obviously speed up the training, compare whether adding EnsureTyped(track_meta=False) can continue to speed up a lot, if not obvious, let's remove the EnsureTyped transform to make the tutorial easier to understand.
  3. Please remove the XXX_orig notebook and the profiling python code, I think they are only for temp test?

Thanks.

@yuchen-xu
Copy link
Contributor Author

Running 10000 steps:

  • Set_track_meta = False with EnsureType: 9845s
  • Set_track_meta = True with EnsureType: 10252s
  • Set_track_meta = False without EnsureType: 11809s

So I think we should be keeping both set_track_meta and EnsureType for now.

@Nic-Ma
Copy link
Contributor

Nic-Ma commented Jul 27, 2022

Hi @yuchen-xu ,

Thanks for the update.
May I know the speed of "Set_track_meta = True without EnsureType"?
And I think you can use MONAI 0.9.1 to test now as we already released it.

Thanks.

@yuchen-xu
Copy link
Contributor Author

Yes, I am using 0.9.1.

May I know the speed of "Set_track_meta = True without EnsureType"?

Given what set_track_meta does, I assumed this would be the slowest, but I will try and run that.

@Nic-Ma
Copy link
Contributor

Nic-Ma commented Jul 27, 2022

Because I don't quite understand why Set_track_meta = True with EnsureType: 10252s is faster than Set_track_meta = False without EnsureType: 11809s..

Thanks.

@yuchen-xu
Copy link
Contributor Author

Following usage in the fast training tutorial, set_track_meta is used after CacheDataset and only affects random transforms, while EnsureTyped(track_meta = False) affects only the non-random transforms, so they're handling different sections of transforms.

I think trying the last combination will confirm/disprove this idea.

@Nic-Ma
Copy link
Contributor

Nic-Ma commented Jul 27, 2022

EnsureType does also affect the training, because it affects the cache content is Tensor or MetaTensor, and random transforms will convert it to Tensor or MetaTensor depends on the set_track_meta, this conversion step may be time-consuming.
@wyli I am afraid the as_tensor() operation of MetaTensor is deepcopy? So Set_track_meta = False without EnsureType is the slowest one. Base on the doc-string:
https://github.com/Project-MONAI/MONAI/blob/dev/monai/data/meta_tensor.py#L317

Thanks.

@yuchen-xu
Copy link
Contributor Author

set_track_meta = True without EnsureType is 14442 seconds, much slower than the others, as expected.

@yuchen-xu yuchen-xu marked this pull request as ready for review July 27, 2022 21:29
@yuchen-xu yuchen-xu changed the title [WIP] Optimize Swin UNETR Optimize Swin UNETR Jul 27, 2022
Copy link
Contributor

@Nic-Ma Nic-Ma left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It overall looks good to me, put some minor comments inline.

Thanks.

@Nic-Ma
Copy link
Contributor

Nic-Ma commented Jul 28, 2022

Please update the description doc in the beginning to add the best training time and memory usage on your GPU.

Thanks.

@Nic-Ma Nic-Ma merged commit bf06be2 into Project-MONAI:main Jul 28, 2022
@yuchen-xu
Copy link
Contributor Author

Sorry, I forgot to put in the final comparison.

For 30,000 training steps (1,250 epochs), with validation every 500 steps:
Original: 63,634s
Optimized: 30,264s (-52%)

These numbers are different from the ones I used on the poster; the poster numbers

  1. used validation every epoch (24 steps)
  2. used an original that doesn't have CacheDataset (i.e., without any MONAI optimizations)
    In there, 30,000 steps gave 148,895s (original) vs. 58,281s (optimized) (-61%).

boneseva pushed a commit to boneseva/MONAI-tutorials that referenced this pull request Apr 21, 2024
* begin work

* checkpoint, added amp, todeviced, threaddataloader

* checkpoint, debug and profiling code

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* checkpoint experiment script

* ready, pending final checks

* applied autofixes

* done

* removed unnecessary files

* updated with set_track_meta

* tmp backup for profiling code

* removed test files

* updated json download

* updated explanation

* final profiling code backup

* remove test files

* updated json link and comments

Co-authored-by: Yuchen Xu <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Optimize Swin UNETR
3 participants