Skip to content

Commit d51edc3

Browse files
wyliNic-Ma
andauthored
minor usability enh bundle workflow (#5480)
Signed-off-by: Wenqi Li <[email protected]> ### Description this PR includes a few minor improvements - allow string input, such as `device="cuda:1"` for Workflow - skip stats key metric if `key_metric_name` is None (monai/handlers/stats_handler.py) - print incompatible metric values in the warning message (CheckpointSaver) - adds a `epochs` parameter for `ShuffleBuffer` (#5488) - adds a `output_name_formatter` for `SaveImage` (#5508) - simplify torch version in bundle init (#5529) ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [x] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [x] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. Signed-off-by: Wenqi Li <[email protected]> Signed-off-by: Nic Ma <[email protected]> Co-authored-by: Nic Ma <[email protected]>
1 parent b36c8db commit d51edc3

17 files changed

+49
-23
lines changed

docs/source/modules.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -264,7 +264,7 @@ A typical bundle example can include:
264264
┗━ *license.txt
265265
```
266266
Details about the bundle config definition and syntax & examples are at [config syntax](https://docs.monai.io/en/latest/config_syntax.html).
267-
A step-by-step [get started](https://github.com/Project-MONAI/tutorials/blob/master/modules/bundles/get_started.ipynb) tutorial notebook can help users quickly set up a bundle. [[bundle examples](https://github.com/Project-MONAI/tutorials/tree/main/bundle), [model-zoo](https://github.com/Project-MONAI/model-zoo)]
267+
A step-by-step [get started](https://github.com/Project-MONAI/tutorials/blob/master/bundle/get_started.md) tutorial notebook can help users quickly set up a bundle. [[bundle examples](https://github.com/Project-MONAI/tutorials/tree/main/bundle), [model-zoo](https://github.com/Project-MONAI/model-zoo)]
268268

269269
## Federated Learning
270270

monai/bundle/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
"version": "0.0.1",
3333
"changelog": {"0.0.1": "Initial version"},
3434
"monai_version": _conf_values["MONAI"],
35-
"pytorch_version": _conf_values["Pytorch"],
35+
"pytorch_version": str(_conf_values["Pytorch"]).split("+")[0].split("a")[0], # 1.9.0a0+df837d0 or 1.13.0+cu117
3636
"numpy_version": _conf_values["Numpy"],
3737
"optional_packages_version": {},
3838
"task": "Describe what the network predicts",

monai/data/dataset.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -785,6 +785,7 @@ def __init__(
785785
2. to execute `runtime cache` on GPU memory, must co-work with
786786
`monai.data.DataLoader`, and can't work with `monai.data.DistributedSampler`
787787
as GPU Tensor usually can't be shared in the multiprocessing context.
788+
(try ``cache_dataset.disable_share_memory_cache()`` in case of GPU caching issues.)
788789
789790
"""
790791
if not isinstance(transform, Compose):

monai/data/folder_layout.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,19 @@
99
# See the License for the specific language governing permissions and
1010
# limitations under the License.
1111

12+
import monai
1213
from monai.config import PathLike
1314
from monai.data.utils import create_file_basename
1415

15-
__all__ = ["FolderLayout"]
16+
__all__ = ["FolderLayout", "default_name_formatter"]
17+
18+
19+
def default_name_formatter(metadict, saver):
20+
"""Returns a kwargs dict for :py:meth:`FolderLayout.filename`,
21+
according to the input metadata and SaveImage transform."""
22+
subject = metadict[monai.utils.ImageMetaKey.FILENAME_OR_OBJ] if metadict else getattr(saver, "_data_index", 0)
23+
patch_index = metadict.get(monai.utils.ImageMetaKey.PATCH_INDEX, None) if metadict else None
24+
return {"subject": f"{subject}", "idx": patch_index}
1625

1726

1827
class FolderLayout:

monai/data/iterable_dataset.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ class ShuffleBuffer(Randomizable, IterableDataset):
7171
seed: random seed to initialize the random state of all workers, set `seed += 1` in
7272
every iter() call, refer to the PyTorch idea:
7373
https://github.com/pytorch/pytorch/blob/v1.10.0/torch/utils/data/distributed.py#L98.
74+
epochs: number of epochs to iterate over the dataset, default to 1, -1 means infinite epochs.
7475
7576
Note:
7677
Both ``monai.data.DataLoader`` and ``torch.utils.data.DataLoader`` do not seed this class (as a subclass of
@@ -93,10 +94,11 @@ def run():
9394
9495
"""
9596

96-
def __init__(self, data, transform=None, buffer_size: int = 512, seed: int = 0) -> None:
97+
def __init__(self, data, transform=None, buffer_size: int = 512, seed: int = 0, epochs: int = 1) -> None:
9798
super().__init__(data=data, transform=transform)
9899
self.size = buffer_size
99100
self.seed = seed
101+
self.epochs = epochs
100102
self._idx = 0
101103

102104
def randomized_pop(self, buffer):
@@ -123,7 +125,8 @@ def __iter__(self):
123125
"""
124126
self.seed += 1
125127
super().set_random_state(seed=self.seed) # make all workers in sync
126-
yield from IterableDataset(self.generate_item(), transform=self.transform)
128+
for _ in range(self.epochs) if self.epochs >= 0 else iter(int, 1):
129+
yield from IterableDataset(self.generate_item(), transform=self.transform)
127130

128131
def randomize(self, size: int) -> None:
129132
self._idx = self.R.randint(size)

monai/engines/evaluator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ class Evaluator(Workflow):
8686

8787
def __init__(
8888
self,
89-
device: torch.device,
89+
device: torch.device | str,
9090
val_data_loader: Iterable | DataLoader,
9191
epoch_length: int | None = None,
9292
non_blocking: bool = False,

monai/engines/multi_gpu_supervised_trainer.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
# See the License for the specific language governing permissions and
1010
# limitations under the License.
1111

12-
from typing import TYPE_CHECKING, Callable, Dict, Optional, Sequence, Tuple
12+
from typing import TYPE_CHECKING, Callable, Dict, Optional, Sequence, Tuple, Union
1313

1414
import torch
1515
import torch.nn
@@ -55,7 +55,7 @@ def create_multigpu_supervised_trainer(
5555
net: torch.nn.Module,
5656
optimizer: Optimizer,
5757
loss_fn: Callable,
58-
devices: Optional[Sequence[torch.device]] = None,
58+
devices: Optional[Sequence[Union[str, torch.device]]] = None,
5959
non_blocking: bool = False,
6060
prepare_batch: Callable = _prepare_batch,
6161
output_transform: Callable = _default_transform,
@@ -105,7 +105,7 @@ def create_multigpu_supervised_trainer(
105105
def create_multigpu_supervised_evaluator(
106106
net: torch.nn.Module,
107107
metrics: Optional[Dict[str, Metric]] = None,
108-
devices: Optional[Sequence[torch.device]] = None,
108+
devices: Optional[Sequence[Union[str, torch.device]]] = None,
109109
non_blocking: bool = False,
110110
prepare_batch: Callable = _prepare_batch,
111111
output_transform: Callable = _default_eval_transform,

monai/engines/trainer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ class SupervisedTrainer(Trainer):
134134

135135
def __init__(
136136
self,
137-
device: torch.device,
137+
device: str | torch.device,
138138
max_epochs: int,
139139
train_data_loader: Iterable | DataLoader,
140140
network: torch.nn.Module,
@@ -304,7 +304,7 @@ class GanTrainer(Trainer):
304304

305305
def __init__(
306306
self,
307-
device: torch.device,
307+
device: str | torch.device,
308308
max_epochs: int,
309309
train_data_loader: DataLoader,
310310
g_network: torch.nn.Module,

monai/engines/utils.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ class IterationEvents(EventEnum):
6060
INNER_ITERATION_COMPLETED = "inner_iteration_completed"
6161

6262

63-
def get_devices_spec(devices: Optional[Sequence[torch.device]] = None) -> List[torch.device]:
63+
def get_devices_spec(devices: Optional[Sequence[Union[torch.device, str]]] = None) -> List[torch.device]:
6464
"""
6565
Get a valid specification for one or more devices. If `devices` is None get devices for all CUDA devices available.
6666
If `devices` is and zero-length structure a single CPU compute device is returned. In any other cases `devices` is
@@ -88,7 +88,8 @@ def get_devices_spec(devices: Optional[Sequence[torch.device]] = None) -> List[t
8888
else:
8989
devices = list(devices)
9090

91-
return devices
91+
devices = [torch.device(d) if isinstance(d, str) else d for d in devices]
92+
return devices # type: ignore
9293

9394

9495
def default_prepare_batch(

monai/engines/workflow.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,6 @@ class Workflow(IgniteEngine): # type: ignore[valid-type, misc] # due to optiona
9393
https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.autocast.
9494
9595
Raises:
96-
TypeError: When ``device`` is not a ``torch.Device``.
9796
TypeError: When ``data_loader`` is not a ``torch.utils.data.DataLoader``.
9897
TypeError: When ``key_metric`` is not a ``Optional[dict]``.
9998
TypeError: When ``additional_metrics`` is not a ``Optional[dict]``.
@@ -102,7 +101,7 @@ class Workflow(IgniteEngine): # type: ignore[valid-type, misc] # due to optiona
102101

103102
def __init__(
104103
self,
105-
device: torch.device,
104+
device: Union[torch.device, str],
106105
max_epochs: int,
107106
data_loader: Union[Iterable, DataLoader],
108107
epoch_length: Optional[int] = None,
@@ -125,8 +124,6 @@ def __init__(
125124
super().__init__(iteration_update)
126125
else:
127126
super().__init__(self._iteration)
128-
if not isinstance(device, torch.device):
129-
raise TypeError(f"Device must be a torch.device but is {type(device).__name__}.")
130127

131128
if isinstance(data_loader, DataLoader):
132129
sampler = data_loader.__dict__["sampler"]
@@ -155,7 +152,7 @@ def set_sampler_epoch(engine: Engine):
155152
metrics={},
156153
metric_details={},
157154
dataloader=None,
158-
device=device,
155+
device=device if isinstance(device, torch.device) or device is None else torch.device(device),
159156
key_metric_name=None, # we can set many metrics, only use key_metric to compare and save the best model
160157
best_metric=-1,
161158
best_metric_epoch=-1,

monai/handlers/checkpoint_saver.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,7 @@ def _score_func(engine: Engine):
165165
warnings.warn(
166166
"key metric is not a scalar value, skip metric comparison and don't save a model."
167167
"please use other metrics as key metric, or change the `reduction` mode to 'mean'."
168+
f"got metric: {metric_name}={metric}."
168169
)
169170
return -1
170171
return (-1 if key_metric_negative_sign else 1) * metric

monai/handlers/stats_handler.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,7 @@ def _default_epoch_print(self, engine: Engine) -> None:
204204
hasattr(engine.state, "key_metric_name")
205205
and hasattr(engine.state, "best_metric")
206206
and hasattr(engine.state, "best_metric_epoch")
207+
and engine.state.key_metric_name is not None # type: ignore
207208
):
208209
out_str = f"Key metric: {engine.state.key_metric_name} "
209210
out_str += f"best value: {engine.state.best_metric} "

monai/transforms/io/array.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727

2828
from monai.config import DtypeLike, NdarrayOrTensor, PathLike
2929
from monai.data import image_writer
30-
from monai.data.folder_layout import FolderLayout
30+
from monai.data.folder_layout import FolderLayout, default_name_formatter
3131
from monai.data.image_reader import (
3232
ImageReader,
3333
ITKReader,
@@ -340,6 +340,8 @@ class SaveImage(Transform):
340340
the supported built-in writer classes are ``"NibabelWriter"``, ``"ITKWriter"``, ``"PILWriter"``.
341341
channel_dim: the index of the channel dimension. Default to `0`.
342342
`None` to indicate no channel dimension.
343+
output_name_formatter: a callable function (returning a kwargs dict) to format the output file name.
344+
see also: :py:func:`monai.data.folder_layout.default_name_formatter`.
343345
"""
344346

345347
def __init__(
@@ -360,6 +362,7 @@ def __init__(
360362
output_format: str = "",
361363
writer: Union[Type[image_writer.ImageWriter], str, None] = None,
362364
channel_dim: Optional[int] = 0,
365+
output_name_formatter=None,
363366
) -> None:
364367
self.folder_layout = FolderLayout(
365368
output_dir=output_dir,
@@ -390,6 +393,7 @@ def __init__(
390393
self.data_kwargs = {"squeeze_end_dims": squeeze_end_dims, "channel_dim": channel_dim}
391394
self.meta_kwargs = {"resample": resample, "mode": mode, "padding_mode": padding_mode, "dtype": dtype}
392395
self.write_kwargs = {"verbose": print_log}
396+
self.fname_formatter = default_name_formatter if output_name_formatter is None else output_name_formatter
393397
self._data_index = 0
394398

395399
def set_options(self, init_kwargs=None, data_kwargs=None, meta_kwargs=None, write_kwargs=None):
@@ -420,9 +424,8 @@ def __call__(self, img: Union[torch.Tensor, np.ndarray], meta_data: Optional[Dic
420424
meta_data: key-value pairs of metadata corresponding to the data.
421425
"""
422426
meta_data = img.meta if isinstance(img, MetaTensor) else meta_data
423-
subject = meta_data[Key.FILENAME_OR_OBJ] if meta_data else str(self._data_index)
424-
patch_index = meta_data.get(Key.PATCH_INDEX, None) if meta_data else None
425-
filename = self.folder_layout.filename(subject=f"{subject}", idx=patch_index)
427+
kw = self.fname_formatter(meta_data, self)
428+
filename = self.folder_layout.filename(**kw)
426429
if meta_data and len(ensure_tuple(meta_data.get("spatial_shape", ()))) == len(img.shape):
427430
self.data_kwargs["channel_dim"] = None
428431

monai/transforms/io/dictionary.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,8 @@ class SaveImaged(MapTransform):
233233
if `None`, use the default writer from `monai.data.image_writer` according to `output_ext`.
234234
if it's a string, it's treated as a class name or dotted path;
235235
the supported built-in writer classes are ``"NibabelWriter"``, ``"ITKWriter"``, ``"PILWriter"``.
236+
output_name_formatter: a callable function (returning a kwargs dict) to format the output file name.
237+
see also: :py:func:`monai.data.folder_layout.default_name_formatter`.
236238
237239
"""
238240

@@ -257,6 +259,7 @@ def __init__(
257259
print_log: bool = True,
258260
output_format: str = "",
259261
writer: Union[Type[image_writer.ImageWriter], str, None] = None,
262+
output_name_formatter=None,
260263
) -> None:
261264
super().__init__(keys, allow_missing_keys)
262265
self.meta_keys = ensure_tuple_rep(meta_keys, len(self.keys))
@@ -277,6 +280,7 @@ def __init__(
277280
print_log=print_log,
278281
output_format=output_format,
279282
writer=writer,
283+
output_name_formatter=output_name_formatter,
280284
)
281285

282286
def set_options(self, init_kwargs=None, data_kwargs=None, meta_kwargs=None, write_kwargs=None):

tests/test_deepedit_interaction.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ def run_interaction(self, train):
9292

9393
# set up engine
9494
engine = SupervisedTrainer(
95-
device=torch.device("cpu"),
95+
device="cpu",
9696
max_epochs=1,
9797
train_data_loader=data_loader,
9898
network=network,

tests/test_save_image.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ def test_saved_content(self, test_data, meta_data, output_ext, resample):
4949
output_ext=output_ext,
5050
resample=resample,
5151
separate_folder=False, # test saving into the same folder
52+
output_name_formatter=lambda x, xform: dict(subject=x["filename_or_obj"] if x else "0"),
5253
)
5354
trans(test_data)
5455

tests/test_shuffle_buffer.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,11 @@ def test_shape(self):
3737
np.testing.assert_allclose(output, [[2, 3], [1, 4]], err_msg=f"seed {buffer.seed}")
3838
np.testing.assert_allclose(output2, [[1, 4], [2, 3]], err_msg=f"seed {buffer.seed}")
3939

40+
def test_epochs(self):
41+
buffer = ShuffleBuffer([1, 2, 3, 4], seed=0, epochs=2)
42+
output = [convert_data_type(x, np.ndarray)[0] for x in DataLoader(dataset=buffer, batch_size=2)]
43+
np.testing.assert_allclose(output, [[2, 1], [3, 4], [4, 2], [3, 1]])
44+
4045

4146
if __name__ == "__main__":
4247
unittest.main()

0 commit comments

Comments
 (0)