Skip to content
This repository was archived by the owner on Jun 2, 2025. It is now read-only.

Commit 5d9b1e3

Browse files
Merge pull request #295 from openclimatefix/issue/fix-tests
add if statement
2 parents c2c363e + 282677f commit 5d9b1e3

File tree

3 files changed

+17
-19
lines changed

3 files changed

+17
-19
lines changed

ocf_datapipes/training/windnet.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -128,12 +128,15 @@ def __iter__(self):
128128
for filename in self.filenames:
129129
dataset = xr.open_dataset(filename)
130130
datasets = uncombine_from_single_dataset(dataset)
131-
# print(datasets)
132-
datasets["nwp"]["ecmwf"] = potentially_coarsen(datasets["nwp"]["ecmwf"])
133-
# Select the specific keys desired
134-
datasets["nwp"]["ecmwf"] = datasets["nwp"]["ecmwf"].sel(
135-
channel=["u10", "u100", "u200", "v10", "v100", "v200"]
136-
)
131+
132+
if "ecmwf" in datasets["nwp"]:
133+
datasets["nwp"]["ecmwf"] = potentially_coarsen(datasets["nwp"]["ecmwf"])
134+
135+
# Select the specific keys desired
136+
datasets["nwp"]["ecmwf"] = datasets["nwp"]["ecmwf"].sel(
137+
channel=["u10", "u100", "u200", "v10", "v100", "v200"]
138+
)
139+
137140
# Yield a dictionary of the data, using the keys in self.keys
138141
# print(datasets)
139142
dataset_dict = {}

tests/batch/test_utils.py

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2,29 +2,24 @@
22
import torch
33

44
from ocf_datapipes.batch import BatchKey, NumpyBatch, TensorBatch
5-
from ocf_datapipes.batch import copy_batch_to_device, batch_to_tensor
5+
from ocf_datapipes.batch.utils import copy_batch_to_device, batch_to_tensor
6+
import pytest
67

78

8-
def _create_test_batch() -> NumpyBatch:
9+
@pytest.fixture()
10+
def sample_numpy_batch():
911
sample: NumpyBatch = {}
1012
sample[BatchKey.satellite_actual] = np.full((12, 10, 24, 24), 0)
1113
return sample
1214

1315

14-
def test_batch_to_tensor() -> None:
15-
batch: NumpyBatch = _create_test_batch()
16-
tensor_batch = batch_to_tensor(batch)
16+
def test_batch_to_tensor(sample_numpy_batch):
17+
tensor_batch = batch_to_tensor(sample_numpy_batch)
1718
assert isinstance(tensor_batch[BatchKey.satellite_actual], torch.Tensor)
1819

1920

20-
def test_copy_batch_to_device() -> None:
21-
batch = _create_test_batch()
22-
tensor_batch = batch_to_tensor(batch)
21+
def test_copy_batch_to_device(sample_numpy_batch):
22+
tensor_batch = batch_to_tensor(sample_numpy_batch)
2323
device = torch.device("cpu")
2424
batch_copy: TensorBatch = copy_batch_to_device(tensor_batch, device)
2525
assert batch_copy[BatchKey.satellite_actual].device == device # type: ignore
26-
27-
28-
if __name__ == "__main__":
29-
test_batch_to_tensor()
30-
test_copy_batch_to_device()

tests/utils/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)