Skip to content
This repository was archived by the owner on Jun 2, 2025. It is now read-only.

Commit 8055507

Browse files
authored
Merge pull request #40 from openclimatefix/jacob/more-tests
Add more tests/linting updates
2 parents 2c26c02 + d9a6295 commit 8055507

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

51 files changed

+536
-174
lines changed

.github/workflows/workflows.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,5 +14,5 @@ jobs:
1414
pytest_cov_dir: "ocf_datapipes"
1515
# extra things to install
1616
sudo_apt_install: "libgeos++-dev libproj-dev proj-data proj-bin"
17-
# brew_install: "proj geos librttopo"
18-
os_list: '["ubuntu-latest"]'
17+
# brew_install: "proj geos librttopo"
18+
os_list: '["ubuntu-latest"]'

ocf_datapipes/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
"""Datapipes"""

ocf_datapipes/batch/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
"""Datapipes for batching together data"""
12
from .merge_numpy_examples_to_batch import (
23
MergeNumpyExamplesToBatchIterDataPipe as MergeNumpyExamplesToBatch,
34
)

ocf_datapipes/config/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
"""Configuration model"""

ocf_datapipes/convert/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
"""Various conversion datapipes"""
12
from .coordinates import ConvertGeostationaryToLatLonIterDataPipe as ConvertGeostationaryToLatLon
23
from .coordinates import ConvertLatLonToOSGBIterDataPipe as ConvertLatLonToOSGB
34
from .coordinates import ConvertOSGBToLatLonIterDataPipe as ConvertOSGBToLatLon

ocf_datapipes/convert/coordinates.py

Lines changed: 33 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
from typing import Union
2+
3+
import xarray as xr
14
from torchdata.datapipes import functional_datapipe
25
from torchdata.datapipes.iter import IterDataPipe
36

@@ -10,10 +13,19 @@
1013

1114
@functional_datapipe("convert_latlon_to_osgb")
1215
class ConvertLatLonToOSGBIterDataPipe(IterDataPipe):
16+
"""Convert from Lat/Lon object to OSGB"""
17+
1318
def __init__(self, source_datapipe: IterDataPipe):
19+
"""
20+
Convert from Lat/Lon to OSGB
21+
22+
Args:
23+
source_datapipe: Datapipe emitting Xarray objects with latitude and longitude data
24+
"""
1425
self.source_datapipe = source_datapipe
1526

