Skip to content

Commit ddf61ab

Browse files
authored
remove the draft version of the GDSDataset (#1473)
### Description - After Project-MONAI/MONAI#6778, `GDSDataset` has been added in core, remove the draft version in the tutorial. - Update the result based on the new version ### Checks <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Avoid including large-size files in the PR. - [x] Clean up long text outputs from code cells in the notebook. - [ ] For security purposes, please check the contents and remove any sensitive info such as user names and private key. - [ ] Ensure (1) hyperlinks and markdown anchors are working (2) use relative paths for tutorial repo files (3) put figure and graphs in the `./figure` folder - [ ] Notebook runs automatically `./runner.sh -t <path to .ipynb file>` Signed-off-by: KumoLiu <[email protected]>
1 parent d0de28e commit ddf61ab

File tree

1 file changed

+23
-128
lines changed

1 file changed

+23
-128
lines changed

modules/GDS_dataset.ipynb

Lines changed: 23 additions & 128 deletions
Original file line numberDiff line numberDiff line change
@@ -39,10 +39,6 @@
3939
" \n",
4040
" (Replace X with the major version of the CUDA toolkit, and Y with the minor version.)\n",
4141
"\n",
42-
"- `GDSDataset` inherited from `PersistentDataset`.\n",
43-
"\n",
44-
" In this tutorial, we have implemented a `GDSDataset` that inherits from `PersistentDataset`. We have re-implemented the `_cachecheck` method to create and save cache using GDS.\n",
45-
"\n",
4642
"- A simple demo comparing the time taken with and without GDS.\n",
4743
"\n",
4844
" In this tutorial, we are creating a conda environment to install `kvikio`, which provides a Python API for GDS. To install `kvikio` using other methods, refer to https://github.com/rapidsai/kvikio#install.\n",
@@ -79,28 +75,21 @@
7975
},
8076
{
8177
"cell_type": "code",
82-
"execution_count": 1,
78+
"execution_count": null,
8379
"metadata": {},
8480
"outputs": [],
8581
"source": [
8682
"import os\n",
8783
"import time\n",
88-
"import cupy\n",
8984
"import torch\n",
9085
"import shutil\n",
9186
"import tempfile\n",
92-
"import numpy as np\n",
93-
"from typing import Any\n",
94-
"from pathlib import Path\n",
95-
"from copy import deepcopy\n",
96-
"from collections.abc import Callable, Sequence\n",
97-
"from kvikio.numpy import fromfile, tofile\n",
9887
"\n",
9988
"import monai\n",
10089
"import monai.transforms as mt\n",
10190
"from monai.config import print_config\n",
102-
"from monai.data.utils import pickle_hashing, SUPPORTED_PICKLE_MOD\n",
103-
"from monai.utils import convert_to_tensor, set_determinism, look_up_option\n",
91+
"from monai.data.dataset import GDSDataset\n",
92+
"from monai.utils import set_determinism\n",
10493
"\n",
10594
"print_config()"
10695
]
@@ -135,100 +124,6 @@
135124
"print(root_dir)"
136125
]
137126
},
138-
{
139-
"cell_type": "markdown",
140-
"metadata": {},
141-
"source": [
142-
"## GDSDataset"
143-
]
144-
},
145-
{
146-
"cell_type": "code",
147-
"execution_count": 3,
148-
"metadata": {},
149-
"outputs": [],
150-
"source": [
151-
"class GDSDataset(monai.data.PersistentDataset):\n",
152-
" def __init__(\n",
153-
" self,\n",
154-
" data: Sequence,\n",
155-
" transform: Sequence[Callable] | Callable,\n",
156-
" cache_dir: Path | str | None,\n",
157-
" hash_func: Callable[..., bytes] = pickle_hashing,\n",
158-
" hash_transform: Callable[..., bytes] | None = None,\n",
159-
" reset_ops_id: bool = True,\n",
160-
" device: int = None,\n",
161-
" **kwargs: Any,\n",
162-
" ) -> None:\n",
163-
" super().__init__(\n",
164-
" data=data,\n",
165-
" transform=transform,\n",
166-
" cache_dir=cache_dir,\n",
167-
" hash_func=hash_func,\n",
168-
" hash_transform=hash_transform,\n",
169-
" reset_ops_id=reset_ops_id,\n",
170-
" **kwargs,\n",
171-
" )\n",
172-
" self.device = device\n",
173-
"\n",
174-
" def _cachecheck(self, item_transformed):\n",
175-
" \"\"\"given the input dictionary ``item_transformed``, return a transformed version of it\"\"\"\n",
176-
" hashfile = None\n",
177-
" # compute a cache id\n",
178-
" if self.cache_dir is not None:\n",
179-
" data_item_md5 = self.hash_func(item_transformed).decode(\"utf-8\")\n",
180-
" data_item_md5 += self.transform_hash\n",
181-
" hashfile = self.cache_dir / f\"{data_item_md5}.pt\"\n",
182-
"\n",
183-
" if hashfile is not None and hashfile.is_file(): # cache hit\n",
184-
" with cupy.cuda.Device(self.device):\n",
185-
" item = {}\n",
186-
" for k in item_transformed:\n",
187-
" meta_k = torch.load(self.cache_dir / f\"{hashfile.name}-{k}-meta\")\n",
188-
" item[k] = fromfile(f\"{hashfile}-{k}\", dtype=np.float32, like=cupy.empty(()))\n",
189-
" item[k] = convert_to_tensor(item[k].reshape(meta_k[\"shape\"]), device=f\"cuda:{self.device}\")\n",
190-
" item[f\"{k}_meta_dict\"] = meta_k\n",
191-
" return item\n",
192-
"\n",
193-
" # create new cache\n",
194-
" _item_transformed = self._pre_transform(deepcopy(item_transformed)) # keep the original hashed\n",
195-
" if hashfile is None:\n",
196-
" return _item_transformed\n",
197-
"\n",
198-
" for k in _item_transformed: # {'image': ..., 'label': ...}\n",
199-
" _item_transformed_meta = _item_transformed[k].meta\n",
200-
" _item_transformed_data = _item_transformed[k].array\n",
201-
" _item_transformed_meta[\"shape\"] = _item_transformed_data.shape\n",
202-
" tofile(_item_transformed_data, f\"{hashfile}-{k}\")\n",
203-
" try:\n",
204-
" # NOTE: Writing to a temporary directory and then using a nearly atomic rename operation\n",
205-
" # to make the cache more robust to manual killing of parent process\n",
206-
" # which may leave partially written cache files in an incomplete state\n",
207-
" with tempfile.TemporaryDirectory() as tmpdirname:\n",
208-
" meta_hash_file_name = f\"{hashfile.name}-{k}-meta\"\n",
209-
" meta_hash_file = self.cache_dir / meta_hash_file_name\n",
210-
" temp_hash_file = Path(tmpdirname) / meta_hash_file_name\n",
211-
" torch.save(\n",
212-
" obj=_item_transformed_meta,\n",
213-
" f=temp_hash_file,\n",
214-
" pickle_module=look_up_option(self.pickle_module, SUPPORTED_PICKLE_MOD),\n",
215-
" pickle_protocol=self.pickle_protocol,\n",
216-
" )\n",
217-
" if temp_hash_file.is_file() and not meta_hash_file.is_file():\n",
218-
" # On Unix, if target exists and is a file, it will be replaced silently if the\n",
219-
" # user has permission.\n",
220-
" # for more details: https://docs.python.org/3/library/shutil.html#shutil.move.\n",
221-
" try:\n",
222-
" shutil.move(str(temp_hash_file), meta_hash_file)\n",
223-
" except FileExistsError:\n",
224-
" pass\n",
225-
" except PermissionError: # project-monai/monai issue #3613\n",
226-
" pass\n",
227-
" open(hashfile, \"a\").close() # store cacheid\n",
228-
"\n",
229-
" return _item_transformed"
230-
]
231-
},
232127
{
233128
"cell_type": "markdown",
234129
"metadata": {},
@@ -245,16 +140,16 @@
245140
},
246141
{
247142
"cell_type": "code",
248-
"execution_count": 4,
143+
"execution_count": 3,
249144
"metadata": {},
250145
"outputs": [
251146
{
252147
"name": "stdout",
253148
"output_type": "stream",
254149
"text": [
255-
"2023-07-12 09:26:17,878 - INFO - Expected md5 is None, skip md5 check for file samples.zip.\n",
256-
"2023-07-12 09:26:17,878 - INFO - File exists: samples.zip, skipped downloading.\n",
257-
"2023-07-12 09:26:17,879 - INFO - Writing into directory: /raid/yliu/test_tutorial.\n"
150+
"2023-07-27 07:59:12,054 - INFO - Expected md5 is None, skip md5 check for file samples.zip.\n",
151+
"2023-07-27 07:59:12,055 - INFO - File exists: samples.zip, skipped downloading.\n",
152+
"2023-07-27 07:59:12,056 - INFO - Writing into directory: /raid/yliu/test_tutorial.\n"
258153
]
259154
}
260155
],
@@ -283,7 +178,7 @@
283178
},
284179
{
285180
"cell_type": "code",
286-
"execution_count": 5,
181+
"execution_count": 4,
287182
"metadata": {},
288183
"outputs": [],
289184
"source": [
@@ -299,7 +194,7 @@
299194
},
300195
{
301196
"cell_type": "code",
302-
"execution_count": 6,
197+
"execution_count": 5,
303198
"metadata": {},
304199
"outputs": [],
305200
"source": [
@@ -332,19 +227,19 @@
332227
},
333228
{
334229
"cell_type": "code",
335-
"execution_count": 7,
230+
"execution_count": 6,
336231
"metadata": {},
337232
"outputs": [
338233
{
339234
"name": "stdout",
340235
"output_type": "stream",
341236
"text": [
342-
"epoch0 time 19.746733903884888\n",
343-
"epoch1 time 0.9976603984832764\n",
344-
"epoch2 time 0.982248067855835\n",
345-
"epoch3 time 0.9838874340057373\n",
346-
"epoch4 time 0.9793403148651123\n",
347-
"total time 23.69102692604065\n"
237+
"epoch0 time 20.148560762405396\n",
238+
"epoch1 time 0.9835140705108643\n",
239+
"epoch2 time 0.9708101749420166\n",
240+
"epoch3 time 0.9711742401123047\n",
241+
"epoch4 time 0.9711296558380127\n",
242+
"total time 24.04619812965393\n"
348243
]
349244
}
350245
],
@@ -372,19 +267,19 @@
372267
},
373268
{
374269
"cell_type": "code",
375-
"execution_count": 8,
270+
"execution_count": 7,
376271
"metadata": {},
377272
"outputs": [
378273
{
379274
"name": "stdout",
380275
"output_type": "stream",
381276
"text": [
382-
"epoch0 time 21.206729650497437\n",
383-
"epoch1 time 1.510526180267334\n",
384-
"epoch2 time 1.588256597518921\n",
385-
"epoch3 time 1.4431262016296387\n",
386-
"epoch4 time 1.4594802856445312\n",
387-
"total time 27.20927882194519\n"
277+
"epoch0 time 21.170511722564697\n",
278+
"epoch1 time 1.482978105545044\n",
279+
"epoch2 time 1.5378782749176025\n",
280+
"epoch3 time 1.4499244689941406\n",
281+
"epoch4 time 1.4379286766052246\n",
282+
"total time 27.08065962791443\n"
388283
]
389284
}
390285
],

0 commit comments

Comments
 (0)