Skip to content
This repository was archived by the owner on Jun 2, 2025. It is now read-only.

Commit c1c059c

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 7d36094 commit c1c059c

21 files changed

+116
-64
lines changed

.github/workflows/workflows.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,5 +14,5 @@ jobs:
1414
pytest_cov_dir: "ocf_datapipes"
1515
# extra things to install
1616
sudo_apt_install: "libgeos++-dev libproj-dev proj-data proj-bin"
17-
# brew_install: "proj geos librttopo"
18-
os_list: '["ubuntu-latest"]'
17+
# brew_install: "proj geos librttopo"
18+
os_list: '["ubuntu-latest"]'

ocf_datapipes/load/configuration.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1+
import logging
2+
13
import fsspec
24
from pathy import Pathy
3-
import logging
45
from pyaml_env import parse_config
56
from torchdata.datapipes import functional_datapipe
67
from torchdata.datapipes.iter import IterDataPipe
@@ -16,11 +17,11 @@ def __init__(self, configuration_filename: str):
1617
self.configuration_filename = configuration_filename
1718

1819
def __iter__(self):
19-
logger.debug(f'Going to open {self.configuration_filename}')
20+
logger.debug(f"Going to open {self.configuration_filename}")
2021
with fsspec.open(self.configuration_filename, mode="r") as stream:
2122
configuration = parse_config(data=stream)
2223

23-
logger.debug(f'Converting to Configuration ({configuration})')
24+
logger.debug(f"Converting to Configuration ({configuration})")
2425
configuration = Configuration(**configuration)
2526

2627
while True:

ocf_datapipes/load/pv.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -214,10 +214,10 @@ def _load_pv_metadata(filename: str) -> pd.DataFrame:
214214
latitude, longitude, system_id, x_osgb, y_osgb
215215
"""
216216
_log.info(f"Loading PV metadata from {filename}")
217-
if 'passiv' in str(filename):
218-
index_col = 'ss_id'
217+
if "passiv" in str(filename):
218+
index_col = "ss_id"
219219
else:
220-
index_col = 'system_id'
220+
index_col = "system_id"
221221
pv_metadata = pd.read_csv(filename, index_col=index_col)
222222

223223
if "Unnamed: 0" in pv_metadata.columns:

ocf_datapipes/load/satellite.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,12 +45,24 @@ def open_sat_data(
4545
# Note that `rename` renames *both* the coordinates and dimensions, and keeps
4646
# the connection between the dims and coordinates, so we don't have to manually
4747
# use `data_array.set_index()`.
48-
dataset = dataset.rename({"time": "time_utc",})
48+
dataset = dataset.rename(
49+
{
50+
"time": "time_utc",
51+
}
52+
)
4953
if "y" in dataset.coords.keys():
50-
dataset = dataset.rename({"y": "y_geostationary",})
54+
dataset = dataset.rename(
55+
{
56+
"y": "y_geostationary",
57+
}
58+
)
5159

5260
if "x" in dataset.coords.keys():
53-
dataset = dataset.rename({"x": "x_geostationary",})
61+
dataset = dataset.rename(
62+
{
63+
"x": "x_geostationary",
64+
}
65+
)
5466

5567
# Flip coordinates to top-left first
5668
if dataset.y_geostationary[0] < dataset.y_geostationary[-1]:

ocf_datapipes/transform/xarray/add_t0idx_and_sample_period_duration.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,21 @@
1+
from datetime import timedelta
12
from typing import Union
23

34
import xarray as xr
45
from torchdata.datapipes import functional_datapipe
56
from torchdata.datapipes.iter import IterDataPipe
6-
from datetime import timedelta
7+
78

89
@functional_datapipe("add_t0_idx_and_sample_period_duration")
910
class AddT0IdxAndSamplePeriodDurationIterDataPipe(IterDataPipe):
1011
"""Add t0_idx and sample_period_duration attributes to datasets for downstream tasks"""
11-
def __init__(self, source_datapipe: IterDataPipe, sample_period_duration: timedelta, history_duration: timedelta):
12+
13+
def __init__(
14+
self,
15+
source_datapipe: IterDataPipe,
16+
sample_period_duration: timedelta,
17+
history_duration: timedelta,
18+
):
1219
"""
1320
Adds two attributes, t0_idx, and sample_period_duration for downstream datapipes to use
1421

ocf_datapipes/transform/xarray/convert_to_nwp_target_times.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
from typing import Union
21
from datetime import timedelta
2+
from typing import Union
33