16-
def __iter__(self):
27+
def __iter__(self) -> Union[xr.DataArray, xr.Dataset]:
28+
"""Convert from Lat/Lon to OSGB"""
1729
for xr_data in self.source_datapipe:
1830
xr_data["x_osgb"], xr_data["y_osgb"] = lat_lon_to_osgb(
1931
latitude=xr_data["latitude"], longitude=xr_data["longitude"]
@@ -23,10 +35,19 @@ def __iter__(self):
2335

2436
@functional_datapipe("convert_osgb_to_latlon")
2537
class ConvertOSGBToLatLonIterDataPipe(IterDataPipe):
38+
"""Convert from OSGB to Lat/Lon"""
39+
2640
def __init__(self, source_datapipe: IterDataPipe):
41+
"""
42+
Convert from OSGB to Lat/Lon
43+
44+
Args:
45+
source_datapipe: Datapipe emitting Xarray objects with OSGB data
46+
"""
2747
self.source_datapipe = source_datapipe
2848

29-
def __iter__(self):
49+
def __iter__(self) -> Union[xr.DataArray, xr.Dataset]:
50+
"""Convert and add lat/lon to Xarray object"""
3051
for xr_data in self.source_datapipe:
3152
xr_data["latitude"], xr_data["longitude"] = osgb_to_lat_lon(
3253
x=xr_data["x_osgb"], y=xr_data["y_osgb"]
@@ -36,10 +57,19 @@ def __iter__(self):
3657

3758
@functional_datapipe("convert_geostationary_to_latlon")
3859
class ConvertGeostationaryToLatLonIterDataPipe(IterDataPipe):
60+
"""Convert from geostationary to Lat/Lon points"""
61+
3962
def __init__(self, source_datapipe: IterDataPipe):
63+
"""
64+
Convert from Geostationary to Lat/Lon points and add to Xarray object
65+
66+
Args:
67+
source_datapipe: Datapipe emitting Xarray object with geostationary points
68+
"""
4069
self.source_datapipe = source_datapipe
4170

42-
def __iter__(self):
71+
def __iter__(self) -> Union[xr.DataArray, xr.Dataset]:
72+
"""Convert from geostationary to Lat/Lon and yield the Xarray object"""
4373
for xr_data in self.source_datapipe:
4474
transform = load_geostationary_area_definition_and_transform_latlon(xr_data)
4575
xr_data["latitude"], xr_data["longitude"] = transform(

ocf_datapipes/convert/gsp.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,21 @@
77

88
@functional_datapipe("convert_gsp_to_numpy_batch")
99
class ConvertGSPToNumpyBatchIterDataPipe(IterDataPipe):
10-
def __init__(self, source_dp: IterDataPipe):
10+
"""Convert GSP Xarray to NumpyBatch"""
11+
12+
def __init__(self, source_datapipe: IterDataPipe):
13+
"""
14+
Convert GSP Xarray to NumpyBatch object
15+
16+
Args:
17+
source_datapipe: Datapipe emitting GSP Xarray object
18+
"""
1119
super().__init__()
12-
self.source_dp = source_dp
20+
self.source_datapipe = source_datapipe
1321

1422
def __iter__(self) -> NumpyBatch:
15-
for xr_data in self.source_dp:
23+
"""Convert from Xarray to NumpyBatch"""
24+
for xr_data in self.source_datapipe:
1625
example: NumpyBatch = {
1726
BatchKey.gsp: xr_data.values,
1827
BatchKey.gsp_t0_idx: xr_data.attrs["t0_idx"],

ocf_datapipes/convert/nwp.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,21 @@
88

99
@functional_datapipe("convert_nwp_to_numpy_batch")
1010
class ConvertNWPToNumpyBatchIterDataPipe(IterDataPipe):
11-
def __init__(self, source_dp: IterDataPipe):
11+
"""Convert NWP Xarray objects to NumpyBatch ones"""
12+
13+
def __init__(self, source_datapipe: IterDataPipe):
14+
"""
15+
Convert NWP Xarray objecs to NumpyBatch ones
16+
17+
Args:
18+
source_datapipe: Datapipe emitting NWP Xarray objects
19+
"""
1220
super().__init__()
13-
self.source_dp = source_dp
21+
self.source_datapipe = source_datapipe
1422

1523
def __iter__(self) -> NumpyBatch:
16-
for xr_data in self.source_dp:
24+
"""Convert from Xarray to NumpyBatch"""
25+
for xr_data in self.source_datapipe:
1726
example: NumpyBatch = {
1827
BatchKey.nwp: xr_data.values,
1928
BatchKey.nwp_t0_idx: xr_data.attrs["t0_idx"],

ocf_datapipes/convert/pv.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,21 @@
88

99
@functional_datapipe("convert_pv_to_numpy_batch")
1010
class ConvertPVToNumpyBatchIterDataPipe(IterDataPipe):
11-
def __init__(self, source_dp: IterDataPipe):
11+
"""Convert PV Xarray to NumpyBatch"""
12+
13+
def __init__(self, source_datapipe: IterDataPipe):
14+
"""
15+
Convert PV Xarray objects to NumpyBatch objects
16+
17+
Args:
18+
source_datapipe: Datapipe emitting PV Xarray objects
19+
"""
1220
super().__init__()
13-
self.source_dp = source_dp
21+
self.source_datapipe = source_datapipe
1422

1523
def __iter__(self) -> NumpyBatch:
16-
for xr_data in self.source_dp:
24+
"""Iterate and convert PV Xarray to NumpyBatch"""
25+
for xr_data in self.source_datapipe:
1726
example: NumpyBatch = {
1827
BatchKey.pv: xr_data.values,
1928
BatchKey.pv_t0_idx: xr_data.attrs["t0_idx"],

ocf_datapipes/convert/satellite.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,23 @@
77

88
@functional_datapipe("convert_satellite_to_numpy_batch")
99
class ConvertSatelliteToNumpyBatchIterDataPipe(IterDataPipe):
10-
def __init__(self, source_dp: IterDataPipe, is_hrv: bool = False):
10+
"""Converts Xarray Satellite to NumpyBatch object"""
11+
12+
def __init__(self, source_datapipe: IterDataPipe, is_hrv: bool = False):
13+
"""
14+
Converts Xarray satellite object to NumpyBatch object
15+
16+
Args:
17+
source_datapipe: Datapipe emitting Xarray satellite objects
18+
is_hrv: Whether this is HRV satellite data or non-HRV data
19+
"""
1120
super().__init__()
12-
self.source_dp = source_dp
21+
self.source_datapipe = source_datapipe
1322
self.is_hrv = is_hrv
1423

1524
def __iter__(self) -> NumpyBatch:
16-
for xr_data in self.source_dp:
25+
"""Convert each example to a NumpyBatch object"""
26+
for xr_data in self.source_datapipe:
1727
if self.is_hrv:
1828
example: NumpyBatch = {
1929
BatchKey.hrvsatellite_actual: xr_data.values,

ocf_datapipes/fake/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
"""Fake data generators for testing"""

ocf_datapipes/load/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
"""Loading datapipes from the raw data"""
12
from .configuration import OpenConfigurationIterDataPipe as OpenConfiguration
23
from .gsp import OpenGSPIterDataPipe as OpenGSP
34
from .nwp import OpenNWPIterDataPipe as OpenNWP

ocf_datapipes/load/configuration.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1+
import logging
2+
13
import fsspec
24
from pathy import Pathy
3-
import logging
45
from pyaml_env import parse_config
56
from torchdata.datapipes import functional_datapipe
67
from torchdata.datapipes.iter import IterDataPipe
@@ -16,11 +17,11 @@ def __init__(self, configuration_filename: str):
1617
self.configuration_filename = configuration_filename
1718

1819
def __iter__(self):
19-
logger.debug(f'Going to open {self.configuration_filename}')
20+
logger.debug(f"Going to open {self.configuration_filename}")
2021
with fsspec.open(self.configuration_filename, mode="r") as stream:
2122
configuration = parse_config(data=stream)
2223

23-
logger.debug(f'Converting to Configuration ({configuration})')
24+
logger.debug(f"Converting to Configuration ({configuration})")
2425
configuration = Configuration(**configuration)
2526

2627
while True:

ocf_datapipes/load/pv.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -214,10 +214,10 @@ def _load_pv_metadata(filename: str) -> pd.DataFrame:
214214
latitude, longitude, system_id, x_osgb, y_osgb
215215
"""
216216
_log.info(f"Loading PV metadata from {filename}")
217-
if 'passiv' in str(filename):
218-
index_col = 'ss_id'
217+
if "passiv" in str(filename):
218+
index_col = "ss_id"
219219
else:
220-
index_col = 'system_id'
220+
index_col = "system_id"
221221
pv_metadata = pd.read_csv(filename, index_col=index_col)
222222

223223
if "Unnamed: 0" in pv_metadata.columns:

ocf_datapipes/load/satellite.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,12 +45,24 @@ def open_sat_data(
4545
# Note that `rename` renames *both* the coordinates and dimensions, and keeps
4646
# the connection between the dims and coordinates, so we don't have to manually
4747
# use `data_array.set_index()`.
48-
dataset = dataset.rename({"time": "time_utc",})
48+
dataset = dataset.rename(
49+
{
50+
"time": "time_utc",
51+
}
52+
)
4953
if "y" in dataset.coords.keys():
50-
dataset = dataset.rename({"y": "y_geostationary",})
54+
dataset = dataset.rename(
55+
{
56+
"y": "y_geostationary",
57+
}
58+
)
5159

5260
if "x" in dataset.coords.keys():
53-
dataset = dataset.rename({"x": "x_geostationary",})
61+
dataset = dataset.rename(
62+
{
63+
"x": "x_geostationary",
64+
}
65+
)
5466

5567
# Flip coordinates to top-left first
5668
if dataset.y_geostationary[0] < dataset.y_geostationary[-1]:

ocf_datapipes/production/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
"""Wrappers to make complete data pipelines for production systems"""

ocf_datapipes/select/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
"""Selection datapipes"""
12
from .location_picker import LocationPickerIterDataPipe as LocationPicker
23
from .offset_t0 import OffsetT0IterDataPipe as OffsetT0
34
from .select_live_t0_time import SelectLiveT0TimeIterDataPipe as SelectLiveT0Time

ocf_datapipes/select/location_picker.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,23 @@
77

88
@functional_datapipe("location_picker")
99
class LocationPickerIterDataPipe(IterDataPipe):
10-
def __init__(self, source_dp: IterDataPipe, return_all_locations: bool = False):
10+
"""Picks locations from a dataset and returns them"""
11+
12+
def __init__(self, source_datapipe: IterDataPipe, return_all_locations: bool = False):
13+
"""
14+
Picks locations from a dataset and returns them
15+
16+
Args:
17+
source_datapipe: Datapipe emitting Xarray Dataset
18+
return_all_locations: Whether to return all locations, if True, also returns them in order
19+
"""
1120
super().__init__()
12-
self.source_dp = source_dp
21+
self.source_datapipe = source_datapipe
1322
self.return_all_locations = return_all_locations
1423

1524
def __iter__(self) -> Location:
16-
for xr_dataset in self.source_dp:
25+
"""Returns locations from the inputs datapipe"""
26+
for xr_dataset in self.source_datapipe:
1727
if self.return_all_locations:
1828
# Iterate through all locations in dataset
1929
for location_idx in range(len(xr_dataset["x_osgb"])):

ocf_datapipes/select/select_live_t0_time.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,18 @@ class SelectLiveT0TimeIterDataPipe(IterDataPipe):
88
"""Select the history for the live data"""
99

1010
def __init__(self, source_datapipe: IterDataPipe, dim_name: str = "time_utc"):
11+
"""
12+
Select history for the Xarray object
13+
14+
Args:
15+
source_datapipe: Datapipe emitting Xarray objects
16+
dim_name: The time dimension name to use
17+
"""
1118
self.source_datapipe = source_datapipe
1219
self.dim_name = dim_name
1320

14-
def __iter__(self):
21+
def __iter__(self) -> pd.Timestamp:
22+
"""Get the latest timestamp and return it"""
1523
for xr_data in self.source_datapipe:
1624
# Get most recent time in data
1725
# Select the history that goes back that far

ocf_datapipes/select/select_live_time_slice.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,22 @@ def __init__(
1919
history_duration: timedelta,
2020
dim_name: str = "time_utc",
2121
):
22+
"""
23+
Select the history for the live time slice
24+
25+
Args:
26+
source_datapipe: Datapipe emitting Xarray object
27+
t0_datapipe: Datapipe emitting t0 timestamps
28+
history_duration: Amount of time for the history
29+
dim_name: Time dimension name
30+
"""
2231
self.source_datapipe = source_datapipe
2332
self.t0_datapipe = t0_datapipe
2433
self.history_duration = np.timedelta64(history_duration)
2534
self.dim_name = dim_name
2635

2736
def __iter__(self):
37+
"""Select the recent live data"""
2838
for xr_data, t0 in Zipper(self.source_datapipe, self.t0_datapipe):
2939
xr_data = xr_data.sel({self.dim_name: slice(t0 - self.history_duration, t0)})
3040
yield xr_data

ocf_datapipes/transform/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
"""Transforms for the data in both xarray and numpy formats"""

ocf_datapipes/transform/numpy/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
"""Numpy transforms"""
12
from .add_topographic_data import AddTopographicDataIterDataPipe as AddTopographicData
23
from .align_gsp_to_5_min import AlignGSPto5MinIterDataPipe as AlignGSPto5Min
34
from .encode_space_time import EncodeSpaceTimeIterDataPipe as EncodeSpaceTime

ocf_datapipes/transform/xarray/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
"""Xarray transforms"""
12
from .add_t0idx_and_sample_period_duration import (
23
AddT0IdxAndSamplePeriodDurationIterDataPipe as AddT0IdxAndSamplePeriodDuration,
34
)

ocf_datapipes/transform/xarray/add_t0idx_and_sample_period_duration.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from datetime import timedelta
12
from typing import Union
23

34
import xarray as xr
@@ -7,13 +8,29 @@
78

89
@functional_datapipe("add_t0_idx_and_sample_period_duration")
910
class AddT0IdxAndSamplePeriodDurationIterDataPipe(IterDataPipe):
10-
def __init__(self, source_datapipe: IterDataPipe, sample_period_duration, history_duration):
11+
"""Add t0_idx and sample_period_duration attributes to datasets for downstream tasks"""
12+
13+
def __init__(
14+
self,
15+
source_datapipe: IterDataPipe,
16+
sample_period_duration: timedelta,
17+
history_duration: timedelta,
18+
):
19+
"""
20+
Adds two attributes, t0_idx, and sample_period_duration for downstream datapipes to use
21+
22+
Args:
23+
source_datapipe: Datapipe emitting a Xarray DataSet or DataArray
24+
sample_period_duration: Time between samples
25+
history_duration: Amount of history in each example
26+
"""
1127
self.source_datapipe = source_datapipe
1228
self.sample_period_duration = sample_period_duration
1329
self.history_duration = history_duration
1430
self.t0_idx = int(self.history_duration / self.sample_period_duration)
1531

1632
def __iter__(self) -> Union[xr.DataArray, xr.Dataset]:
33+
"""Adds the two attributes to the xarray objects and returns them"""
1734
for xr_data in self.source_datapipe:
1835
xr_data.attrs["t0_idx"] = self.t0_idx
1936
xr_data.attrs["sample_period_duration"] = self.sample_period_duration

0 commit comments

Comments
 (0)