Skip to content
This repository was archived by the owner on Jun 2, 2025. It is now read-only.

Commit 43af752

Browse files
committed
Wrap batch utility functions to enforce typing
1 parent 9f9d2de commit 43af752

File tree

2 files changed

+49
-15
lines changed

2 files changed

+49
-15
lines changed

ocf_datapipes/batch/utils.py

Lines changed: 38 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,15 @@
22
import numpy as np
33
import torch
44

5+
from ocf_datapipes.batch import NumpyBatch, TensorBatch
56

6-
def copy_batch_to_device(batch: dict, device: torch.device) -> dict:
7+
8+
def _copy_batch_to_device(batch: dict, device: torch.device) -> dict:
79
"""
8-
Moves a dict-batch of tensors to new device.
10+
Moves tensor leaves in a nested dict to a new device
911
1012
Args:
11-
batch: dict with tensors to move
13+
batch: nested dict with tensors to move
1214
device: Device to move tensors to
1315
1416
Returns:
@@ -19,28 +21,55 @@ def copy_batch_to_device(batch: dict, device: torch.device) -> dict:
1921
for k, v in batch.items():
2022
if isinstance(v, dict):
2123
# Recursion to reach the nested NWP
22-
batch_copy[k] = copy_batch_to_device(v, device)
24+
batch_copy[k] = _copy_batch_to_device(v, device)
2325
elif isinstance(v, torch.Tensor):
2426
batch_copy[k] = v.to(device)
2527
else:
2628
batch_copy[k] = v
2729
return batch_copy
2830

2931

30-
def batch_to_tensor(batch: dict) -> dict:
32+
def copy_batch_to_device(batch: TensorBatch, device: torch.device) -> TensorBatch:
33+
"""
34+
Moves the tensors in a TensorBatch to a new device.
35+
36+
Args:
37+
batch: TensorBatch with tensors to move
38+
device: Device to move tensors to
39+
40+
Returns:
41+
TensorBatch with tensors moved to new device
42+
"""
43+
return _copy_batch_to_device(batch, device)
44+
45+
46+
def _batch_to_tensor(batch: dict) -> dict:
3147
"""
32-
Moves numpy batch to a tensor
48+
Moves ndarrays in a nested dict to torch tensors
3349
3450
Args:
35-
batch: dict-like batch with data in numpy arrays
51+
batch: nested dict with data in numpy arrays
3652
3753
Returns:
38-
A batch with data in torch tensors
54+
Nested dict with data in torch tensors
3955
"""
4056
for k, v in batch.items():
4157
if isinstance(v, dict):
4258
# Recursion to reach the nested NWP
43-
batch[k] = batch_to_tensor(v)
59+
batch[k] = _batch_to_tensor(v)
4460
elif isinstance(v, np.ndarray) and np.issubdtype(v.dtype, np.number):
4561
batch[k] = torch.as_tensor(v)
4662
return batch
63+
64+
65+
def batch_to_tensor(batch: NumpyBatch) -> TensorBatch:
66+
"""
67+
Moves data in a NumpyBatch to a TensorBatch
68+
69+
Args:
70+
batch: NumpyBatch with data in numpy arrays
71+
72+
Returns:
73+
TensorBatch with data in torch tensors
74+
"""
75+
return _batch_to_tensor(batch)

tests/batch/test_utils.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import numpy as np
22
import torch
33

4-
from ocf_datapipes.batch import BatchKey, NumpyBatch
4+
from ocf_datapipes.batch import BatchKey, NumpyBatch, TensorBatch
55
from ocf_datapipes.batch import copy_batch_to_device, batch_to_tensor
66

77

@@ -11,15 +11,20 @@ def _create_test_batch() -> NumpyBatch:
1111
return sample
1212

1313

14-
def test_batch_to_tensor():
15-
batch = _create_test_batch()
14+
def test_batch_to_tensor() -> None:
15+
batch: NumpyBatch = _create_test_batch()
1616
tensor_batch = batch_to_tensor(batch)
1717
assert isinstance(tensor_batch[BatchKey.satellite_actual], torch.Tensor)
1818

1919

20-
def test_copy_batch_to_device():
20+
def test_copy_batch_to_device() -> None:
2121
batch = _create_test_batch()
2222
tensor_batch = batch_to_tensor(batch)
2323
device = torch.device("cpu")
24-
batch_copy = copy_batch_to_device(tensor_batch, device)
25-
assert batch_copy[BatchKey.satellite_actual].device == device
24+
batch_copy: TensorBatch = copy_batch_to_device(tensor_batch, device)
25+
assert batch_copy[BatchKey.satellite_actual].device == device # type: ignore
26+
27+
28+
if __name__ == "__main__":
29+
test_batch_to_tensor()
30+
test_copy_batch_to_device()

0 commit comments

Comments
 (0)