Skip to content
This repository was archived by the owner on Jun 2, 2025. It is now read-only.

Commit e227a41

Browse files
committed
Update SelectLiveT0Time docstring and typing and test
1 parent 7d36094 commit e227a41

20 files changed

+129
-69
lines changed

ocf_datapipes/load/configuration.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,11 @@ def __init__(self, configuration_filename: str):
1616
self.configuration_filename = configuration_filename
1717

1818
def __iter__(self):
19-
logger.debug(f'Going to open {self.configuration_filename}')
19+
logger.debug(f"Going to open {self.configuration_filename}")
2020
with fsspec.open(self.configuration_filename, mode="r") as stream:
2121
configuration = parse_config(data=stream)
2222

23-
logger.debug(f'Converting to Configuration ({configuration})')
23+
logger.debug(f"Converting to Configuration ({configuration})")
2424
configuration = Configuration(**configuration)
2525

2626
while True:

ocf_datapipes/load/pv.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -214,10 +214,10 @@ def _load_pv_metadata(filename: str) -> pd.DataFrame:
214214
latitude, longitude, system_id, x_osgb, y_osgb
215215
"""
216216
_log.info(f"Loading PV metadata from {filename}")
217-
if 'passiv' in str(filename):
218-
index_col = 'ss_id'
217+
if "passiv" in str(filename):
218+
index_col = "ss_id"
219219
else:
220-
index_col = 'system_id'
220+
index_col = "system_id"
221221
pv_metadata = pd.read_csv(filename, index_col=index_col)
222222

223223
if "Unnamed: 0" in pv_metadata.columns:

ocf_datapipes/load/satellite.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,12 +45,24 @@ def open_sat_data(
4545
# Note that `rename` renames *both* the coordinates and dimensions, and keeps
4646
# the connection between the dims and coordinates, so we don't have to manually
4747
# use `data_array.set_index()`.
48-
dataset = dataset.rename({"time": "time_utc",})
48+
dataset = dataset.rename(
49+
{
50+
"time": "time_utc",
51+
}
52+
)
4953
if "y" in dataset.coords.keys():
50-
dataset = dataset.rename({"y": "y_geostationary",})
54+
dataset = dataset.rename(
55+
{
56+
"y": "y_geostationary",
57+
}
58+
)
5159

5260
if "x" in dataset.coords.keys():
53-
dataset = dataset.rename({"x": "x_geostationary",})
61+
dataset = dataset.rename(
62+
{
63+
"x": "x_geostationary",
64+
}
65+
)
5466

5567
# Flip coordinates to top-left first
5668
if dataset.y_geostationary[0] < dataset.y_geostationary[-1]:

ocf_datapipes/select/select_live_t0_time.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,18 @@ class SelectLiveT0TimeIterDataPipe(IterDataPipe):
88
"""Select the history for the live data"""
99

1010
def __init__(self, source_datapipe: IterDataPipe, dim_name: str = "time_utc"):
11+
"""
12+
Select history for the Xarray object
13+
14+
Args:
15+
source_datapipe: Datapipe emitting Xarray objects
16+
dim_name: The time dimension name to use
17+
"""
1118
self.source_datapipe = source_datapipe
1219
self.dim_name = dim_name
1320

14-
def __iter__(self):
21+
def __iter__(self) -> pd.Timestamp:
22+
"""Get the latest timestamp and return it"""
1523
for xr_data in self.source_datapipe:
1624
# Get most recent time in data
1725
# Select the history that goes back that far

ocf_datapipes/transform/xarray/add_t0idx_and_sample_period_duration.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,17 @@
55
from torchdata.datapipes.iter import IterDataPipe
66
from datetime import timedelta
77

8+
89
@functional_datapipe("add_t0_idx_and_sample_period_duration")
910
class AddT0IdxAndSamplePeriodDurationIterDataPipe(IterDataPipe):
1011
"""Add t0_idx and sample_period_duration attributes to datasets for downstream tasks"""
11-
def __init__(self, source_datapipe: IterDataPipe, sample_period_duration: timedelta, history_duration: timedelta):
12+
13+
def __init__(
14+
self,
15+
source_datapipe: IterDataPipe,
16+
sample_period_duration: timedelta,
17+
history_duration: timedelta,
18+
):
1219
"""
1320
Adds two attributes, t0_idx, and sample_period_duration for downstream datapipes to use
1421

