2
2
import numpy as np
3
3
import torch
4
4
5
+ from ocf_datapipes .batch import NumpyBatch , TensorBatch
5
6
6
- def copy_batch_to_device (batch : dict , device : torch .device ) -> dict :
7
+
8
+ def _copy_batch_to_device (batch : dict , device : torch .device ) -> dict :
7
9
"""
8
- Moves a dict-batch of tensors to new device.
10
+ Moves tensor leaves in a nested dict to a new device
9
11
10
12
Args:
11
- batch: dict with tensors to move
13
+ batch: nested dict with tensors to move
12
14
device: Device to move tensors to
13
15
14
16
Returns:
@@ -19,28 +21,55 @@ def copy_batch_to_device(batch: dict, device: torch.device) -> dict:
19
21
for k , v in batch .items ():
20
22
if isinstance (v , dict ):
21
23
# Recursion to reach the nested NWP
22
- batch_copy [k ] = copy_batch_to_device (v , device )
24
+ batch_copy [k ] = _copy_batch_to_device (v , device )
23
25
elif isinstance (v , torch .Tensor ):
24
26
batch_copy [k ] = v .to (device )
25
27
else :
26
28
batch_copy [k ] = v
27
29
return batch_copy
28
30
29
31
30
- def batch_to_tensor (batch : dict ) -> dict :
32
+ def copy_batch_to_device (batch : TensorBatch , device : torch .device ) -> TensorBatch :
33
+ """
34
+ Moves the tensors in a TensorBatch to a new device.
35
+
36
+ Args:
37
+ batch: TensorBatch with tensors to move
38
+ device: Device to move tensors to
39
+
40
+ Returns:
41
+ TensorBatch with tensors moved to new device
42
+ """
43
+ return _copy_batch_to_device (batch , device )
44
+
45
+
46
+ def _batch_to_tensor (batch : dict ) -> dict :
31
47
"""
32
- Moves numpy batch to a tensor
48
+ Moves ndarrays in a nested dict to torch tensors
33
49
34
50
Args:
35
- batch: dict-like batch with data in numpy arrays
51
+ batch: nested dict with data in numpy arrays
36
52
37
53
Returns:
38
- A batch with data in torch tensors
54
+ Nested dict with data in torch tensors
39
55
"""
40
56
for k , v in batch .items ():
41
57
if isinstance (v , dict ):
42
58
# Recursion to reach the nested NWP
43
- batch [k ] = batch_to_tensor (v )
59
+ batch [k ] = _batch_to_tensor (v )
44
60
elif isinstance (v , np .ndarray ) and np .issubdtype (v .dtype , np .number ):
45
61
batch [k ] = torch .as_tensor (v )
46
62
return batch
63
+
64
+
65
+ def batch_to_tensor (batch : NumpyBatch ) -> TensorBatch :
66
+ """
67
+ Moves data in a NumpyBatch to a TensorBatch
68
+
69
+ Args:
70
+ batch: NumpyBatch with data in numpy arrays
71
+
72
+ Returns:
73
+ TensorBatch with data in torch tensors
74
+ """
75
+ return _batch_to_tensor (batch )
0 commit comments