Skip to content

Commit 2978120

Browse files
author
Yuchen Xu
committed
updated AMP tutorial for 0.9.1rc2
1 parent eb3746c commit 2978120

File tree

1 file changed

+22
-67
lines changed

1 file changed

+22
-67
lines changed

acceleration/automatic_mixed_precision.ipynb

Lines changed: 22 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -28,17 +28,9 @@
2828
},
2929
{
3030
"cell_type": "code",
31-
"execution_count": 2,
31+
"execution_count": null,
3232
"metadata": {},
33-
"outputs": [
34-
{
35-
"name": "stdout",
36-
"output_type": "stream",
37-
"text": [
38-
"Note: you may need to restart the kernel to use updated packages.\n"
39-
]
40-
}
41-
],
33+
"outputs": [],
4234
"source": [
4335
"!python -c \"import monai\" || pip install -q \"monai-weekly[nibabel, tqdm]\"\n",
4436
"!python -c \"import matplotlib\" || pip install -q matplotlib\n",
@@ -54,40 +46,9 @@
5446
},
5547
{
5648
"cell_type": "code",
57-
"execution_count": 1,
49+
"execution_count": null,
5850
"metadata": {},
59-
"outputs": [
60-
{
61-
"name": "stdout",
62-
"output_type": "stream",
63-
"text": [
64-
"MONAI version: 0.6.0rc1+15.gf3d436a0\n",
65-
"Numpy version: 1.20.3\n",
66-
"Pytorch version: 1.9.0a0+c3d40fd\n",
67-
"MONAI flags: HAS_EXT = True, USE_COMPILED = False\n",
68-
"MONAI rev id: f3d436a09deefcf905ece2faeec37f55ab030003\n",
69-
"\n",
70-
"Optional dependencies:\n",
71-
"Pytorch Ignite version: 0.4.5\n",
72-
"Nibabel version: 3.2.1\n",
73-
"scikit-image version: 0.15.0\n",
74-
"Pillow version: 8.2.0\n",
75-
"Tensorboard version: 2.5.0\n",
76-
"gdown version: 3.13.0\n",
77-
"TorchVision version: 0.10.0a0\n",
78-
"ITK version: 5.1.2\n",
79-
"tqdm version: 4.53.0\n",
80-
"lmdb version: 1.2.1\n",
81-
"psutil version: 5.8.0\n",
82-
"pandas version: 1.1.4\n",
83-
"einops version: 0.3.0\n",
84-
"\n",
85-
"For details about installing the optional dependencies, please visit:\n",
86-
" https://docs.monai.io/en/latest/installation.html#installing-the-recommended-dependencies\n",
87-
"\n"
88-
]
89-
}
90-
],
51+
"outputs": [],
9152
"source": [
9253
"# Copyright 2020 MONAI Consortium\n",
9354
"# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
@@ -154,17 +115,11 @@
154115
},
155116
{
156117
"cell_type": "code",
157-
"execution_count": 2,
158-
"metadata": {},
159-
"outputs": [
160-
{
161-
"name": "stdout",
162-
"output_type": "stream",
163-
"text": [
164-
"root dir is: /workspace/data/medical\n"
165-
]
166-
}
167-
],
118+
"execution_count": null,
119+
"metadata": {
120+
"scrolled": true
121+
},
122+
"outputs": [],
168123
"source": [
169124
"directory = os.environ.get(\"MONAI_DATA_DIRECTORY\")\n",
170125
"root_dir = tempfile.mkdtemp() if directory is None else directory\n",
@@ -182,7 +137,7 @@
182137
},
183138
{
184139
"cell_type": "code",
185-
"execution_count": 3,
140+
"execution_count": null,
186141
"metadata": {},
187142
"outputs": [],
188143
"source": [
@@ -204,7 +159,7 @@
204159
},
205160
{
206161
"cell_type": "code",
207-
"execution_count": 4,
162+
"execution_count": null,
208163
"metadata": {},
209164
"outputs": [],
210165
"source": [
@@ -230,7 +185,7 @@
230185
},
231186
{
232187
"cell_type": "code",
233-
"execution_count": 5,
188+
"execution_count": null,
234189
"metadata": {},
235190
"outputs": [],
236191
"source": [
@@ -275,7 +230,6 @@
275230
" fg_indices_key=\"label_fg\",\n",
276231
" bg_indices_key=\"label_bg\",\n",
277232
" ),\n",
278-
" EnsureTyped(keys=[\"image\", \"label\"]),\n",
279233
" ]\n",
280234
" )\n",
281235
" val_transforms = Compose(\n",
@@ -297,7 +251,6 @@
297251
" clip=True,\n",
298252
" ),\n",
299253
" CropForegroundd(keys=[\"image\", \"label\"], source_key=\"image\"),\n",
300-
" EnsureTyped(keys=[\"image\", \"label\"]),\n",
301254
" ]\n",
302255
" )\n",
303256
" return train_transforms, val_transforms"
@@ -313,7 +266,7 @@
313266
},
314267
{
315268
"cell_type": "code",
316-
"execution_count": 6,
269+
"execution_count": null,
317270
"metadata": {
318271
"scrolled": true
319272
},
@@ -352,8 +305,8 @@
352305
" optimizer = torch.optim.Adam(model.parameters(), 1e-4)\n",
353306
" scaler = torch.cuda.amp.GradScaler() if amp else None\n",
354307
"\n",
355-
" post_pred = Compose([EnsureType(), AsDiscrete(argmax=True, to_onehot=2)])\n",
356-
" post_label = Compose([EnsureType(), AsDiscrete(to_onehot=2)])\n",
308+
" post_pred = Compose([AsDiscrete(argmax=True, to_onehot=2)])\n",
309+
" post_label = Compose([AsDiscrete(to_onehot=2)])\n",
357310
"\n",
358311
" dice_metric = DiceMetric(include_background=False, reduction=\"mean\", get_not_nans=False)\n",
359312
"\n",
@@ -422,8 +375,8 @@
422375
" val_outputs = sliding_window_inference(\n",
423376
" val_inputs, roi_size, sw_batch_size, model\n",
424377
" )\n",
425-
" val_outputs = [post_pred(i) for i in decollate_batch(val_outputs)]\n",
426-
" val_labels = [post_label(i) for i in decollate_batch(val_labels)]\n",
378+
" val_outputs = [post_pred(i) for i in decollate_batch(val_outputs.as_tensor())]\n",
379+
" val_labels = [post_label(i) for i in decollate_batch(val_labels.as_tensor())]\n",
427380
" dice_metric(y_pred=val_outputs, y=val_labels)\n",
428381
"\n",
429382
" metric = dice_metric.aggregate().item()\n",
@@ -474,7 +427,9 @@
474427
{
475428
"cell_type": "code",
476429
"execution_count": null,
477-
"metadata": {},
430+
"metadata": {
431+
"scrolled": true
432+
},
478433
"outputs": [],
479434
"source": [
480435
"set_determinism(seed=0)\n",
@@ -856,7 +811,7 @@
856811
],
857812
"metadata": {
858813
"kernelspec": {
859-
"display_name": "Python 3",
814+
"display_name": "Python 3 (ipykernel)",
860815
"language": "python",
861816
"name": "python3"
862817
},
@@ -870,7 +825,7 @@
870825
"name": "python",
871826
"nbconvert_exporter": "python",
872827
"pygments_lexer": "ipython3",
873-
"version": "3.7.10"
828+
"version": "3.8.13"
874829
}
875830
},
876831
"nbformat": 4,

0 commit comments

Comments
 (0)