Skip to content
This repository was archived by the owner on Jun 2, 2025. It is now read-only.

Commit c37ea9f

Browse files
committed
Add tests
1 parent 49dfe54 commit c37ea9f

File tree

1 file changed

+25
-0
lines changed

1 file changed

+25
-0
lines changed

tests/batch/test_utils.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
import numpy as np
2+
import torch
3+
4+
from ocf_datapipes.batch import BatchKey, NumpyBatch
5+
from ocf_datapipes.batch import copy_batch_to_device, batch_to_tensor
6+
7+
8+
def _create_test_batch() -> NumpyBatch:
9+
sample: NumpyBatch = {}
10+
sample[BatchKey.satellite_actual] = np.full((12, 10, 24, 24), 0)
11+
return sample
12+
13+
14+
def test_batch_to_tensor():
15+
batch = _create_test_batch()
16+
tensor_batch = batch_to_tensor(batch)
17+
assert isinstance(tensor_batch[BatchKey.satellite_actual], torch.Tensor)
18+
19+
20+
def test_copy_batch_to_device():
21+
batch = _create_test_batch()
22+
tensor_batch = batch_to_tensor(batch)
23+
device = torch.device("cpu")
24+
batch_copy = copy_batch_to_device(tensor_batch, device)
25+
assert batch_copy[BatchKey.satellite_actual].device == device

0 commit comments

Comments
 (0)