|
53 | 53 | },
|
54 | 54 | {
|
55 | 55 | "cell_type": "code",
|
56 |
| - "execution_count": null, |
| 56 | + "execution_count": 1, |
57 | 57 | "metadata": {},
|
58 | 58 | "outputs": [],
|
59 | 59 | "source": [
|
|
69 | 69 | },
|
70 | 70 | {
|
71 | 71 | "cell_type": "code",
|
72 |
| - "execution_count": null, |
| 72 | + "execution_count": 2, |
73 | 73 | "metadata": {},
|
74 | 74 | "outputs": [],
|
75 | 75 | "source": [
|
|
78 | 78 | },
|
79 | 79 | {
|
80 | 80 | "cell_type": "code",
|
81 |
| - "execution_count": null, |
| 81 | + "execution_count": 3, |
82 | 82 | "metadata": {
|
83 | 83 | "id": "KvbbZuhmquRR"
|
84 | 84 | },
|
|
92 | 92 | },
|
93 | 93 | {
|
94 | 94 | "cell_type": "code",
|
95 |
| - "execution_count": null, |
| 95 | + "execution_count": 4, |
96 | 96 | "metadata": {
|
97 | 97 | "id": "gduPdIturUIB"
|
98 | 98 | },
|
99 | 99 | "outputs": [],
|
100 | 100 | "source": [
|
101 |
| - "from pathlib import Path\n", |
102 | 101 | "from datetime import datetime\n",
|
| 102 | + "import os\n", |
| 103 | + "import tempfile\n", |
| 104 | + "from glob import glob\n", |
103 | 105 | "\n",
|
104 | 106 | "import torch\n",
|
105 | 107 | "from torch.utils.data import random_split, DataLoader\n",
|
|
117 | 119 | "%load_ext tensorboard"
|
118 | 120 | ]
|
119 | 121 | },
|
| 122 | + { |
| 123 | + "cell_type": "markdown", |
| 124 | + "metadata": {}, |
| 125 | + "source": [ |
| 126 | + "## Setup data directory\n", |
| 127 | + "\n", |
| 128 | + "You can specify a directory with the `MONAI_DATA_DIRECTORY` environment variable. \n", |
| 129 | + "This allows you to save results and reuse downloads. \n", |
| 130 | + "If not specified a temporary directory will be used." |
| 131 | + ] |
| 132 | + }, |
| 133 | + { |
| 134 | + "cell_type": "code", |
| 135 | + "execution_count": 5, |
| 136 | + "metadata": {}, |
| 137 | + "outputs": [ |
| 138 | + { |
| 139 | + "name": "stdout", |
| 140 | + "output_type": "stream", |
| 141 | + "text": [ |
| 142 | + "/mnt/data/rbrown/Documents/Data/MONAI\n" |
| 143 | + ] |
| 144 | + } |
| 145 | + ], |
| 146 | + "source": [ |
| 147 | + "directory = os.environ.get(\"MONAI_DATA_DIRECTORY\")\n", |
| 148 | + "root_dir = tempfile.mkdtemp() if directory is None else directory\n", |
| 149 | + "print(root_dir)" |
| 150 | + ] |
| 151 | + }, |
120 | 152 | {
|
121 | 153 | "cell_type": "markdown",
|
122 | 154 | "metadata": {
|
|
145 | 177 | },
|
146 | 178 | {
|
147 | 179 | "cell_type": "code",
|
148 |
| - "execution_count": null, |
| 180 | + "execution_count": 6, |
149 | 181 | "metadata": {
|
150 | 182 | "id": "KuhTaRl3vf37"
|
151 | 183 | },
|
152 | 184 | "outputs": [],
|
153 | 185 | "source": [
|
154 |
| - "\n", |
155 |
| - "\n", |
156 | 186 | "class MedicalDecathlonDataModule(pl.LightningDataModule):\n",
|
157 | 187 | " def __init__(self, task, batch_size, train_val_ratio):\n",
|
158 | 188 | " super().__init__()\n",
|
159 | 189 | " self.task = task\n",
|
160 | 190 | " self.batch_size = batch_size\n",
|
161 |
| - " self.dataset_dir = Path(task)\n", |
| 191 | + " self.base_dir = root_dir\n", |
| 192 | + " self.dataset_dir = os.path.join(root_dir, task)\n", |
162 | 193 | " self.train_val_ratio = train_val_ratio\n",
|
163 | 194 | " self.subjects = None\n",
|
164 | 195 | " self.test_subjects = None\n",
|
|
175 | 206 | " return shapes.max(axis=0)\n",
|
176 | 207 | "\n",
|
177 | 208 | " def download_data(self):\n",
|
178 |
| - " if not self.dataset_dir.is_dir():\n", |
179 |
| - " url = 'https://msd-for-monai.s3-us-west-2.amazonaws.com/Task04_Hippocampus.tar'\n", |
180 |
| - " monai.apps.download_and_extract(url=url, output_dir=\".\")\n", |
| 209 | + " if not os.path.isdir(self.dataset_dir):\n", |
| 210 | + " url = f'https://msd-for-monai.s3-us-west-2.amazonaws.com/{self.task}.tar'\n", |
| 211 | + " monai.apps.download_and_extract(url=url, output_dir=self.base_dir)\n", |
181 | 212 | "\n",
|
182 |
| - " def get_niis(d):\n", |
183 |
| - " return sorted(p for p in d.glob('*.nii*') if not p.name.startswith('.'))\n", |
184 |
| - "\n", |
185 |
| - " image_training_paths = get_niis(self.dataset_dir / 'imagesTr')\n", |
186 |
| - " label_training_paths = get_niis(self.dataset_dir / 'labelsTr')\n", |
187 |
| - " image_test_paths = get_niis(self.dataset_dir / 'imagesTs')\n", |
| 213 | + " image_training_paths = sorted(glob(os.path.join(self.dataset_dir, 'imagesTr', \"*.nii*\")))\n", |
| 214 | + " label_training_paths = sorted(glob(os.path.join(self.dataset_dir, 'labelsTr', \"*.nii*\")))\n", |
| 215 | + " image_test_paths = sorted(glob(os.path.join(self.dataset_dir, 'imagesTs', \"*.nii*\")))\n", |
188 | 216 | " return image_training_paths, label_training_paths, image_test_paths\n",
|
189 | 217 | "\n",
|
190 | 218 | " def prepare_data(self):\n",
|
|
260 | 288 | },
|
261 | 289 | {
|
262 | 290 | "cell_type": "code",
|
263 |
| - "execution_count": null, |
| 291 | + "execution_count": 7, |
264 | 292 | "metadata": {
|
265 | 293 | "id": "hcHf9w2nLfyC"
|
266 | 294 | },
|
|
284 | 312 | },
|
285 | 313 | {
|
286 | 314 | "cell_type": "code",
|
287 |
| - "execution_count": null, |
| 315 | + "execution_count": 8, |
288 | 316 | "metadata": {
|
289 | 317 | "colab": {
|
290 | 318 | "base_uri": "https://localhost:8080/"
|
|
293 | 321 | "outputId": "7cb39051-4c26-4811-b838-8a5e938e53a3"
|
294 | 322 | },
|
295 | 323 | "outputs": [
|
296 |
| - { |
297 |
| - "name": "stderr", |
298 |
| - "output_type": "stream", |
299 |
| - "text": [ |
300 |
| - "Downloading...\n", |
301 |
| - "From: https://drive.google.com/uc?id=1RzPB1_bqzQhlWvU-YGvZzhx2omcDh38C\n", |
302 |
| - "To: /content/Task04_Hippocampus.tar\n", |
303 |
| - "28.4MB [00:00, 82.8MB/s]\n" |
304 |
| - ] |
305 |
| - }, |
306 | 324 | {
|
307 | 325 | "name": "stdout",
|
308 | 326 | "output_type": "stream",
|
|
341 | 359 | },
|
342 | 360 | {
|
343 | 361 | "cell_type": "code",
|
344 |
| - "execution_count": null, |
| 362 | + "execution_count": 9, |
345 | 363 | "metadata": {
|
346 | 364 | "id": "1Ov3H12p6Qx1"
|
347 | 365 | },
|
|
395 | 413 | },
|
396 | 414 | {
|
397 | 415 | "cell_type": "code",
|
398 |
| - "execution_count": null, |
| 416 | + "execution_count": 10, |
399 | 417 | "metadata": {
|
400 | 418 | "colab": {
|
401 | 419 | "base_uri": "https://localhost:8080/"
|
|
0 commit comments