Skip to content
This repository was archived by the owner on Jun 2, 2025. It is now read-only.

Commit 275d68b

Browse files
committed
Merge commit '7f6b13cbdbd38c6ecd5ed96d89e6fbdcb9b20561' into issue/fix-tests
2 parents 681daa8 + 7f6b13c commit 275d68b

File tree

2 files changed

+8
-14
lines changed

2 files changed

+8
-14
lines changed

tests/batch/test_utils.py

Lines changed: 8 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2,29 +2,23 @@
22
import torch
33

44
from ocf_datapipes.batch import BatchKey, NumpyBatch, TensorBatch
5-
from ocf_datapipes.batch import copy_batch_to_device, batch_to_tensor
5+
from ocf_datapipes.batch.utils import copy_batch_to_device, batch_to_tensor
6+
import pytest
67

7-
8-
def _create_test_batch() -> NumpyBatch:
8+
@pytest.fixture()
9+
def sample_numpy_batch():
910
sample: NumpyBatch = {}
1011
sample[BatchKey.satellite_actual] = np.full((12, 10, 24, 24), 0)
1112
return sample
1213

1314

14-
def test_batch_to_tensor() -> None:
15-
batch: NumpyBatch = _create_test_batch()
16-
tensor_batch = batch_to_tensor(batch)
15+
def test_batch_to_tensor(sample_numpy_batch):
16+
tensor_batch = batch_to_tensor(sample_numpy_batch)
1717
assert isinstance(tensor_batch[BatchKey.satellite_actual], torch.Tensor)
1818

1919

20-
def test_copy_batch_to_device() -> None:
21-
batch = _create_test_batch()
22-
tensor_batch = batch_to_tensor(batch)
20+
def test_copy_batch_to_device(sample_numpy_batch):
21+
tensor_batch = batch_to_tensor(sample_numpy_batch)
2322
device = torch.device("cpu")
2423
batch_copy: TensorBatch = copy_batch_to_device(tensor_batch, device)
2524
assert batch_copy[BatchKey.satellite_actual].device == device # type: ignore
26-
27-
28-
if __name__ == "__main__":
29-
test_batch_to_tensor()
30-
test_copy_batch_to_device()

tests/utils/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)