-
Notifications
You must be signed in to change notification settings - Fork 739
Optimize Swin UNETR #774
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Optimize Swin UNETR #774
Conversation
Check out this pull request on See visual diffs & provide feedback on Jupyter Notebooks. Powered by ReviewNB |
In preliminary experiments, training was run for 5 epochs/120 steps, with one validation. All experiments are on one 32GB RAM GPU. As seen below, AMP seems to provide the biggest speed boost. Note that while the effect of ToDeviced and ThreadDataLoader is small, the effect could be amplified greatly since training is designed to run for 30,000 steps. Original: 144.86s
Here are the results of running larger batch sizes in combination with thread workers. All experiments are only run once. Consistent with findings at #757, the use of thread workers does not seem to improve speed, but the use of larger batch sizes does seem to help (at batch size = 4), although further investigation is needed. Thread | 1 | 2 | 3 | 4 Further profiling analysis will be conducted in the following week. |
Hi @yuchen-xu , Thanks for sharing the experiments summary, progress looks good. Thanks. |
Just making a note in case it comes up later: the data loading function in this is load_decathlon_datalist, which returns a list of dicts (and not MetaTensor). Hence the "image_meta_dict" seems to work fine for the slice visualization in 0.9.1rc2. |
Hi @yuchen-xu , That because we tried to avoid breaking changes in the new release, it's compatible in API level. Thanks. |
I have encountered two major issues when optimizing this tutorial:
I suspect this might have to do with the use of transformers and/or some specific floating point operations, but I am not able to confirm.
Val AMP, middle ToDevice before random transforms, ThreadDataLoader: 358.67s total, 63s for a typical pass through the data (24 steps) @tangy5 @ahatamiz - please advise, and we can work together to figure out possible reasons. I can provide more information on the code I used (the 091 one in the commit above). Thanks! |
Update: the second issue does not seem to be isolated and seems to be a broader issue in 0.9.1; see comments in another PR. The AMP training issue is baffling and has been observed in both 0.9.0 and 0.9.1. |
After 0.9.1rc4, we are bringing back the EnsureTyped transform with track_meta = False, which will speed up processing by a lot. We have a design choice here for validation data though: we have illustrations of an example validation image both before and after training, and those illustrations need the metadata. So we have two choices: either set track_meta = True for validation data, or not use EnsureTyped for validation data. Over 30000 training steps, both would slow overall training down by less than 1%, and both have essentially the same time. Suggestions? |
'explicit is better than implicit', so I think set track meta to true is better... |
Thanks @wyli , I agree. I will also include a blurb explaining this choice. |
Lessons learned about AMP: It seems that all of these are required for AMP to work (at least for Swin-UNETR):
|
Hi @yuchen-xu , Thanks very much for the optimization in swin unetr tutorial, I have several comments:
What do you think? Thanks. |
So for now, let me clarify:
Thanks. |
…tutorials into 773-optimize-swin-unetr
Running 10000 steps:
So I think we should be keeping both set_track_meta and EnsureType for now. |
Hi @yuchen-xu , Thanks for the update. Thanks. |
Yes, I am using 0.9.1.
Given what set_track_meta does, I assumed this would be the slowest, but I will try and run that. |
Because I don't quite understand why Thanks. |
Following usage in the fast training tutorial, set_track_meta is used after CacheDataset and only affects random transforms, while EnsureTyped(track_meta = False) affects only the non-random transforms, so they're handling different sections of transforms. I think trying the last combination will confirm/disprove this idea. |
Thanks. |
set_track_meta = True without EnsureType is 14442 seconds, much slower than the others, as expected. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It overall looks good to me, put some minor comments inline.
Thanks.
Please update the description doc in the beginning to add the best training time and memory usage on your GPU. Thanks. |
…tutorials into 773-optimize-swin-unetr
Sorry, I forgot to put in the final comparison. For 30,000 training steps (1,250 epochs), with validation every 500 steps: These numbers are different from the ones I used on the poster; the poster numbers
|
* begin work * checkpoint, added amp, todeviced, threaddataloader * checkpoint, debug and profiling code * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * checkpoint experiment script * ready, pending final checks * applied autofixes * done * removed unnecessary files * updated with set_track_meta * tmp backup for profiling code * removed test files * updated json download * updated explanation * final profiling code backup * remove test files * updated json link and comments Co-authored-by: Yuchen Xu <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Signed-off-by: Yuchen Xu [email protected]
Fixes #773 .
Description
Accelerating optimizations for the Swin UNETR tutorial.
Status
Work in progress
Checks