Skip to content
This repository was archived by the owner on Jun 2, 2025. It is now read-only.

Commit 05aab77

Browse files
committed
Add vars and dims tests
1 parent 09f1b3d commit 05aab77

File tree

2 files changed

+40
-16
lines changed

2 files changed

+40
-16
lines changed

ocf_datapipes/validation/check_vars_and_dims.py

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@ class CheckVarsAndDimsIterDataPipe(IterDataPipe):
1313
def __init__(
1414
self,
1515
source_datapipe: IterDataPipe,
16-
expected_dimensions: Iterable[str],
17-
expected_data_vars: Iterable[str],
16+
expected_dimensions: Optional[Iterable[str]] = None,
17+
expected_data_vars: Optional[Iterable[str]] = None,
1818
dataset_name: Optional[str] = None,
1919
):
2020
"""
@@ -41,19 +41,23 @@ def __iter__(self) -> Union[xr.DataArray, xr.Dataset]:
4141
"""
4242
for xr_data in self.source_datapipe:
4343
if self.dataset_name is None:
44-
xr_data = validate_data_vars(xr_data, self.expected_data_vars)
45-
xr_data = validate_dims(xr_data, self.expected_dimensions)
46-
xr_data = validate_coords(xr_data, self.expected_dimensions)
44+
if self.expected_data_vars is not None:
45+
xr_data = validate_data_vars(xr_data, self.expected_data_vars)
46+
if self.expected_dimensions is not None:
47+
xr_data = validate_dims(xr_data, self.expected_dimensions)
48+
xr_data = validate_coords(xr_data, self.expected_dimensions)
4749
else:
48-
xr_data[self.dataset_name] = validate_data_vars(
49-
xr_data[self.dataset_name], self.expected_data_vars
50-
)
51-
xr_data[self.dataset_name] = validate_dims(
52-
xr_data[self.dataset_name], self.expected_dimensions
53-
)
54-
xr_data[self.dataset_name] = validate_coords(
55-
xr_data[self.dataset_name], self.expected_dimensions
56-
)
50+
if self.expected_data_vars is not None:
51+
xr_data[self.dataset_name] = validate_data_vars(
52+
xr_data[self.dataset_name], self.expected_data_vars
53+
)
54+
if self.expected_dimensions is not None:
55+
xr_data[self.dataset_name] = validate_dims(
56+
xr_data[self.dataset_name], self.expected_dimensions
57+
)
58+
xr_data[self.dataset_name] = validate_coords(
59+
xr_data[self.dataset_name], self.expected_dimensions
60+
)
5761
yield xr_data
5862

5963

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,25 @@
11
from ocf_datapipes.validation import CheckVarsAndDims
22

33

4-
def test_check_vars_and_dims():
5-
pass
4+
def test_check_vars_and_dims_gsp(gsp_dp):
5+
gsp_dp = gsp_dp.check_vars_and_dims(expected_dimensions=("time_utc", "gsp_id"))
6+
next(iter(gsp_dp))
7+
8+
def test_check_vars_and_dims_sat(sat_dp):
9+
sat_dp = sat_dp.check_vars_and_dims(expected_dimensions=("time_utc", "channel", "x_geostationary", "y_geostationary"))
10+
next(iter(sat_dp))
11+
12+
def test_check_vars_and_dim_passiv(passiv_dp):
13+
passiv_dp = passiv_dp.check_vars_and_dims(
14+
expected_dimensions=("time_utc", "pv_system_id"))
15+
next(iter(passiv_dp))
16+
17+
def test_check_vars_and_dim_nwp(nwp_dp):
18+
nwp_dp = nwp_dp.check_vars_and_dims(
19+
expected_dimensions=("init_time_utc", "channel", "step", "x_osgb", "y_osgb"))
20+
next(iter(nwp_dp))
21+
22+
def test_check_vars_and_dim_topo(topo_dp):
23+
topo_dp = topo_dp.check_vars_and_dims(
24+
expected_dimensions=("x_osgb", "y_osgb"))
25+
next(iter(topo_dp))

0 commit comments

Comments
 (0)