44
import pandas as pd
55
import xarray as xr
@@ -10,6 +10,7 @@
1010
@functional_datapipe("convert_to_nwp_target_time")
1111
class ConvertToNWPTargetTimeIterDataPipe(IterDataPipe):
1212
"""Converts NWP Xarray to use the target time"""
13+
1314
def __init__(
1415
self,
1516
source_datapipe: IterDataPipe,

ocf_datapipes/transform/xarray/downsample.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,15 @@
55
@functional_datapipe("downsample")
66
class DownsampleIterDataPipe(IterDataPipe):
77
"""Downsample Xarray dataset with coarsen"""
8-
def __init__(self, source_datapipe: IterDataPipe, y_coarsen: int, x_coarsen: int, x_dim_name: str = "x_osgb", y_dim_name: str = "y_osgb"):
8+
9+
def __init__(
10+
self,
11+
source_datapipe: IterDataPipe,
12+
y_coarsen: int,
13+
x_coarsen: int,
14+
x_dim_name: str = "x_osgb",
15+
y_dim_name: str = "y_osgb",
16+
):
917
"""
1018
Downsample xarray dataset/dataarrays with coarsen
1119
@@ -25,8 +33,10 @@ def __init__(self, source_datapipe: IterDataPipe, y_coarsen: int, x_coarsen: int
2533
def __iter__(self):
2634
"""Coarsen the data on the specified dimensions"""
2735
for xr_data in self.source_datapipe:
28-
yield xr_data.coarsen({
29-
self.y_dim_name: self.y_coarsen,
30-
self.x_dim_name: self.x_coarsen,},
36+
yield xr_data.coarsen(
37+
{
38+
self.y_dim_name: self.y_coarsen,
39+
self.x_dim_name: self.x_coarsen,
40+
},
3141
boundary="trim",
3242
).mean()

ocf_datapipes/transform/xarray/normalize.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,21 @@
1-
import xarray as xr
1+
from typing import Callable, Optional, Union
2+
23
import numpy as np
3-
from typing import Optional, Union, Callable
4+
import xarray as xr
45
from torchdata.datapipes import functional_datapipe
56
from torchdata.datapipes.iter import IterDataPipe
67

78

89
@functional_datapipe("normalize")
910
class NormalizeIterDataPipe(IterDataPipe):
1011
"""Normalize the data in various methods"""
12+
1113
def __init__(
1214
self,
1315
source_datapipe: IterDataPipe,
14-
mean: Optional[Union[xr.Dataset, xr.DataArray, np.ndarray]]=None,
15-
std: Optional[Union[xr.Dataset, xr.DataArray, np.ndarray]]=None,
16-
max_value: Optional[Union[int, float]]=None,
16+
mean: Optional[Union[xr.Dataset, xr.DataArray, np.ndarray]] = None,
17+
std: Optional[Union[xr.Dataset, xr.DataArray, np.ndarray]] = None,
18+
max_value: Optional[Union[int, float]] = None,
1719
calculate_mean_std_from_example: bool = False,
1820
normalize_fn: Optional[Callable] = None,
1921
):

ocf_datapipes/transform/xarray/pv_power_rolling_window.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
@functional_datapipe("pv_power_rolling_window")
1010
class PVPowerRollingWindowIterDataPipe(IterDataPipe):
1111
"""Compute rolling mean of PV power."""
12+
1213
def __init__(
1314
self,
1415
source_datapipe: IterDataPipe,

tests/convert/test_gsp.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
1-
from ocf_datapipes.convert import ConvertGSPToNumpyBatch
2-
from ocf_datapipes.transform.xarray import AddT0IdxAndSamplePeriodDuration
3-
41
from datetime import timedelta
52

3+
from ocf_datapipes.convert import ConvertGSPToNumpyBatch
4+
from ocf_datapipes.transform.xarray import AddT0IdxAndSamplePeriodDuration
65

76

87
def test_convert_gsp_to_numpy_batch(gsp_dp):

tests/convert/test_nwp.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,15 @@
1+
from datetime import timedelta
2+
13
from ocf_datapipes.convert import ConvertNWPToNumpyBatch
2-
from ocf_datapipes.transform.xarray import AddT0IdxAndSamplePeriodDuration, ConvertToNWPTargetTime
34
from ocf_datapipes.select import SelectLiveT0Time
4-
5-
from datetime import timedelta
5+
from ocf_datapipes.transform.xarray import AddT0IdxAndSamplePeriodDuration, ConvertToNWPTargetTime
66

77

88
def test_convert_nwp_to_numpy_batch(nwp_dp):
99
nwp_dp = AddT0IdxAndSamplePeriodDuration(
1010
nwp_dp, sample_period_duration=timedelta(minutes=60), history_duration=timedelta(minutes=60)
1111
)
12-
t0_dp = SelectLiveT0Time(nwp_dp, dim_name='init_time_utc')
12+
t0_dp = SelectLiveT0Time(nwp_dp, dim_name="init_time_utc")
1313

1414
nwp_dp = ConvertToNWPTargetTime(
1515
nwp_dp,

tests/convert/test_pv.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
1+
from datetime import timedelta
2+
13
from ocf_datapipes.convert import ConvertPVToNumpyBatch
24
from ocf_datapipes.transform.xarray import AddT0IdxAndSamplePeriodDuration
35

4-
from datetime import timedelta
5-
66

77
def test_convert_passiv_to_numpy_batch(passiv_dp):
88
passiv_dp = AddT0IdxAndSamplePeriodDuration(
9-
passiv_dp, sample_period_duration=timedelta(minutes=5), history_duration=timedelta(minutes=60)
9+
passiv_dp,
10+
sample_period_duration=timedelta(minutes=5),
11+
history_duration=timedelta(minutes=60),
1012
)
1113
passiv_dp = ConvertPVToNumpyBatch(passiv_dp)
1214
data = next(iter(passiv_dp))
@@ -15,7 +17,9 @@ def test_convert_passiv_to_numpy_batch(passiv_dp):
1517

1618
def test_convert_pvoutput_to_numpy_batch(pvoutput_dp):
1719
pvoutput_dp = AddT0IdxAndSamplePeriodDuration(
18-
pvoutput_dp, sample_period_duration=timedelta(minutes=5), history_duration=timedelta(minutes=60)
20+
pvoutput_dp,
21+
sample_period_duration=timedelta(minutes=5),
22+
history_duration=timedelta(minutes=60),
1923
)
2024

2125
pvoutput_dp = ConvertPVToNumpyBatch(pvoutput_dp)

tests/convert/test_satellite.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
1+
from datetime import timedelta
2+
13
from ocf_datapipes.convert import ConvertSatelliteToNumpyBatch
24
from ocf_datapipes.transform.xarray import AddT0IdxAndSamplePeriodDuration
35

4-
from datetime import timedelta
5-
66

77
def test_convert_satellite_to_numpy_batch(sat_dp):
88

tests/end2end/test_power_perceiver_production.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
ConvertSatelliteToNumpyBatch,
1818
)
1919
from ocf_datapipes.experimental import EnsureNNWPVariables, SetSystemIDsToOne
20+
from ocf_datapipes.production.power_perceiver import GSPIterator
2021
from ocf_datapipes.select import (
2122
LocationPicker,
2223
SelectLiveT0Time,
@@ -43,7 +44,6 @@
4344
ReprojectTopography,
4445
)
4546
from ocf_datapipes.utils.consts import NWP_MEAN, NWP_STD, SAT_MEAN, SAT_STD, BatchKey
46-
from ocf_datapipes.production.power_perceiver import GSPIterator
4747

4848

4949
def test_power_perceiver_production(sat_hrv_dp, passiv_dp, topo_dp, gsp_dp, nwp_dp):

tests/load/test_load_config.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,7 @@
1-
21
from ocf_datapipes.load import OpenConfiguration
32

43

54
def test_open_config():
65
config_dp = OpenConfiguration("tests/config/test.yaml")
76
configuration = next(iter(config_dp))
87
print(configuration)
9-

tests/production/test_pp_production.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
1-
from ocf_datapipes.production.power_perceiver import power_perceiver_production_datapipe
2-
from ocf_datapipes.utils.consts import BatchKey
3-
4-
import ocf_datapipes
51
import os
2+
63
import pytest
74

5+
import ocf_datapipes
6+
from ocf_datapipes.production.power_perceiver import power_perceiver_production_datapipe
7+
from ocf_datapipes.utils.consts import BatchKey
8+
89

9-
@pytest.mark.skip('Need to set up laod PV from database first')
10+
@pytest.mark.skip("Need to set up laod PV from database first")
1011
def test_pp_production_datapipe():
1112

1213
filename = os.path.join(os.path.dirname(ocf_datapipes.__file__), "../tests/config/test.yaml")

tests/select/test_select_live_t0_time_slice.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def test_select_passiv(passiv_dp):
5858

5959
def test_select_pvoutput(pvoutput_dp):
6060
time_len = len(next(iter(pvoutput_dp)).time_utc.values)
61-
t0_dp = SelectLiveT0Time(pvoutput_dp, dim_name='time_utc')
61+
t0_dp = SelectLiveT0Time(pvoutput_dp, dim_name="time_utc")
6262
pvoutput_dp = SelectLiveTimeSlice(
6363
pvoutput_dp,
6464
t0_datapipe=t0_dp,

tests/select/test_select_live_time_slice.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,16 @@
11
from datetime import timedelta
22

33
from ocf_datapipes.select import SelectLiveT0Time, SelectLiveTimeSlice
4+
45
# from ocf_datapipes.transform.xarray import AddNWPTargetTime
56

67

78
def test_select_hrv(sat_hrv_dp):
89
time_len = len(next(iter(sat_hrv_dp)).time_utc.values)
910
t0_dp = SelectLiveT0Time(sat_hrv_dp, dim_name="time_utc")
10-
sat_hrv_dp = SelectLiveTimeSlice(sat_hrv_dp, history_duration=timedelta(minutes=60), t0_datapipe=t0_dp)
11+
sat_hrv_dp = SelectLiveTimeSlice(
12+
sat_hrv_dp, history_duration=timedelta(minutes=60), t0_datapipe=t0_dp
13+
)
1114
data = next(iter(sat_hrv_dp))
1215
assert len(data.time_utc.values) == 13
1316
assert len(data.time_utc.values) < time_len
@@ -49,7 +52,9 @@ def test_select_gsp(gsp_dp):
4952
def test_select_passiv(passiv_dp):
5053
time_len = len(next(iter(passiv_dp)).time_utc.values)
5154
t0_dp = SelectLiveT0Time(passiv_dp, dim_name="time_utc")
52-
passiv_dp = SelectLiveTimeSlice(passiv_dp, history_duration=timedelta(minutes=60), t0_datapipe=t0_dp)
55+
passiv_dp = SelectLiveTimeSlice(
56+
passiv_dp, history_duration=timedelta(minutes=60), t0_datapipe=t0_dp
57+
)
5358
data = next(iter(passiv_dp))
5459
assert len(data.time_utc.values) == 13
5560
assert len(data.time_utc.values) < time_len
@@ -58,7 +63,9 @@ def test_select_passiv(passiv_dp):
5863
def test_select_pvoutput(pvoutput_dp):
5964
time_len = len(next(iter(pvoutput_dp)).time_utc.values)
6065
t0_dp = SelectLiveT0Time(pvoutput_dp, dim_name="time_utc")
61-
pvoutput_dp = SelectLiveTimeSlice(pvoutput_dp, history_duration=timedelta(minutes=60), t0_datapipe=t0_dp)
66+
pvoutput_dp = SelectLiveTimeSlice(
67+
pvoutput_dp, history_duration=timedelta(minutes=60), t0_datapipe=t0_dp
68+
)
6269
data = next(iter(pvoutput_dp))
6370
assert len(data.time_utc.values) == 13
6471
assert len(data.time_utc.values) < time_len

tests/transform/xarray/test_add_nwp_target_time.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,18 @@
1-
from ocf_datapipes.transform.xarray import ConvertToNWPTargetTime
2-
from ocf_datapipes.select import SelectLiveT0Time
31
from datetime import timedelta
42

3+
from ocf_datapipes.select import SelectLiveT0Time
4+
from ocf_datapipes.transform.xarray import ConvertToNWPTargetTime
5+
6+
57
def test_add_nwp_target_time(nwp_dp):
68
t0_dp = SelectLiveT0Time(nwp_dp, dim_name="init_time_utc")
7-
nwp_dp = ConvertToNWPTargetTime(nwp_dp, t0_dp, sample_period_duration=timedelta(minutes=60), history_duration=timedelta(hours=2), forecast_duration=timedelta(hours=3))
9+
nwp_dp = ConvertToNWPTargetTime(
10+
nwp_dp,
11+
t0_dp,
12+
sample_period_duration=timedelta(minutes=60),
13+
history_duration=timedelta(hours=2),
14+
forecast_duration=timedelta(hours=3),
15+
)
816
data = next(iter(nwp_dp))
917
assert "target_time_utc" in data.coords
1018
assert len(data.coords["target_time_utc"]) == 6
11-

tests/transform/xarray/test_downsample.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,13 @@ def test_nwp_downsample(nwp_dp):
1010

1111

1212
def test_sat_downsample(sat_dp):
13-
sat_dp = Downsample(sat_dp, y_coarsen=16, x_coarsen=16, y_dim_name="y_geostationary", x_dim_name="x_geostationary")
13+
sat_dp = Downsample(
14+
sat_dp,
15+
y_coarsen=16,
16+
x_coarsen=16,
17+
y_dim_name="y_geostationary",
18+
x_dim_name="x_geostationary",
19+
)
1420
data = next(iter(sat_dp))
1521
assert data.shape[-1] == 38
1622
assert data.shape[-2] == 18

0 commit comments

Comments
 (0)