Skip to content
This repository was archived by the owner on Jun 2, 2025. It is now read-only.

Commit 9ff98d2

Browse files
authored
Merge pull request #356 from openclimatefix/time-divisibility
Enforce forecast and history duration divisibility by time resolution
2 parents a6eec2b + ddae643 commit 9ff98d2

File tree

2 files changed

+95
-18
lines changed

2 files changed

+95
-18
lines changed

ocf_datapipes/config/model.py

Lines changed: 71 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import git
2020
import numpy as np
2121
from pathy import Pathy
22-
from pydantic import BaseModel, Field, RootModel, model_validator, validator
22+
from pydantic import BaseModel, Field, RootModel, ValidationInfo, field_validator, model_validator
2323

2424
# nowcasting_dataset imports
2525
from ocf_datapipes.utils.consts import (
@@ -93,8 +93,8 @@ class DataSourceMixin(Base):
9393

9494
log_level: str = Field(
9595
"DEBUG",
96-
description="The logging level for this data source. T"
97-
"his is the default value and can be set in each data source",
96+
description="The logging level for this data source. "
97+
"This is the default value and can be set in each data source",
9898
)
9999

100100
@property
@@ -139,16 +139,16 @@ class DropoutMixin(Base):
139139

140140
dropout_fraction: float = Field(0, description="Chance of dropout being applied to each sample")
141141

142-
@validator("dropout_timedeltas_minutes")
143-
def dropout_timedeltas_minutes_negative(cls, v):
142+
@field_validator("dropout_timedeltas_minutes")
143+
def dropout_timedeltas_minutes_negative(cls, v: List[int]) -> List[int]:
144144
"""Validate 'dropout_timedeltas_minutes'"""
145145
if v is not None:
146146
for m in v:
147147
assert m <= 0
148148
return v
149149

150-
@validator("dropout_fraction")
151-
def dropout_fraction_valid(cls, v):
150+
@field_validator("dropout_fraction")
151+
def dropout_fraction_valid(cls, v: float) -> float:
152152
"""Validate 'dropout_fraction'"""
153153
assert 0 <= v <= 1
154154
return v
@@ -169,8 +169,8 @@ class SystemDropoutMixin(Base):
169169
system_dropout_fraction_min: float = Field(0, description="Min chance of system dropout")
170170
system_dropout_fraction_max: float = Field(0, description="Max chance of system dropout")
171171

172-
@validator("system_dropout_fraction_min", "system_dropout_fraction_max")
173-
def validate_system_dropout_fractions(cls, v):
172+
@field_validator("system_dropout_fraction_min", "system_dropout_fraction_max")
173+
def validate_system_dropout_fractions(cls, v: float):
174174
"""Validate dropout fraction values"""
175175
assert 0 <= v <= 1
176176
return v
@@ -192,8 +192,8 @@ class TimeResolutionMixin(Base):
192192
"Note that this needs to be divisible by 5.",
193193
)
194194

