Skip to content
This repository was archived by the owner on Jun 2, 2025. It is now read-only.

Commit 78eb766

Browse files
committed
Add docstrings for converters and better tests
1 parent ed22fa2 commit 78eb766

File tree

9 files changed

+94
-22
lines changed

9 files changed

+94
-22
lines changed

ocf_datapipes/convert/coordinates.py

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from torchdata.datapipes import functional_datapipe
22
from torchdata.datapipes.iter import IterDataPipe
3+
import xarray as xr
4+
from typing import Union
35

46
from ocf_datapipes.utils.geospatial import (
57
lat_lon_to_osgb,
@@ -10,10 +12,18 @@
1012

1113
@functional_datapipe("convert_latlon_to_osgb")
1214
class ConvertLatLonToOSGBIterDataPipe(IterDataPipe):
15+
"""Convert from Lat/Lon object to OSGB"""
1316
def __init__(self, source_datapipe: IterDataPipe):
17+
"""
18+
Convert from Lat/Lon to OSGB
19+
20+
Args:
21+
source_datapipe: Datapipe emitting Xarray objects with latitude and longitude data
22+
"""
1423
self.source_datapipe = source_datapipe
1524

16-
def __iter__(self):
25+
def __iter__(self) -> Union[xr.DataArray, xr.Dataset]:
26+
"""Convert from Lat/Lon to OSGB"""
1727
for xr_data in self.source_datapipe:
1828
xr_data["x_osgb"], xr_data["y_osgb"] = lat_lon_to_osgb(
1929
latitude=xr_data["latitude"], longitude=xr_data["longitude"]
@@ -23,10 +33,18 @@ def __iter__(self):
2333

2434
@functional_datapipe("convert_osgb_to_latlon")
2535
class ConvertOSGBToLatLonIterDataPipe(IterDataPipe):
36+
"""Convert from OSGB to Lat/Lon"""
2637
def __init__(self, source_datapipe: IterDataPipe):
38+
"""
39+
Convert from OSGB to Lat/Lon
40+
41+
Args:
42+
source_datapipe: Datapipe emitting Xarray objects with OSGB data
43+
"""
2744
self.source_datapipe = source_datapipe
2845

29-
def __iter__(self):
46+
def __iter__(self) -> Union[xr.DataArray, xr.Dataset]:
47+
"""Convert and add lat/lon to Xarray object"""
3048
for xr_data in self.source_datapipe:
3149
xr_data["latitude"], xr_data["longitude"] = osgb_to_lat_lon(
3250
x=xr_data["x_osgb"], y=xr_data["y_osgb"]
@@ -36,10 +54,18 @@ def __iter__(self):
3654

3755
@functional_datapipe("convert_geostationary_to_latlon")
3856
class ConvertGeostationaryToLatLonIterDataPipe(IterDataPipe):
57+
"""Convert from geostationary to Lat/Lon points"""
3958
def __init__(self, source_datapipe: IterDataPipe):
59+
"""
60+
Convert from Geostationary to Lat/Lon points and add to Xarray object
61+
62+
Args:
63+
source_datapipe: Datapipe emitting Xarray object with geostationary points
64+
"""
4065
self.source_datapipe = source_datapipe
4166

42-
def __iter__(self):
67+
def __iter__(self) -> Union[xr.DataArray, xr.Dataset]:
68+
"""Convert from geostationary to Lat/Lon and yield the Xarray object"""
4369
for xr_data in self.source_datapipe:
4470
transform = load_geostationary_area_definition_and_transform_latlon(xr_data)
4571
xr_data["latitude"], xr_data["longitude"] = transform(

ocf_datapipes/convert/gsp.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,20 @@
77

88
@functional_datapipe("convert_gsp_to_numpy_batch")
99
class ConvertGSPToNumpyBatchIterDataPipe(IterDataPipe):
10-
def __init__(self, source_dp: IterDataPipe):
10+
"""Convert GSP Xarray to NumpyBatch"""
11+
def __init__(self, source_datapipe: IterDataPipe):
12+
"""
13+
Convert GSP Xarray to NumpyBatch object
14+
15+
Args:
16+
source_datapipe: Datapipe emitting GSP Xarray object
17+
"""
1118
super().__init__()
12-
self.source_dp = source_dp
19+
self.source_datapipe = source_datapipe
1320

1421
def __iter__(self) -> NumpyBatch:
15-
for xr_data in self.source_dp:
22+
"""Convert from Xarray to NumpyBatch"""
23+
for xr_data in self.source_datapipe:
1624
example: NumpyBatch = {
1725
BatchKey.gsp: xr_data.values,
1826
BatchKey.gsp_t0_idx: xr_data.attrs["t0_idx"],

ocf_datapipes/convert/nwp.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,20 @@
88

99
@functional_datapipe("convert_nwp_to_numpy_batch")
1010
class ConvertNWPToNumpyBatchIterDataPipe(IterDataPipe):
11-
def __init__(self, source_dp: IterDataPipe):
11+
"""Convert NWP Xarray objects to NumpyBatch ones"""
12+
def __init__(self, source_datapipe: IterDataPipe):
13+
"""
14+
Convert NWP Xarray objecs to NumpyBatch ones
15+
16+
Args:
17+
source_datapipe: Datapipe emitting NWP Xarray objects
18+
"""
1219
super().__init__()
13-
self.source_dp = source_dp
20+
self.source_datapipe = source_datapipe
1421

1522
def __iter__(self) -> NumpyBatch:
16-
for xr_data in self.source_dp:
23+
"""Convert from Xarray to NumpyBatch"""
24+
for xr_data in self.source_datapipe:
1725
example: NumpyBatch = {
1826
BatchKey.nwp: xr_data.values,
1927
BatchKey.nwp_t0_idx: xr_data.attrs["t0_idx"],

ocf_datapipes/convert/pv.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,20 @@
88

99
@functional_datapipe("convert_pv_to_numpy_batch")
1010
class ConvertPVToNumpyBatchIterDataPipe(IterDataPipe):
11-
def __init__(self, source_dp: IterDataPipe):
11+
"""Convert PV Xarray to NumpyBatch"""
12+
def __init__(self, source_datapipe: IterDataPipe):
13+
"""
14+
Convert PV Xarray objects to NumpyBatch objects
15+
16+
Args:
17+
source_datapipe: Datapipe emitting PV Xarray objects
18+
"""
1219
super().__init__()
13-
self.source_dp = source_dp
20+
self.source_datapipe = source_datapipe
1421

1522
def __iter__(self) -> NumpyBatch:
16-
for xr_data in self.source_dp:
23+
"""Iterate and convert PV Xarray to NumpyBatch"""
24+
for xr_data in self.source_datapipe:
1725
example: NumpyBatch = {
1826
BatchKey.pv: xr_data.values,
1927
BatchKey.pv_t0_idx: xr_data.attrs["t0_idx"],

ocf_datapipes/convert/satellite.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,22 @@
77

88
@functional_datapipe("convert_satellite_to_numpy_batch")
99
class ConvertSatelliteToNumpyBatchIterDataPipe(IterDataPipe):
10-
def __init__(self, source_dp: IterDataPipe, is_hrv: bool = False):
10+
"""Converts Xarray Satellite to NumpyBatch object"""
11+
def __init__(self, source_datapipe: IterDataPipe, is_hrv: bool = False):
12+
"""
13+
Converts Xarray satellite object to NumpyBatch object
14+
15+
Args:
16+
source_datapipe: Datapipe emitting Xarray satellite objects
17+
is_hrv: Whether this is HRV satellite data or non-HRV data
18+
"""
1119
super().__init__()
12-
self.source_dp = source_dp
20+
self.source_datapipe = source_datapipe
1321
self.is_hrv = is_hrv
1422

1523
def __iter__(self) -> NumpyBatch:
16-
for xr_data in self.source_dp:
24+
"""Convert each example to a NumpyBatch object"""
25+
for xr_data in self.source_datapipe:
1726
if self.is_hrv:
1827
example: NumpyBatch = {
1928
BatchKey.hrvsatellite_actual: xr_data.values,

tests/convert/test_gsp.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,13 @@
22

33
from ocf_datapipes.convert import ConvertGSPToNumpyBatch
44
from ocf_datapipes.transform.xarray import AddT0IdxAndSamplePeriodDuration
5-
5+
from ocf_datapipes.utils.consts import BatchKey
66

77
def test_convert_gsp_to_numpy_batch(gsp_dp):
88
gsp_dp = AddT0IdxAndSamplePeriodDuration(
99
gsp_dp, sample_period_duration=timedelta(minutes=5), history_duration=timedelta(minutes=60)
1010
)
1111
gsp_dp = ConvertGSPToNumpyBatch(gsp_dp)
1212
data = next(iter(gsp_dp))
13-
assert data is not None
13+
assert BatchKey.gsp in data
14+
assert BatchKey.gsp_id in data

tests/convert/test_nwp.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from ocf_datapipes.convert import ConvertNWPToNumpyBatch
44
from ocf_datapipes.select import SelectLiveT0Time
55
from ocf_datapipes.transform.xarray import AddT0IdxAndSamplePeriodDuration, ConvertToNWPTargetTime
6+
from ocf_datapipes.utils.consts import BatchKey
67

78

89
def test_convert_nwp_to_numpy_batch(nwp_dp):
@@ -20,4 +21,5 @@ def test_convert_nwp_to_numpy_batch(nwp_dp):
2021
)
2122
nwp_dp = ConvertNWPToNumpyBatch(nwp_dp)
2223
data = next(iter(nwp_dp))
23-
assert data is not None
24+
assert BatchKey.nwp in data
25+
assert BatchKey.nwp_channel_names in data

tests/convert/test_pv.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from ocf_datapipes.convert import ConvertPVToNumpyBatch
44
from ocf_datapipes.transform.xarray import AddT0IdxAndSamplePeriodDuration
5+
from ocf_datapipes.utils.consts import BatchKey
56

67

78
def test_convert_passiv_to_numpy_batch(passiv_dp):
@@ -12,7 +13,8 @@ def test_convert_passiv_to_numpy_batch(passiv_dp):
1213
)
1314
passiv_dp = ConvertPVToNumpyBatch(passiv_dp)
1415
data = next(iter(passiv_dp))
15-
assert data is not None
16+
assert BatchKey.pv in data
17+
assert BatchKey.pv_t0_idx in data
1618

1719

1820
def test_convert_pvoutput_to_numpy_batch(pvoutput_dp):
@@ -25,4 +27,5 @@ def test_convert_pvoutput_to_numpy_batch(pvoutput_dp):
2527
pvoutput_dp = ConvertPVToNumpyBatch(pvoutput_dp)
2628

2729
data = next(iter(pvoutput_dp))
28-
assert data is not None
30+
assert BatchKey.pv in data
31+
assert BatchKey.pv_t0_idx in data

tests/convert/test_satellite.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from ocf_datapipes.convert import ConvertSatelliteToNumpyBatch
44
from ocf_datapipes.transform.xarray import AddT0IdxAndSamplePeriodDuration
5+
from ocf_datapipes.utils.consts import BatchKey
56

67

78
def test_convert_satellite_to_numpy_batch(sat_dp):
@@ -11,7 +12,10 @@ def test_convert_satellite_to_numpy_batch(sat_dp):
1112
)
1213
sat_dp = ConvertSatelliteToNumpyBatch(sat_dp, is_hrv=False)
1314
data = next(iter(sat_dp))
14-
assert data is not None
15+
assert BatchKey.satellite_actual in data
16+
assert BatchKey.satellite_t0_idx in data
17+
assert BatchKey.hrvsatellite_actual not in data
18+
assert BatchKey.hrvsatellite_t0_idx not in data
1519

1620

1721
def test_convert_hrvsatellite_to_numpy_batch(sat_dp):
@@ -20,4 +24,7 @@ def test_convert_hrvsatellite_to_numpy_batch(sat_dp):
2024
)
2125
sat_dp = ConvertSatelliteToNumpyBatch(sat_dp, is_hrv=True)
2226
data = next(iter(sat_dp))
23-
assert data is not None
27+
assert BatchKey.hrvsatellite_actual in data
28+
assert BatchKey.hrvsatellite_t0_idx in data
29+
assert BatchKey.satellite_actual not in data
30+
assert BatchKey.satellite_t0_idx not in data

0 commit comments

Comments
 (0)