|
2 | 2 | import torch
|
3 | 3 |
|
4 | 4 | from ocf_datapipes.batch import BatchKey, NumpyBatch, TensorBatch
|
5 |
| -from ocf_datapipes.batch import copy_batch_to_device, batch_to_tensor |
| 5 | +from ocf_datapipes.batch.utils import copy_batch_to_device, batch_to_tensor |
| 6 | +import pytest |
6 | 7 |
|
7 |
| - |
8 |
| -def _create_test_batch() -> NumpyBatch: |
| 8 | +@pytest.fixture() |
| 9 | +def sample_numpy_batch(): |
9 | 10 | sample: NumpyBatch = {}
|
10 | 11 | sample[BatchKey.satellite_actual] = np.full((12, 10, 24, 24), 0)
|
11 | 12 | return sample
|
12 | 13 |
|
13 | 14 |
|
14 |
| -def test_batch_to_tensor() -> None: |
15 |
| - batch: NumpyBatch = _create_test_batch() |
16 |
| - tensor_batch = batch_to_tensor(batch) |
| 15 | +def test_batch_to_tensor(sample_numpy_batch): |
| 16 | + tensor_batch = batch_to_tensor(sample_numpy_batch) |
17 | 17 | assert isinstance(tensor_batch[BatchKey.satellite_actual], torch.Tensor)
|
18 | 18 |
|
19 | 19 |
|
20 |
| -def test_copy_batch_to_device() -> None: |
21 |
| - batch = _create_test_batch() |
22 |
| - tensor_batch = batch_to_tensor(batch) |
| 20 | +def test_copy_batch_to_device(sample_numpy_batch): |
| 21 | + tensor_batch = batch_to_tensor(sample_numpy_batch) |
23 | 22 | device = torch.device("cpu")
|
24 | 23 | batch_copy: TensorBatch = copy_batch_to_device(tensor_batch, device)
|
25 | 24 | assert batch_copy[BatchKey.satellite_actual].device == device # type: ignore
|
26 |
| - |
27 |
| - |
28 |
| -if __name__ == "__main__": |
29 |
| - test_batch_to_tensor() |
30 |
| - test_copy_batch_to_device() |
0 commit comments