195-
@validator("time_resolution_minutes")
196-
def forecast_minutes_divide_by_5(cls, v):
195+
@field_validator("time_resolution_minutes")
196+
def forecast_minutes_divide_by_5(cls, v: int) -> int:
197197
"""Validate 'forecast_minutes'"""
198198
assert v % 5 == 0, f"The time resolution ({v}) is not divisible by 5"
199199
return v
@@ -257,7 +257,6 @@ class Wind(DataSourceMixin, TimeResolutionMixin, XYDimensionalNames, DropoutMixi
257257
None,
258258
description="List of the ML IDs of the Wind systems you'd like to filter to.",
259259
)
260-
time_resolution_minutes: int = Field(15, description="The temporal resolution (in minutes).")
261260
wind_image_size_meters_height: int = METERS_PER_ROI
262261
wind_image_size_meters_width: int = METERS_PER_ROI
263262
n_wind_systems_per_example: int = Field(
@@ -286,6 +285,24 @@ class Wind(DataSourceMixin, TimeResolutionMixin, XYDimensionalNames, DropoutMixi
286285
"Note that this needs to be divisible by 5.",
287286
)
288287

288+
@field_validator("forecast_minutes")
289+
def forecast_minutes_divide_by_time_resolution(cls, v: int, info: ValidationInfo) -> int:
290+
"""Check forecast length requested will give stable number of timesteps"""
291+
if v % info.data["time_resolution_minutes"] != 0:
292+
message = "Forecast duration must be divisible by time resolution"
293+
logger.error(message)
294+
raise Exception(message)
295+
return v
296+
297+
@field_validator("history_minutes")
298+
def history_minutes_divide_by_time_resolution(cls, v: int, info: ValidationInfo) -> int:
299+
"""Check history length requested will give stable number of timesteps"""
300+
if v % info.data["time_resolution_minutes"] != 0:
301+
message = "History duration must be divisible by time resolution"
302+
logger.error(message)
303+
raise Exception(message)
304+
return v
305+
289306

290307
class PVFiles(BaseModel):
291308
"""Model to hold pv file and metadata file"""
@@ -305,8 +322,8 @@ class PVFiles(BaseModel):
305322

306323
label: Optional[str] = Field(providers[0], description="Label of where the pv data came from")
307324

308-
@validator("label")
309-
def v_label0(cls, v):
325+
@field_validator("label")
326+
def v_label0(cls, v: str) -> str:
310327
"""Validate 'label'"""
311328
if v not in providers:
312329
message = f"provider {v} not in {providers}"
@@ -385,6 +402,24 @@ def model_validation(cls, v):
385402

386403
return v
387404

405+
@field_validator("forecast_minutes")
406+
def forecast_minutes_divide_by_time_resolution(cls, v: int, info: ValidationInfo) -> int:
407+
"""Check forecast length requested will give stable number of timesteps"""
408+
if v % info.data["time_resolution_minutes"] != 0:
409+
message = "Forecast duration must be divisible by time resolution"
410+
logger.error(message)
411+
raise Exception(message)
412+
return v
413+
414+
@field_validator("history_minutes")
415+
def history_minutes_divide_by_time_resolution(cls, v: int, info: ValidationInfo) -> int:
416+
"""Check history length requested will give stable number of timesteps"""
417+
if v % info.data["time_resolution_minutes"] != 0:
418+
message = "History duration must be divisible by time resolution"
419+
logger.error(message)
420+
raise Exception(message)
421+
return v
422+
388423

389424
class Sensor(DataSourceMixin, TimeResolutionMixin, XYDimensionalNames):
390425
"""PV configuration model"""
@@ -599,15 +634,33 @@ class NWP(DataSourceMixin, TimeResolutionMixin, XYDimensionalNames, DropoutMixin
599634
0.1, description="The number of degrees to coarsen the NWP data to"
600635
)
601636

602-
@validator("nwp_provider")
603-
def validate_nwp_provider(cls, v):
637+
@field_validator("nwp_provider")
638+
def validate_nwp_provider(cls, v: str) -> str:
604639
"""Validate 'nwp_provider'"""
605640
if v.lower() not in NWP_PROVIDERS:
606641
message = f"NWP provider {v} is not in {NWP_PROVIDERS}"
607642
logger.warning(message)
608643
assert Exception(message)
609644
return v
610645

646+
@field_validator("forecast_minutes")
647+
def forecast_minutes_divide_by_time_resolution(cls, v: int, info: ValidationInfo) -> int:
648+
"""Check forecast length requested will give stable number of timesteps"""
649+
if v % info.data["time_resolution_minutes"] != 0:
650+
message = "Forecast duration must be divisible by time resolution"
651+
logger.error(message)
652+
raise Exception(message)
653+
return v
654+
655+
@field_validator("history_minutes")
656+
def history_minutes_divide_by_time_resolution(cls, v: int, info: ValidationInfo) -> int:
657+
"""Check history length requested will give stable number of timesteps"""
658+
if v % info.data["time_resolution_minutes"] != 0:
659+
message = "History duration must be divisible by time resolution"
660+
logger.error(message)
661+
raise Exception(message)
662+
return v
663+
611664

612665
class MultiNWP(RootModel):
613666
"""Configuration for multiple NWPs"""
@@ -668,13 +721,13 @@ class GSP(DataSourceMixin, TimeResolutionMixin, DropoutMixin):
668721
"Note that this needs to be divisible by 5.",
669722
)
670723

671-
@validator("history_minutes")
724+
@field_validator("history_minutes")
672725
def history_minutes_divide_by_30(cls, v):
673726
"""Validate 'history_minutes'"""
674727
assert v % 30 == 0 # this means it also divides by 5
675728
return v
676729

677-
@validator("forecast_minutes")
730+
@field_validator("forecast_minutes")
678731
def forecast_minutes_divide_by_30(cls, v):
679732
"""Validate 'forecast_minutes'"""
680733
assert v % 30 == 0 # this means it also divides by 5

tests/config/test_config.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,3 +98,27 @@ def test_config_git(configuration_filename):
9898
assert type(config.git.message) == str
9999
assert type(config.git.hash) == str
100100
assert type(config.git.committed_date) == datetime
101+
102+
103+
def test_incorrect_forecast_minutes():
104+
"""
105+
Check a forecast length no divisible by time resolution causes error
106+
"""
107+
108+
configuration = Configuration()
109+
configuration.input_data = configuration.input_data.set_all_to_defaults()
110+
configuration.input_data.wind.forecast_minutes = 1111
111+
with pytest.raises(Exception):
112+
_ = Configuration(**configuration.dict())
113+
114+
115+
def test_incorrect_history_minutes():
116+
"""
117+
Check a forecast length no divisible by time resolution causes error
118+
"""
119+
120+
configuration = Configuration()
121+
configuration.input_data = configuration.input_data.set_all_to_defaults()
122+
configuration.input_data.wind.history_minutes = 1111
123+
with pytest.raises(Exception):
124+
_ = Configuration(**configuration.dict())

0 commit comments

Comments
 (0)