Skip to content

Commit 2920135

Browse files
Add TensorBoardStatsHandler for 2d_classfication (#1650)
Fixes #1649 ### Description Add TensorBoardStatsHandler for 2d_classfication ### Checks <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [ ] Avoid including large-size files in the PR. - [ ] Clean up long text outputs from code cells in the notebook. - [ ] For security purposes, please check the contents and remove any sensitive info such as user names and private key. - [ ] Ensure (1) hyperlinks and markdown anchors are working (2) use relative paths for tutorial repo files (3) put figure and graphs in the `./figure` folder - [ ] Notebook runs automatically `./runner.sh -t <path to .ipynb file>` --------- Signed-off-by: YunLiu <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 5d561cb commit 2920135

File tree

2 files changed

+444
-45
lines changed

2 files changed

+444
-45
lines changed

2d_classification/mednist_tutorial.ipynb

Lines changed: 53 additions & 45 deletions
Large diffs are not rendered by default.

2d_classification/monai_201.ipynb

Lines changed: 391 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,391 @@
1+
{
2+
"cells": [
3+
{
4+
"attachments": {},
5+
"cell_type": "markdown",
6+
"metadata": {},
7+
"source": [
8+
"Copyright (c) MONAI Consortium \n",
9+
"Licensed under the Apache License, Version 2.0 (the \"License\"); \n",
10+
"you may not use this file except in compliance with the License. \n",
11+
"You may obtain a copy of the License at \n",
12+
"&nbsp;&nbsp;&nbsp;&nbsp;http://www.apache.org/licenses/LICENSE-2.0 \n",
13+
"Unless required by applicable law or agreed to in writing, software \n",
14+
"distributed under the License is distributed on an \"AS IS\" BASIS, \n",
15+
"WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. \n",
16+
"See the License for the specific language governing permissions and \n",
17+
"limitations under the License.\n",
18+
"\n",
19+
"# MONAI 201 tutorial\n",
20+
"\n",
21+
"In this tutorial we'll revisit the [MONAI 101 notebook](https://github.com/Project-MONAI/tutorials/blob/main/2d_classification/monai_101.ipynb) and add more features representing best practice concepts. This will include evaluation and tensorboard handler techniques.\n",
22+
"\n",
23+
"These steps will be included in this tutorial, and each of them will take only a few lines of code:\n",
24+
"- Dataset download and Data pre-processing\n",
25+
"- Define a DenseNet-121 and run training\n",
26+
"- Run inference using SupervisedEvaluator\n",
27+
"\n",
28+
"This tutorial will use about 7GB of GPU memory and 10 minutes to run.\n",
29+
"\n",
30+
"[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Project-MONAI/tutorials/blob/main/2d_classification/monai_201.ipynb)"
31+
]
32+
},
33+
{
34+
"cell_type": "markdown",
35+
"metadata": {},
36+
"source": [
37+
"## Setup environment"
38+
]
39+
},
40+
{
41+
"cell_type": "code",
42+
"execution_count": 10,
43+
"metadata": {},
44+
"outputs": [],
45+
"source": [
46+
"!python -c \"import monai\" || pip install -q \"monai-weekly[ignite, tqdm, tensorboard]\""
47+
]
48+
},
49+
{
50+
"attachments": {},
51+
"cell_type": "markdown",
52+
"metadata": {},
53+
"source": [
54+
"## Setup imports"
55+
]
56+
},
57+
{
58+
"cell_type": "code",
59+
"execution_count": null,
60+
"metadata": {},
61+
"outputs": [],
62+
"source": [
63+
"import logging\n",
64+
"import numpy as np\n",
65+
"import os\n",
66+
"from pathlib import Path\n",
67+
"import sys\n",
68+
"import tempfile\n",
69+
"import torch\n",
70+
"import ignite\n",
71+
"\n",
72+
"from monai.apps import MedNISTDataset\n",
73+
"from monai.config import print_config\n",
74+
"from monai.data import DataLoader\n",
75+
"from monai.engines import SupervisedTrainer, SupervisedEvaluator\n",
76+
"from monai.handlers import (\n",
77+
" StatsHandler,\n",
78+
" TensorBoardStatsHandler,\n",
79+
" ValidationHandler,\n",
80+
" CheckpointSaver,\n",
81+
" CheckpointLoader,\n",
82+
" ClassificationSaver,\n",
83+
")\n",
84+
"from monai.handlers.utils import from_engine\n",
85+
"from monai.inferers import SimpleInferer\n",
86+
"from monai.networks.nets import densenet121\n",
87+
"from monai.transforms import LoadImageD, EnsureChannelFirstD, ScaleIntensityD, Compose, AsDiscreted\n",
88+
"\n",
89+
"print_config()"
90+
]
91+
},
92+
{
93+
"attachments": {},
94+
"cell_type": "markdown",
95+
"metadata": {},
96+
"source": [
97+
"## Setup data directory\n",
98+
"\n",
99+
"You can specify a directory with the `MONAI_DATA_DIRECTORY` environment variable. \n",
100+
"This allows you to save results and reuse downloads. \n",
101+
"If not specified a temporary directory will be used."
102+
]
103+
},
104+
{
105+
"cell_type": "code",
106+
"execution_count": 12,
107+
"metadata": {},
108+
"outputs": [
109+
{
110+
"name": "stdout",
111+
"output_type": "stream",
112+
"text": [
113+
"/workspace/Data\n"
114+
]
115+
}
116+
],
117+
"source": [
118+
"directory = os.environ.get(\"MONAI_DATA_DIRECTORY\")\n",
119+
"root_dir = tempfile.mkdtemp() if directory is None else directory\n",
120+
"print(root_dir)"
121+
]
122+
},
123+
{
124+
"attachments": {},
125+
"cell_type": "markdown",
126+
"metadata": {},
127+
"source": [
128+
"## Use MONAI transforms to preprocess data\n",
129+
"\n",
130+
"We'll first prepare the data very much like in the previous tutorial with the same transforms and dataset:"
131+
]
132+
},
133+
{
134+
"cell_type": "code",
135+
"execution_count": 13,
136+
"metadata": {},
137+
"outputs": [
138+
{
139+
"name": "stdout",
140+
"output_type": "stream",
141+
"text": [
142+
"2024-02-27 08:31:31,955 - INFO - Verified 'MedNIST.tar.gz', md5: 0bc7306e7427e00ad1c5526a6677552d.\n",
143+
"2024-02-27 08:31:31,955 - INFO - File exists: /workspace/Data/MedNIST.tar.gz, skipped downloading.\n",
144+
"2024-02-27 08:31:31,956 - INFO - Non-empty folder exists in /workspace/Data/MedNIST, skipped extracting.\n"
145+
]
146+
},
147+
{
148+
"name": "stderr",
149+
"output_type": "stream",
150+
"text": [
151+
"Loading dataset: 0%| | 0/47164 [00:00<?, ?it/s]"
152+
]
153+
},
154+
{
155+
"name": "stderr",
156+
"output_type": "stream",
157+
"text": [
158+
"Loading dataset: 100%|██████████| 47164/47164 [00:19<00:00, 2393.21it/s]\n",
159+
"Loading dataset: 100%|██████████| 5895/5895 [00:02<00:00, 2465.05it/s]\n"
160+
]
161+
}
162+
],
163+
"source": [
164+
"transform = Compose(\n",
165+
" [\n",
166+
" LoadImageD(keys=\"image\", image_only=True),\n",
167+
" EnsureChannelFirstD(keys=\"image\"),\n",
168+
" ScaleIntensityD(keys=\"image\"),\n",
169+
" ]\n",
170+
")\n",
171+
"\n",
172+
"# If you use the MedNIST dataset, please acknowledge the source.\n",
173+
"dataset = MedNISTDataset(root_dir=root_dir, transform=transform, section=\"training\", download=True)\n",
174+
"valdata = MedNISTDataset(root_dir=root_dir, transform=transform, section=\"validation\", download=False)"
175+
]
176+
},
177+
{
178+
"cell_type": "markdown",
179+
"metadata": {},
180+
"source": [
181+
"## Define a network and a supervised trainer\n",
182+
"\n",
183+
"For training we have the same elements again and will slightly change the `SupervisedTrainer` by expanding its train_handlers. This upgrade will be beneficial for efficient utilization of TensorBoard.\n",
184+
"Furthermore, we introduce a `SupervisedEvaluator` object that will efficiently track model progress. Accompanied by `TensorBoardStatsHandler`, it will log statistics for TensorBoard, ensuring precise tracking and management."
185+
]
186+
},
187+
{
188+
"cell_type": "code",
189+
"execution_count": 14,
190+
"metadata": {},
191+
"outputs": [],
192+
"source": [
193+
"max_epochs = 5\n",
194+
"save_interval = 2\n",
195+
"out_dir = \"./eval\"\n",
196+
"model = densenet121(spatial_dims=2, in_channels=1, out_channels=6).to(\"cuda:0\")\n",
197+
"\n",
198+
"logging.basicConfig(stream=sys.stdout, level=logging.INFO)\n",
199+
"\n",
200+
"evaluator = SupervisedEvaluator(\n",
201+
" device=torch.device(\"cuda:0\"),\n",
202+
" val_data_loader=DataLoader(valdata, batch_size=512, shuffle=False, num_workers=4),\n",
203+
" network=model,\n",
204+
" inferer=SimpleInferer(),\n",
205+
" key_val_metric={\"val_acc\": ignite.metrics.Accuracy(from_engine([\"pred\", \"label\"]))},\n",
206+
" val_handlers=[StatsHandler(iteration_log=False), TensorBoardStatsHandler(iteration_log=False)],\n",
207+
")\n",
208+
"\n",
209+
"trainer = SupervisedTrainer(\n",
210+
" device=torch.device(\"cuda:0\"),\n",
211+
" max_epochs=max_epochs,\n",
212+
" train_data_loader=DataLoader(dataset, batch_size=512, shuffle=True, num_workers=4),\n",
213+
" network=model,\n",
214+
" optimizer=torch.optim.Adam(model.parameters(), lr=1e-5),\n",
215+
" loss_function=torch.nn.CrossEntropyLoss(),\n",
216+
" inferer=SimpleInferer(),\n",
217+
" train_handlers=[\n",
218+
" ValidationHandler(validator=evaluator, epoch_level=True, interval=1),\n",
219+
" CheckpointSaver(\n",
220+
" save_dir=out_dir,\n",
221+
" save_dict={\"model\": model},\n",
222+
" save_interval=save_interval,\n",
223+
" save_final=True,\n",
224+
" final_filename=\"checkpoint.pt\",\n",
225+
" ),\n",
226+
" StatsHandler(),\n",
227+
" TensorBoardStatsHandler(tag_name=\"train_loss\", output_transform=from_engine([\"loss\"], first=True)),\n",
228+
" ],\n",
229+
")"
230+
]
231+
},
232+
{
233+
"attachments": {},
234+
"cell_type": "markdown",
235+
"metadata": {},
236+
"source": [
237+
"## Run the training"
238+
]
239+
},
240+
{
241+
"cell_type": "code",
242+
"execution_count": null,
243+
"metadata": {},
244+
"outputs": [],
245+
"source": [
246+
"trainer.run()"
247+
]
248+
},
249+
{
250+
"cell_type": "markdown",
251+
"metadata": {},
252+
"source": [
253+
"## View training in tensorboard\n",
254+
"\n",
255+
"Please uncomment the following cell to load tensorboard results."
256+
]
257+
},
258+
{
259+
"cell_type": "code",
260+
"execution_count": 2,
261+
"metadata": {},
262+
"outputs": [],
263+
"source": [
264+
"# %load_ext tensorboard\n",
265+
"# %tensorboard --logdir ./runs"
266+
]
267+
},
268+
{
269+
"cell_type": "markdown",
270+
"metadata": {},
271+
"source": [
272+
"## Inference\n",
273+
"\n",
274+
"First thing to do is to prepare the test dataset:"
275+
]
276+
},
277+
{
278+
"cell_type": "code",
279+
"execution_count": 6,
280+
"metadata": {},
281+
"outputs": [],
282+
"source": [
283+
"dataset_dir = Path(root_dir, \"MedNIST\")\n",
284+
"class_names = sorted(f\"{x.name}\" for x in dataset_dir.iterdir() if x.is_dir())\n",
285+
"testdata = MedNISTDataset(root_dir=root_dir, transform=transform, section=\"test\", download=False, runtime_cache=True)"
286+
]
287+
},
288+
{
289+
"attachments": {},
290+
"cell_type": "markdown",
291+
"metadata": {},
292+
"source": [
293+
"Next, we're going to establish a `SupervisedEvaluator`. This evaluator will process all the files in the specified directory and persist the results into a CSV file. Validation handlers (val_handlers) will be utilized to load the checkpoint file, providing an error if any file is unavailable, and they will also save the classification outcomes."
294+
]
295+
},
296+
{
297+
"cell_type": "code",
298+
"execution_count": 10,
299+
"metadata": {},
300+
"outputs": [
301+
{
302+
"name": "stdout",
303+
"output_type": "stream",
304+
"text": [
305+
"INFO:ignite.engine.engine.SupervisedEvaluator:Engine run resuming from iteration 0, epoch 0 until 1 epochs\n",
306+
"INFO:ignite.engine.engine.SupervisedEvaluator:Restored all variables from ./eval/checkpoint.pt\n",
307+
"INFO:ignite.engine.engine.SupervisedEvaluator:Epoch[1] Complete. Time taken: 00:01:24.338\n",
308+
"INFO:ignite.engine.engine.SupervisedEvaluator:Engine run complete. Time taken: 00:01:24.390\n"
309+
]
310+
}
311+
],
312+
"source": [
313+
"evaluator = SupervisedEvaluator(\n",
314+
" device=torch.device(\"cuda:0\"),\n",
315+
" val_data_loader=DataLoader(testdata, batch_size=1, num_workers=0),\n",
316+
" network=model,\n",
317+
" inferer=SimpleInferer(),\n",
318+
" postprocessing=AsDiscreted(keys=\"pred\", argmax=True),\n",
319+
" val_handlers=[\n",
320+
" CheckpointLoader(load_path=f\"{out_dir}/checkpoint.pt\", load_dict={\"model\": model}),\n",
321+
" ClassificationSaver(\n",
322+
" batch_transform=lambda batch: batch[0][\"image\"].meta, output_transform=from_engine([\"pred\"])\n",
323+
" ),\n",
324+
" ],\n",
325+
")\n",
326+
"\n",
327+
"evaluator.run()"
328+
]
329+
},
330+
{
331+
"cell_type": "markdown",
332+
"metadata": {},
333+
"source": [
334+
"By default, the inference results are stored in a file named \"predictions.csv\". However, this output filename can be customized within the `ClassificationSaver` handler, according to your preferences.\n",
335+
"Upon examining the output, one can note that the second column corresponds to the predicted class. A more discernable interpretation can be achieved by using these values as indices mapped to our predefined list of class names."
336+
]
337+
},
338+
{
339+
"cell_type": "code",
340+
"execution_count": 12,
341+
"metadata": {},
342+
"outputs": [
343+
{
344+
"name": "stdout",
345+
"output_type": "stream",
346+
"text": [
347+
"/workspace/Data/MedNIST/AbdomenCT/006070.jpeg AbdomenCT\n",
348+
"/workspace/Data/MedNIST/BreastMRI/006574.jpeg BreastMRI\n",
349+
"/workspace/Data/MedNIST/ChestCT/009858.jpeg ChestCT\n",
350+
"/workspace/Data/MedNIST/CXR/007398.jpeg CXR\n",
351+
"/workspace/Data/MedNIST/Hand/005663.jpeg Hand\n",
352+
"/workspace/Data/MedNIST/HeadCT/006896.jpeg HeadCT\n",
353+
"/workspace/Data/MedNIST/HeadCT/007179.jpeg HeadCT\n",
354+
"/workspace/Data/MedNIST/CXR/001190.jpeg CXR\n",
355+
"/workspace/Data/MedNIST/ChestCT/005138.jpeg ChestCT\n",
356+
"/workspace/Data/MedNIST/BreastMRI/000023.jpeg BreastMRI\n"
357+
]
358+
}
359+
],
360+
"source": [
361+
"max_items_to_print = 10\n",
362+
"for fn, idx in np.loadtxt(\"./predictions.csv\", delimiter=\",\", dtype=str):\n",
363+
" print(fn, class_names[int(float(idx))])\n",
364+
" max_items_to_print -= 1\n",
365+
" if max_items_to_print == 0:\n",
366+
" break"
367+
]
368+
}
369+
],
370+
"metadata": {
371+
"kernelspec": {
372+
"display_name": "Python 3",
373+
"language": "python",
374+
"name": "python3"
375+
},
376+
"language_info": {
377+
"codemirror_mode": {
378+
"name": "ipython",
379+
"version": 3
380+
},
381+
"file_extension": ".py",
382+
"mimetype": "text/x-python",
383+
"name": "python",
384+
"nbconvert_exporter": "python",
385+
"pygments_lexer": "ipython3",
386+
"version": "3.10.12"
387+
}
388+
},
389+
"nbformat": 4,
390+
"nbformat_minor": 2
391+
}

0 commit comments

Comments
 (0)