|
39 | 39 | " \n",
|
40 | 40 | " (Replace X with the major version of the CUDA toolkit, and Y with the minor version.)\n",
|
41 | 41 | "\n",
|
42 |
| - "- `GDSDataset` inherited from `PersistentDataset`.\n", |
43 |
| - "\n", |
44 |
| - " In this tutorial, we have implemented a `GDSDataset` that inherits from `PersistentDataset`. We have re-implemented the `_cachecheck` method to create and save cache using GDS.\n", |
45 |
| - "\n", |
46 | 42 | "- A simple demo comparing the time taken with and without GDS.\n",
|
47 | 43 | "\n",
|
48 | 44 | " In this tutorial, we are creating a conda environment to install `kvikio`, which provides a Python API for GDS. To install `kvikio` using other methods, refer to https://github.com/rapidsai/kvikio#install.\n",
|
|
79 | 75 | },
|
80 | 76 | {
|
81 | 77 | "cell_type": "code",
|
82 |
| - "execution_count": 1, |
| 78 | + "execution_count": null, |
83 | 79 | "metadata": {},
|
84 | 80 | "outputs": [],
|
85 | 81 | "source": [
|
86 | 82 | "import os\n",
|
87 | 83 | "import time\n",
|
88 |
| - "import cupy\n", |
89 | 84 | "import torch\n",
|
90 | 85 | "import shutil\n",
|
91 | 86 | "import tempfile\n",
|
92 |
| - "import numpy as np\n", |
93 |
| - "from typing import Any\n", |
94 |
| - "from pathlib import Path\n", |
95 |
| - "from copy import deepcopy\n", |
96 |
| - "from collections.abc import Callable, Sequence\n", |
97 |
| - "from kvikio.numpy import fromfile, tofile\n", |
98 | 87 | "\n",
|
99 | 88 | "import monai\n",
|
100 | 89 | "import monai.transforms as mt\n",
|
101 | 90 | "from monai.config import print_config\n",
|
102 |
| - "from monai.data.utils import pickle_hashing, SUPPORTED_PICKLE_MOD\n", |
103 |
| - "from monai.utils import convert_to_tensor, set_determinism, look_up_option\n", |
| 91 | + "from monai.data.dataset import GDSDataset\n", |
| 92 | + "from monai.utils import set_determinism\n", |
104 | 93 | "\n",
|
105 | 94 | "print_config()"
|
106 | 95 | ]
|
|
135 | 124 | "print(root_dir)"
|
136 | 125 | ]
|
137 | 126 | },
|
138 |
| - { |
139 |
| - "cell_type": "markdown", |
140 |
| - "metadata": {}, |
141 |
| - "source": [ |
142 |
| - "## GDSDataset" |
143 |
| - ] |
144 |
| - }, |
145 |
| - { |
146 |
| - "cell_type": "code", |
147 |
| - "execution_count": 3, |
148 |
| - "metadata": {}, |
149 |
| - "outputs": [], |
150 |
| - "source": [ |
151 |
| - "class GDSDataset(monai.data.PersistentDataset):\n", |
152 |
| - " def __init__(\n", |
153 |
| - " self,\n", |
154 |
| - " data: Sequence,\n", |
155 |
| - " transform: Sequence[Callable] | Callable,\n", |
156 |
| - " cache_dir: Path | str | None,\n", |
157 |
| - " hash_func: Callable[..., bytes] = pickle_hashing,\n", |
158 |
| - " hash_transform: Callable[..., bytes] | None = None,\n", |
159 |
| - " reset_ops_id: bool = True,\n", |
160 |
| - " device: int = None,\n", |
161 |
| - " **kwargs: Any,\n", |
162 |
| - " ) -> None:\n", |
163 |
| - " super().__init__(\n", |
164 |
| - " data=data,\n", |
165 |
| - " transform=transform,\n", |
166 |
| - " cache_dir=cache_dir,\n", |
167 |
| - " hash_func=hash_func,\n", |
168 |
| - " hash_transform=hash_transform,\n", |
169 |
| - " reset_ops_id=reset_ops_id,\n", |
170 |
| - " **kwargs,\n", |
171 |
| - " )\n", |
172 |
| - " self.device = device\n", |
173 |
| - "\n", |
174 |
| - " def _cachecheck(self, item_transformed):\n", |
175 |
| - " \"\"\"given the input dictionary ``item_transformed``, return a transformed version of it\"\"\"\n", |
176 |
| - " hashfile = None\n", |
177 |
| - " # compute a cache id\n", |
178 |
| - " if self.cache_dir is not None:\n", |
179 |
| - " data_item_md5 = self.hash_func(item_transformed).decode(\"utf-8\")\n", |
180 |
| - " data_item_md5 += self.transform_hash\n", |
181 |
| - " hashfile = self.cache_dir / f\"{data_item_md5}.pt\"\n", |
182 |
| - "\n", |
183 |
| - " if hashfile is not None and hashfile.is_file(): # cache hit\n", |
184 |
| - " with cupy.cuda.Device(self.device):\n", |
185 |
| - " item = {}\n", |
186 |
| - " for k in item_transformed:\n", |
187 |
| - " meta_k = torch.load(self.cache_dir / f\"{hashfile.name}-{k}-meta\")\n", |
188 |
| - " item[k] = fromfile(f\"{hashfile}-{k}\", dtype=np.float32, like=cupy.empty(()))\n", |
189 |
| - " item[k] = convert_to_tensor(item[k].reshape(meta_k[\"shape\"]), device=f\"cuda:{self.device}\")\n", |
190 |
| - " item[f\"{k}_meta_dict\"] = meta_k\n", |
191 |
| - " return item\n", |
192 |
| - "\n", |
193 |
| - " # create new cache\n", |
194 |
| - " _item_transformed = self._pre_transform(deepcopy(item_transformed)) # keep the original hashed\n", |
195 |
| - " if hashfile is None:\n", |
196 |
| - " return _item_transformed\n", |
197 |
| - "\n", |
198 |
| - " for k in _item_transformed: # {'image': ..., 'label': ...}\n", |
199 |
| - " _item_transformed_meta = _item_transformed[k].meta\n", |
200 |
| - " _item_transformed_data = _item_transformed[k].array\n", |
201 |
| - " _item_transformed_meta[\"shape\"] = _item_transformed_data.shape\n", |
202 |
| - " tofile(_item_transformed_data, f\"{hashfile}-{k}\")\n", |
203 |
| - " try:\n", |
204 |
| - " # NOTE: Writing to a temporary directory and then using a nearly atomic rename operation\n", |
205 |
| - " # to make the cache more robust to manual killing of parent process\n", |
206 |
| - " # which may leave partially written cache files in an incomplete state\n", |
207 |
| - " with tempfile.TemporaryDirectory() as tmpdirname:\n", |
208 |
| - " meta_hash_file_name = f\"{hashfile.name}-{k}-meta\"\n", |
209 |
| - " meta_hash_file = self.cache_dir / meta_hash_file_name\n", |
210 |
| - " temp_hash_file = Path(tmpdirname) / meta_hash_file_name\n", |
211 |
| - " torch.save(\n", |
212 |
| - " obj=_item_transformed_meta,\n", |
213 |
| - " f=temp_hash_file,\n", |
214 |
| - " pickle_module=look_up_option(self.pickle_module, SUPPORTED_PICKLE_MOD),\n", |
215 |
| - " pickle_protocol=self.pickle_protocol,\n", |
216 |
| - " )\n", |
217 |
| - " if temp_hash_file.is_file() and not meta_hash_file.is_file():\n", |
218 |
| - " # On Unix, if target exists and is a file, it will be replaced silently if the\n", |
219 |
| - " # user has permission.\n", |
220 |
| - " # for more details: https://docs.python.org/3/library/shutil.html#shutil.move.\n", |
221 |
| - " try:\n", |
222 |
| - " shutil.move(str(temp_hash_file), meta_hash_file)\n", |
223 |
| - " except FileExistsError:\n", |
224 |
| - " pass\n", |
225 |
| - " except PermissionError: # project-monai/monai issue #3613\n", |
226 |
| - " pass\n", |
227 |
| - " open(hashfile, \"a\").close() # store cacheid\n", |
228 |
| - "\n", |
229 |
| - " return _item_transformed" |
230 |
| - ] |
231 |
| - }, |
232 | 127 | {
|
233 | 128 | "cell_type": "markdown",
|
234 | 129 | "metadata": {},
|
|
245 | 140 | },
|
246 | 141 | {
|
247 | 142 | "cell_type": "code",
|
248 |
| - "execution_count": 4, |
| 143 | + "execution_count": 3, |
249 | 144 | "metadata": {},
|
250 | 145 | "outputs": [
|
251 | 146 | {
|
252 | 147 | "name": "stdout",
|
253 | 148 | "output_type": "stream",
|
254 | 149 | "text": [
|
255 |
| - "2023-07-12 09:26:17,878 - INFO - Expected md5 is None, skip md5 check for file samples.zip.\n", |
256 |
| - "2023-07-12 09:26:17,878 - INFO - File exists: samples.zip, skipped downloading.\n", |
257 |
| - "2023-07-12 09:26:17,879 - INFO - Writing into directory: /raid/yliu/test_tutorial.\n" |
| 150 | + "2023-07-27 07:59:12,054 - INFO - Expected md5 is None, skip md5 check for file samples.zip.\n", |
| 151 | + "2023-07-27 07:59:12,055 - INFO - File exists: samples.zip, skipped downloading.\n", |
| 152 | + "2023-07-27 07:59:12,056 - INFO - Writing into directory: /raid/yliu/test_tutorial.\n" |
258 | 153 | ]
|
259 | 154 | }
|
260 | 155 | ],
|
|
283 | 178 | },
|
284 | 179 | {
|
285 | 180 | "cell_type": "code",
|
286 |
| - "execution_count": 5, |
| 181 | + "execution_count": 4, |
287 | 182 | "metadata": {},
|
288 | 183 | "outputs": [],
|
289 | 184 | "source": [
|
|
299 | 194 | },
|
300 | 195 | {
|
301 | 196 | "cell_type": "code",
|
302 |
| - "execution_count": 6, |
| 197 | + "execution_count": 5, |
303 | 198 | "metadata": {},
|
304 | 199 | "outputs": [],
|
305 | 200 | "source": [
|
|
332 | 227 | },
|
333 | 228 | {
|
334 | 229 | "cell_type": "code",
|
335 |
| - "execution_count": 7, |
| 230 | + "execution_count": 6, |
336 | 231 | "metadata": {},
|
337 | 232 | "outputs": [
|
338 | 233 | {
|
339 | 234 | "name": "stdout",
|
340 | 235 | "output_type": "stream",
|
341 | 236 | "text": [
|
342 |
| - "epoch0 time 19.746733903884888\n", |
343 |
| - "epoch1 time 0.9976603984832764\n", |
344 |
| - "epoch2 time 0.982248067855835\n", |
345 |
| - "epoch3 time 0.9838874340057373\n", |
346 |
| - "epoch4 time 0.9793403148651123\n", |
347 |
| - "total time 23.69102692604065\n" |
| 237 | + "epoch0 time 20.148560762405396\n", |
| 238 | + "epoch1 time 0.9835140705108643\n", |
| 239 | + "epoch2 time 0.9708101749420166\n", |
| 240 | + "epoch3 time 0.9711742401123047\n", |
| 241 | + "epoch4 time 0.9711296558380127\n", |
| 242 | + "total time 24.04619812965393\n" |
348 | 243 | ]
|
349 | 244 | }
|
350 | 245 | ],
|
|
372 | 267 | },
|
373 | 268 | {
|
374 | 269 | "cell_type": "code",
|
375 |
| - "execution_count": 8, |
| 270 | + "execution_count": 7, |
376 | 271 | "metadata": {},
|
377 | 272 | "outputs": [
|
378 | 273 | {
|
379 | 274 | "name": "stdout",
|
380 | 275 | "output_type": "stream",
|
381 | 276 | "text": [
|
382 |
| - "epoch0 time 21.206729650497437\n", |
383 |
| - "epoch1 time 1.510526180267334\n", |
384 |
| - "epoch2 time 1.588256597518921\n", |
385 |
| - "epoch3 time 1.4431262016296387\n", |
386 |
| - "epoch4 time 1.4594802856445312\n", |
387 |
| - "total time 27.20927882194519\n" |
| 277 | + "epoch0 time 21.170511722564697\n", |
| 278 | + "epoch1 time 1.482978105545044\n", |
| 279 | + "epoch2 time 1.5378782749176025\n", |
| 280 | + "epoch3 time 1.4499244689941406\n", |
| 281 | + "epoch4 time 1.4379286766052246\n", |
| 282 | + "total time 27.08065962791443\n" |
388 | 283 | ]
|
389 | 284 | }
|
390 | 285 | ],
|
|
0 commit comments