Skip to content
This repository was archived by the owner on Jun 2, 2025. It is now read-only.

Commit 2e8b064

Browse files
authored
Merge pull request #287 from markus-kreft/utils_from_pvnet
Add NumpyBatch utils from PVNet
2 parents a066adc + 53a2e6d commit 2e8b064

File tree

6 files changed

+111
-4
lines changed

6 files changed

+111
-4
lines changed

ocf_datapipes/batch/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
"""Datapipes for batching together data"""
2-
from .batches import BatchKey, NumpyBatch, NWPBatchKey, NWPNumpyBatch, XarrayBatch
2+
from .batches import BatchKey, NumpyBatch, NWPBatchKey, NWPNumpyBatch, TensorBatch, XarrayBatch
33
from .merge_numpy_examples_to_batch import (
44
MergeNumpyBatchIterDataPipe as MergeNumpyBatch,
55
)
@@ -12,3 +12,4 @@
1212
)
1313
from .merge_numpy_modalities import MergeNumpyModalitiesIterDataPipe as MergeNumpyModalities
1414
from .merge_numpy_modalities import MergeNWPNumpyModalitiesIterDataPipe as MergeNWPNumpyModalities
15+
from .utils import batch_to_tensor, copy_batch_to_device

ocf_datapipes/batch/batches.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from typing import Union
55

66
import numpy as np
7+
import torch
78
import xarray as xr
89

910

@@ -229,3 +230,5 @@ class NWPBatchKey(Enum):
229230
NumpyBatch = dict[BatchKey, Union[np.ndarray, dict[str, NWPNumpyBatch]]]
230231

231232
XarrayBatch = dict[BatchKey, Union[xr.DataArray, xr.Dataset]]
233+
234+
TensorBatch = dict[BatchKey, Union[torch.Tensor, dict[str, dict[NWPBatchKey, torch.Tensor]]]]

ocf_datapipes/batch/utils.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
"""Additional utils for working with batches"""
2+
import numpy as np
3+
import torch
4+
5+
from ocf_datapipes.batch import NumpyBatch, TensorBatch
6+
7+
8+
def _copy_batch_to_device(batch: dict, device: torch.device) -> dict:
9+
"""
10+
Moves tensor leaves in a nested dict to a new device
11+
12+
Args:
13+
batch: nested dict with tensors to move
14+
device: Device to move tensors to
15+
16+
Returns:
17+
A dict with tensors moved to new device
18+
"""
19+
batch_copy = {}
20+
21+
for k, v in batch.items():
22+
if isinstance(v, dict):
23+
# Recursion to reach the nested NWP
24+
batch_copy[k] = _copy_batch_to_device(v, device)
25+
elif isinstance(v, torch.Tensor):
26+
batch_copy[k] = v.to(device)
27+
else:
28+
batch_copy[k] = v
29+
return batch_copy
30+
31+
32+
def copy_batch_to_device(batch: TensorBatch, device: torch.device) -> TensorBatch:
33+
"""
34+
Moves the tensors in a TensorBatch to a new device.
35+
36+
Args:
37+
batch: TensorBatch with tensors to move
38+
device: Device to move tensors to
39+
40+
Returns:
41+
TensorBatch with tensors moved to new device
42+
"""
43+
return _copy_batch_to_device(batch, device)
44+
45+
46+
def _batch_to_tensor(batch: dict) -> dict:
47+
"""
48+
Moves ndarrays in a nested dict to torch tensors
49+
50+
Args:
51+
batch: nested dict with data in numpy arrays
52+
53+
Returns:
54+
Nested dict with data in torch tensors
55+
"""
56+
for k, v in batch.items():
57+
if isinstance(v, dict):
58+
# Recursion to reach the nested NWP
59+
batch[k] = _batch_to_tensor(v)
60+
elif isinstance(v, np.ndarray) and np.issubdtype(v.dtype, np.number):
61+
batch[k] = torch.as_tensor(v)
62+
return batch
63+
64+
65+
def batch_to_tensor(batch: NumpyBatch) -> TensorBatch:
66+
"""
67+
Moves data in a NumpyBatch to a TensorBatch
68+
69+
Args:
70+
batch: NumpyBatch with data in numpy arrays
71+
72+
Returns:
73+
TensorBatch with data in torch tensors
74+
"""
75+
return _batch_to_tensor(batch)

ocf_datapipes/training/pvnet_site.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,8 @@
66
import xarray as xr
77
from torch.utils.data import IterDataPipe, functional_datapipe
88
from torch.utils.data.datapipes.iter import IterableWrapper
9-
from ocf_datapipes.batch import BatchKey, NumpyBatch
109

11-
from ocf_datapipes.batch import MergeNumpyModalities, MergeNWPNumpyModalities
10+
from ocf_datapipes.batch import BatchKey, MergeNumpyModalities, MergeNWPNumpyModalities
1211
from ocf_datapipes.training.common import (
1312
DatapipeKeyForker,
1413
_get_datapipes_dict,

ocf_datapipes/training/windnet.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,6 @@ def __init__(self, filenames: List[str], keys: List[str]):
123123

124124
def __iter__(self):
125125
"""Iterate through each filename, loading it, uncombining it, and then yielding it"""
126-
import numpy as np
127126

128127
while True:
129128
for filename in self.filenames:

tests/batch/test_utils.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
import numpy as np
2+
import torch
3+
4+
from ocf_datapipes.batch import BatchKey, NumpyBatch, TensorBatch
5+
from ocf_datapipes.batch import copy_batch_to_device, batch_to_tensor
6+
7+
8+
def _create_test_batch() -> NumpyBatch:
9+
sample: NumpyBatch = {}
10+
sample[BatchKey.satellite_actual] = np.full((12, 10, 24, 24), 0)
11+
return sample
12+
13+
14+
def test_batch_to_tensor() -> None:
15+
batch: NumpyBatch = _create_test_batch()
16+
tensor_batch = batch_to_tensor(batch)
17+
assert isinstance(tensor_batch[BatchKey.satellite_actual], torch.Tensor)
18+
19+
20+
def test_copy_batch_to_device() -> None:
21+
batch = _create_test_batch()
22+
tensor_batch = batch_to_tensor(batch)
23+
device = torch.device("cpu")
24+
batch_copy: TensorBatch = copy_batch_to_device(tensor_batch, device)
25+
assert batch_copy[BatchKey.satellite_actual].device == device # type: ignore
26+
27+
28+
if __name__ == "__main__":
29+
test_batch_to_tensor()
30+
test_copy_batch_to_device()

0 commit comments

Comments
 (0)