ocf_datapipes/transform/xarray/convert_to_nwp_target_times.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
@functional_datapipe("convert_to_nwp_target_time")
1111
class ConvertToNWPTargetTimeIterDataPipe(IterDataPipe):
1212
"""Converts NWP Xarray to use the target time"""
13+
1314
def __init__(
1415
self,
1516
source_datapipe: IterDataPipe,

ocf_datapipes/transform/xarray/downsample.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,15 @@
55
@functional_datapipe("downsample")
66
class DownsampleIterDataPipe(IterDataPipe):
77
"""Downsample Xarray dataset with coarsen"""
8-
def __init__(self, source_datapipe: IterDataPipe, y_coarsen: int, x_coarsen: int, x_dim_name: str = "x_osgb", y_dim_name: str = "y_osgb"):
8+
9+
def __init__(
10+
self,
11+
source_datapipe: IterDataPipe,
12+
y_coarsen: int,
13+
x_coarsen: int,
14+
x_dim_name: str = "x_osgb",
15+
y_dim_name: str = "y_osgb",
16+
):
917
"""
1018
Downsample xarray dataset/dataarrays with coarsen
1119
@@ -25,8 +33,10 @@ def __init__(self, source_datapipe: IterDataPipe, y_coarsen: int, x_coarsen: int
2533
def __iter__(self):
2634
"""Coarsen the data on the specified dimensions"""
2735
for xr_data in self.source_datapipe:
28-
yield xr_data.coarsen({
29-
self.y_dim_name: self.y_coarsen,
30-
self.x_dim_name: self.x_coarsen,},
36+
yield xr_data.coarsen(
37+
{
38+
self.y_dim_name: self.y_coarsen,
39+
self.x_dim_name: self.x_coarsen,
40+
},
3141
boundary="trim",
3242
).mean()

ocf_datapipes/transform/xarray/get_contiguous_time_periods.py

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,39 @@
1-
import datetime
2-
31
import pandas as pd
4-
import xarray as xr
2+
import numpy as np
53
from torchdata.datapipes import functional_datapipe
64
from torchdata.datapipes.iter import IterDataPipe
5+
from datetime import timedelta
76

87

98
@functional_datapipe("add_contiguous_time_periods")
109
class GetContiguousT0TimePeriodsIterDataPipe(IterDataPipe):
10+
"""Get contiguous time periods for training"""
11+
1112
def __init__(
12-
self, source_dp: IterDataPipe, history_duration, forecast_duration, sample_period_duration
13+
self,
14+
source_datapipe: IterDataPipe,
15+
history_duration: timedelta,
16+
forecast_duration: timedelta,
17+
sample_period_duration: timedelta,
1318
):
14-
self.source_dp = source_dp
19+
"""
20+
Get contiguous time periods for use in determing t0 times for training
21+
22+
Args:
23+
source_datapipe: Datapipe emitting a Xarray dataset
24+
history_duration: Amount of time for the history of an example
25+
forecast_duration: Amount of time for the forecast of an example
26+
sample_period_duration: The sampling period of the data source
27+
"""
28+
self.source_datapipe = source_datapipe
1529
self.history_duration = history_duration
1630
self.forecast_duration = forecast_duration
1731
self.total_duration = history_duration + forecast_duration
1832
self.sample_period_duration = sample_period_duration
1933

2034
def __iter__(self) -> pd.DataFrame:
21-
for xr_data in self.source_dp:
35+
"""Calculate contiguous time periods and return a dataframe containing them"""
36+
for xr_data in self.source_datapipe:
2237
contiguous_time_periods = get_contiguous_time_periods(
2338
datetimes=pd.DatetimeIndex(xr_data["time_utc"]),
2439
min_seq_length=int(self.total_duration / self.sample_period_duration) + 1,
@@ -33,7 +48,7 @@ def __iter__(self) -> pd.DataFrame:
3348

3449

3550
def get_contiguous_t0_time_periods(
36-
contiguous_time_periods, history_duration, forecast_duration
51+
contiguous_time_periods: pd.DataFrame, history_duration: timedelta, forecast_duration: timedelta
3752
) -> pd.DataFrame:
3853
"""Get all time periods which contain valid t0 datetimes.
3954
@@ -52,7 +67,7 @@ def get_contiguous_t0_time_periods(
5267
def get_contiguous_time_periods(
5368
datetimes: pd.DatetimeIndex,
5469
min_seq_length: int,
55-
max_gap_duration: datetime.timedelta,
70+
max_gap_duration: timedelta,
5671
) -> pd.DataFrame:
5772
"""Return a pd.DataFrame where each row records the boundary of a contiguous time period.
5873

ocf_datapipes/transform/xarray/normalize.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,13 @@
88
@functional_datapipe("normalize")
99
class NormalizeIterDataPipe(IterDataPipe):
1010
"""Normalize the data in various methods"""
11+
1112
def __init__(
1213
self,
1314
source_datapipe: IterDataPipe,
14-
mean: Optional[Union[xr.Dataset, xr.DataArray, np.ndarray]]=None,
15-
std: Optional[Union[xr.Dataset, xr.DataArray, np.ndarray]]=None,
16-
max_value: Optional[Union[int, float]]=None,
15+
mean: Optional[Union[xr.Dataset, xr.DataArray, np.ndarray]] = None,
16+
std: Optional[Union[xr.Dataset, xr.DataArray, np.ndarray]] = None,
17+
max_value: Optional[Union[int, float]] = None,
1718
calculate_mean_std_from_example: bool = False,
1819
normalize_fn: Optional[Callable] = None,
1920
):

ocf_datapipes/transform/xarray/pv_power_rolling_window.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
@functional_datapipe("pv_power_rolling_window")
1010
class PVPowerRollingWindowIterDataPipe(IterDataPipe):
1111
"""Compute rolling mean of PV power."""
12+
1213
def __init__(
1314
self,
1415
source_datapipe: IterDataPipe,

tests/convert/test_gsp.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
from datetime import timedelta
55

66

7-
87
def test_convert_gsp_to_numpy_batch(gsp_dp):
98
gsp_dp = AddT0IdxAndSamplePeriodDuration(
109
gsp_dp, sample_period_duration=timedelta(minutes=5), history_duration=timedelta(minutes=60)

tests/convert/test_nwp.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ def test_convert_nwp_to_numpy_batch(nwp_dp):
99
nwp_dp = AddT0IdxAndSamplePeriodDuration(
1010
nwp_dp, sample_period_duration=timedelta(minutes=60), history_duration=timedelta(minutes=60)
1111
)
12-
t0_dp = SelectLiveT0Time(nwp_dp, dim_name='init_time_utc')
12+
t0_dp = SelectLiveT0Time(nwp_dp, dim_name="init_time_utc")
1313

1414
nwp_dp = ConvertToNWPTargetTime(
1515
nwp_dp,

tests/convert/test_pv.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,9 @@
66

77
def test_convert_passiv_to_numpy_batch(passiv_dp):
88
passiv_dp = AddT0IdxAndSamplePeriodDuration(
9-
passiv_dp, sample_period_duration=timedelta(minutes=5), history_duration=timedelta(minutes=60)
9+
passiv_dp,
10+
sample_period_duration=timedelta(minutes=5),
11+
history_duration=timedelta(minutes=60),
1012
)
1113
passiv_dp = ConvertPVToNumpyBatch(passiv_dp)
1214
data = next(iter(passiv_dp))
@@ -15,7 +17,9 @@ def test_convert_passiv_to_numpy_batch(passiv_dp):
1517

1618
def test_convert_pvoutput_to_numpy_batch(pvoutput_dp):
1719
pvoutput_dp = AddT0IdxAndSamplePeriodDuration(
18-
pvoutput_dp, sample_period_duration=timedelta(minutes=5), history_duration=timedelta(minutes=60)
20+
pvoutput_dp,
21+
sample_period_duration=timedelta(minutes=5),
22+
history_duration=timedelta(minutes=60),
1923
)
2024

2125
pvoutput_dp = ConvertPVToNumpyBatch(pvoutput_dp)

tests/load/test_load_config.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,7 @@
1-
21
from ocf_datapipes.load import OpenConfiguration
32

43

54
def test_open_config():
65
config_dp = OpenConfiguration("tests/config/test.yaml")
76
configuration = next(iter(config_dp))
87
print(configuration)
9-

tests/production/test_pp_production.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import pytest
77

88

9-
@pytest.mark.skip('Need to set up laod PV from database first')
9+
@pytest.mark.skip("Need to set up laod PV from database first")
1010
def test_pp_production_datapipe():
1111

1212
filename = os.path.join(os.path.dirname(ocf_datapipes.__file__), "../tests/config/test.yaml")

tests/select/test_select_live_t0_time_slice.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def test_select_passiv(passiv_dp):
5858

5959
def test_select_pvoutput(pvoutput_dp):
6060
time_len = len(next(iter(pvoutput_dp)).time_utc.values)
61-
t0_dp = SelectLiveT0Time(pvoutput_dp, dim_name='time_utc')
61+
t0_dp = SelectLiveT0Time(pvoutput_dp, dim_name="time_utc")
6262
pvoutput_dp = SelectLiveTimeSlice(
6363
pvoutput_dp,
6464
t0_datapipe=t0_dp,

tests/select/test_select_live_time_slice.py

Lines changed: 20 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from datetime import timedelta
22

33
from ocf_datapipes.select import SelectLiveT0Time, SelectLiveTimeSlice
4-
# from ocf_datapipes.transform.xarray import AddNWPTargetTime
4+
from ocf_datapipes.transform.xarray import ConvertToNWPTargetTime
55

66

77
def test_select_hrv(sat_hrv_dp):
@@ -22,28 +22,25 @@ def test_select_gsp(gsp_dp):
2222
assert len(data.time_utc.values) < time_len
2323

2424

25-
# def test_select_nwp(nwp_dp):
26-
# t0_dp = SelectLiveT0Time(nwp_dp, dim_name="init_time_utc")
27-
# nwp_dp = AddNWPTargetTime(
28-
# nwp_dp,
29-
# t0_dp,
30-
# sample_period_duration=timedelta(hours=1),
31-
# history_duration=timedelta(hours=2),
32-
# forecast_duration=timedelta(hours=4),
33-
# )
34-
# data = next(iter(nwp_dp))
35-
# print(data)
36-
# print(data.init_time_utc.values)
37-
# time_len = len(next(iter(nwp_dp)).target_time_utc.values)
38-
# nwp_dp = SelectLiveTimeSlice(
39-
# nwp_dp,
40-
# t0_datapipe=t0_dp,
41-
# history_duration=timedelta(minutes=120),
42-
# dim_name="target_time_utc",
43-
# )
44-
# data = next(iter(nwp_dp))
45-
# assert len(data.time_utc.values) == 3
46-
# assert len(data.time_utc.values) < time_len
25+
def test_select_nwp(nwp_dp):
26+
t0_dp = SelectLiveT0Time(nwp_dp, dim_name="init_time_utc")
27+
nwp_dp = ConvertToNWPTargetTime(
28+
nwp_dp,
29+
t0_dp,
30+
sample_period_duration=timedelta(hours=1),
31+
history_duration=timedelta(hours=2),
32+
forecast_duration=timedelta(hours=4),
33+
)
34+
time_len = len(next(iter(nwp_dp)).target_time_utc.values)
35+
nwp_dp = SelectLiveTimeSlice(
36+
nwp_dp,
37+
t0_datapipe=t0_dp,
38+
history_duration=timedelta(minutes=120),
39+
dim_name="target_time_utc",
40+
)
41+
data = next(iter(nwp_dp))
42+
assert len(data.target_time_utc.values) == 3
43+
assert len(data.target_time_utc.values) < time_len
4744

4845

4946
def test_select_passiv(passiv_dp):

tests/transform/xarray/test_add_nwp_target_time.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,16 @@
22
from ocf_datapipes.select import SelectLiveT0Time
33
from datetime import timedelta
44

5+
56
def test_add_nwp_target_time(nwp_dp):
67
t0_dp = SelectLiveT0Time(nwp_dp, dim_name="init_time_utc")
7-
nwp_dp = ConvertToNWPTargetTime(nwp_dp, t0_dp, sample_period_duration=timedelta(minutes=60), history_duration=timedelta(hours=2), forecast_duration=timedelta(hours=3))
8+
nwp_dp = ConvertToNWPTargetTime(
9+
nwp_dp,
10+
t0_dp,
11+
sample_period_duration=timedelta(minutes=60),
12+
history_duration=timedelta(hours=2),
13+
forecast_duration=timedelta(hours=3),
14+
)
815
data = next(iter(nwp_dp))
916
assert "target_time_utc" in data.coords
1017
assert len(data.coords["target_time_utc"]) == 6
11-

tests/transform/xarray/test_downsample.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,13 @@ def test_nwp_downsample(nwp_dp):
1010

1111

1212
def test_sat_downsample(sat_dp):
13-
sat_dp = Downsample(sat_dp, y_coarsen=16, x_coarsen=16, y_dim_name="y_geostationary", x_dim_name="x_geostationary")
13+
sat_dp = Downsample(
14+
sat_dp,
15+
y_coarsen=16,
16+
x_coarsen=16,
17+
y_dim_name="y_geostationary",
18+
x_dim_name="x_geostationary",
19+
)
1420
data = next(iter(sat_dp))
1521
assert data.shape[-1] == 38
1622
assert data.shape[-2] == 18

0 commit comments

Comments
 (0)