Skip to content
This repository was archived by the owner on Jun 2, 2025. It is now read-only.

Commit 7d36094

Browse files
committed
Update PVRollingPowerWindow and docstrings
1 parent dcbe63c commit 7d36094

File tree

3 files changed

+54
-27
lines changed

3 files changed

+54
-27
lines changed

ocf_datapipes/transform/xarray/add_t0idx_and_sample_period_duration.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,17 +3,27 @@
33
import xarray as xr
44
from torchdata.datapipes import functional_datapipe
55
from torchdata.datapipes.iter import IterDataPipe
6-
6+
from datetime import timedelta
77

88
@functional_datapipe("add_t0_idx_and_sample_period_duration")
99
class AddT0IdxAndSamplePeriodDurationIterDataPipe(IterDataPipe):
10-
def __init__(self, source_datapipe: IterDataPipe, sample_period_duration, history_duration):
10+
"""Add t0_idx and sample_period_duration attributes to datasets for downstream tasks"""
11+
def __init__(self, source_datapipe: IterDataPipe, sample_period_duration: timedelta, history_duration: timedelta):
12+
"""
13+
Adds two attributes, t0_idx, and sample_period_duration for downstream datapipes to use
14+
15+
Args:
16+
source_datapipe: Datapipe emitting a Xarray DataSet or DataArray
17+
sample_period_duration: Time between samples
18+
history_duration: Amount of history in each example
19+
"""
1120
self.source_datapipe = source_datapipe
1221
self.sample_period_duration = sample_period_duration
1322
self.history_duration = history_duration
1423
self.t0_idx = int(self.history_duration / self.sample_period_duration)
1524

1625
def __iter__(self) -> Union[xr.DataArray, xr.Dataset]:
26+
"""Adds the two attributes to the xarray objects and returns them"""
1727
for xr_data in self.source_datapipe:
1828
xr_data.attrs["t0_idx"] = self.t0_idx
1929
xr_data.attrs["sample_period_duration"] = self.sample_period_duration

ocf_datapipes/transform/xarray/pv_power_rolling_window.py

Lines changed: 40 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from typing import Optional, Union
22

3-
import numpy as np
43
import pandas as pd
54
import xarray as xr
65
from torchdata.datapipes import functional_datapipe
@@ -9,24 +8,60 @@
98

109
@functional_datapipe("pv_power_rolling_window")
1110
class PVPowerRollingWindowIterDataPipe(IterDataPipe):
11+
"""Compute rolling mean of PV power."""
1212
def __init__(
1313
self,
14-
source_dp: IterDataPipe,
14+
source_datapipe: IterDataPipe,
1515
window: Union[int, pd.tseries.offsets.DateOffset, pd.core.indexers.objects.BaseIndexer] = 3,
1616
min_periods: Optional[int] = 2,
1717
center: bool = True,
1818
win_type: Optional[str] = None,
1919
expect_dataset: bool = True,
2020
):
21-
self.source_dp = source_dp
21+
"""
22+
Compute the rolling mean of PV power data
23+
24+
Args:
25+
source_datapipe: Datapipe emitting PV Xarray object
26+
27+
window: Size of the moving window.
28+
If an integer, the fixed number of observations used for each window.
29+
30+
If an offset, the time period of each window. Each window will be a variable sized
31+
based on the observations included in the time-period. This is only valid for
32+
datetimelike indexes. To learn more about the offsets & frequency strings, please see:
33+
https://pandas.pydata.org/pandas-docs/stable/user_guide/timeseries.html#offset-aliases
34+
35+
If a BaseIndexer subclass, the window boundaries based on the defined
36+
`get_window_bounds` method. Additional rolling keyword arguments,
37+
namely `min_periods` and `center` will be passed to `get_window_bounds`.
38+
39+
min_periods: Minimum number of observations in window required to have a value;
40+
otherwise, result is `np.nan`.
41+
42+
To avoid NaNs at the start and end of the timeseries, this should be <= ceil(window/2).
43+
44+
For a window that is specified by an offset, `min_periods` will default to 1.
45+
46+
For a window that is specified by an integer, `min_periods` will default to the size of
47+
the window.
48+
49+
center: If False, set the window labels as the right edge of the window index.
50+
If True, set the window labels as the center of the window index.
51+
52+
win_type: Window type
53+
expect_dataset: Whether to expect a dataset or DataArray
54+
"""
55+
self.source_datapipe = source_datapipe
2256
self.window = window
2357
self.min_periods = min_periods
2458
self.center = center
2559
self.win_type = win_type
2660
self.expect_dataset = expect_dataset
2761

28-
def __iter__(self):
29-
for xr_data in self.source_dp:
62+
def __iter__(self) -> Union[xr.DataArray, xr.Dataset]:
63+
"""Compute rolling mean of PV power"""
64+
for xr_data in self.source_datapipe:
3065
if self.expect_dataset:
3166
data_to_resample = xr_data["power_w"]
3267
else:
@@ -47,22 +82,3 @@ def __iter__(self):
4782
resampled.attrs[attr_name] = xr_data.attrs[attr_name]
4883

4984
yield resampled
50-
51-
52-
def set_new_sample_period_and_t0_idx_attrs(xr_data, new_sample_period) -> xr.DataArray:
53-
orig_sample_period = xr_data.attrs["sample_period_duration"]
54-
orig_t0_idx = xr_data.attrs["t0_idx"]
55-
new_sample_period = pd.Timedelta(new_sample_period)
56-
assert new_sample_period >= orig_sample_period
57-
new_t0_idx = orig_t0_idx / (new_sample_period / orig_sample_period)
58-
np.testing.assert_almost_equal(
59-
int(new_t0_idx),
60-
new_t0_idx,
61-
err_msg=(
62-
"The original t0_idx must be exactly divisible by"
63-
" (new_sample_period / orig_sample_period)"
64-
),
65-
)
66-
xr_data.attrs["sample_period_duration"] = new_sample_period
67-
xr_data.attrs["t0_idx"] = int(new_t0_idx)
68-
return xr_data

tests/transform/xarray/test_pv_power_rolling_window.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ def test_pv_power_rolling_window_passiv(passiv_dp):
99
history_duration=timedelta(minutes=60),
1010
sample_period_duration=timedelta(minutes=5),
1111
)
12+
data_before = next(iter(passiv_dp))
1213
passiv_dp = PVPowerRollingWindow(passiv_dp, expect_dataset=False)
1314
data = next(iter(passiv_dp))
14-
assert data is not None
15+
assert len(data.time_utc) == len(data_before.time_utc)

0 commit comments

Comments
 (0)