Skip to content
This repository was archived by the owner on Jun 2, 2025. It is now read-only.

Enforce forecast and history duration divisibility by time resolution #356

Merged
merged 4 commits into from
Aug 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 71 additions & 18 deletions ocf_datapipes/config/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import git
import numpy as np
from pathy import Pathy
from pydantic import BaseModel, Field, RootModel, model_validator, validator
from pydantic import BaseModel, Field, RootModel, ValidationInfo, field_validator, model_validator

# nowcasting_dataset imports
from ocf_datapipes.utils.consts import (
Expand Down Expand Up @@ -93,8 +93,8 @@ class DataSourceMixin(Base):

log_level: str = Field(
"DEBUG",
description="The logging level for this data source. T"
"his is the default value and can be set in each data source",
description="The logging level for this data source. "
"This is the default value and can be set in each data source",
)

@property
Expand Down Expand Up @@ -139,16 +139,16 @@ class DropoutMixin(Base):

dropout_fraction: float = Field(0, description="Chance of dropout being applied to each sample")

@validator("dropout_timedeltas_minutes")
def dropout_timedeltas_minutes_negative(cls, v):
@field_validator("dropout_timedeltas_minutes")
def dropout_timedeltas_minutes_negative(cls, v: List[int]) -> List[int]:
"""Validate 'dropout_timedeltas_minutes'"""
if v is not None:
for m in v:
assert m <= 0
return v

@validator("dropout_fraction")
def dropout_fraction_valid(cls, v):
@field_validator("dropout_fraction")
def dropout_fraction_valid(cls, v: float) -> float:
"""Validate 'dropout_fraction'"""
assert 0 <= v <= 1
return v
Expand All @@ -169,8 +169,8 @@ class SystemDropoutMixin(Base):
system_dropout_fraction_min: float = Field(0, description="Min chance of system dropout")
system_dropout_fraction_max: float = Field(0, description="Max chance of system dropout")

@validator("system_dropout_fraction_min", "system_dropout_fraction_max")
def validate_system_dropout_fractions(cls, v):
@field_validator("system_dropout_fraction_min", "system_dropout_fraction_max")
def validate_system_dropout_fractions(cls, v: float):
"""Validate dropout fraction values"""
assert 0 <= v <= 1
return v
Expand All @@ -192,8 +192,8 @@ class TimeResolutionMixin(Base):
"Note that this needs to be divisible by 5.",
)

@validator("time_resolution_minutes")
def forecast_minutes_divide_by_5(cls, v):
@field_validator("time_resolution_minutes")
def forecast_minutes_divide_by_5(cls, v: int) -> int:
"""Validate 'forecast_minutes'"""
assert v % 5 == 0, f"The time resolution ({v}) is not divisible by 5"
return v
Expand Down Expand Up @@ -257,7 +257,6 @@ class Wind(DataSourceMixin, TimeResolutionMixin, XYDimensionalNames, DropoutMixi
None,
description="List of the ML IDs of the Wind systems you'd like to filter to.",
)
time_resolution_minutes: int = Field(15, description="The temporal resolution (in minutes).")
wind_image_size_meters_height: int = METERS_PER_ROI
wind_image_size_meters_width: int = METERS_PER_ROI
n_wind_systems_per_example: int = Field(
Expand Down Expand Up @@ -286,6 +285,24 @@ class Wind(DataSourceMixin, TimeResolutionMixin, XYDimensionalNames, DropoutMixi
"Note that this needs to be divisible by 5.",
)

@field_validator("forecast_minutes")
def forecast_minutes_divide_by_time_resolution(cls, v: int, info: ValidationInfo) -> int:
"""Check forecast length requested will give stable number of timesteps"""
if v % info.data["time_resolution_minutes"] != 0:
message = "Forecast duration must be divisible by time resolution"
logger.error(message)
raise Exception(message)
return v

@field_validator("history_minutes")
def history_minutes_divide_by_time_resolution(cls, v: int, info: ValidationInfo) -> int:
"""Check history length requested will give stable number of timesteps"""
if v % info.data["time_resolution_minutes"] != 0:
message = "History duration must be divisible by time resolution"
logger.error(message)
raise Exception(message)
return v


class PVFiles(BaseModel):
"""Model to hold pv file and metadata file"""
Expand All @@ -305,8 +322,8 @@ class PVFiles(BaseModel):

label: Optional[str] = Field(providers[0], description="Label of where the pv data came from")

@validator("label")
def v_label0(cls, v):
@field_validator("label")
def v_label0(cls, v: str) -> str:
"""Validate 'label'"""
if v not in providers:
message = f"provider {v} not in {providers}"
Expand Down Expand Up @@ -385,6 +402,24 @@ def model_validation(cls, v):

return v

@field_validator("forecast_minutes")
def forecast_minutes_divide_by_time_resolution(cls, v: int, info: ValidationInfo) -> int:
"""Check forecast length requested will give stable number of timesteps"""
if v % info.data["time_resolution_minutes"] != 0:
message = "Forecast duration must be divisible by time resolution"
logger.error(message)
raise Exception(message)
return v

@field_validator("history_minutes")
def history_minutes_divide_by_time_resolution(cls, v: int, info: ValidationInfo) -> int:
"""Check history length requested will give stable number of timesteps"""
if v % info.data["time_resolution_minutes"] != 0:
message = "History duration must be divisible by time resolution"
logger.error(message)
raise Exception(message)
return v


class Sensor(DataSourceMixin, TimeResolutionMixin, XYDimensionalNames):
"""PV configuration model"""
Expand Down Expand Up @@ -599,15 +634,33 @@ class NWP(DataSourceMixin, TimeResolutionMixin, XYDimensionalNames, DropoutMixin
0.1, description="The number of degrees to coarsen the NWP data to"
)

@validator("nwp_provider")
def validate_nwp_provider(cls, v):
@field_validator("nwp_provider")
def validate_nwp_provider(cls, v: str) -> str:
"""Validate 'nwp_provider'"""
if v.lower() not in NWP_PROVIDERS:
message = f"NWP provider {v} is not in {NWP_PROVIDERS}"
logger.warning(message)
assert Exception(message)
return v

@field_validator("forecast_minutes")
def forecast_minutes_divide_by_time_resolution(cls, v: int, info: ValidationInfo) -> int:
"""Check forecast length requested will give stable number of timesteps"""
if v % info.data["time_resolution_minutes"] != 0:
message = "Forecast duration must be divisible by time resolution"
logger.error(message)
raise Exception(message)
return v

@field_validator("history_minutes")
def history_minutes_divide_by_time_resolution(cls, v: int, info: ValidationInfo) -> int:
"""Check history length requested will give stable number of timesteps"""
if v % info.data["time_resolution_minutes"] != 0:
message = "History duration must be divisible by time resolution"
logger.error(message)
raise Exception(message)
return v


class MultiNWP(RootModel):
"""Configuration for multiple NWPs"""
Expand Down Expand Up @@ -668,13 +721,13 @@ class GSP(DataSourceMixin, TimeResolutionMixin, DropoutMixin):
"Note that this needs to be divisible by 5.",
)

@validator("history_minutes")
@field_validator("history_minutes")
def history_minutes_divide_by_30(cls, v):
"""Validate 'history_minutes'"""
assert v % 30 == 0 # this means it also divides by 5
return v

@validator("forecast_minutes")
@field_validator("forecast_minutes")
def forecast_minutes_divide_by_30(cls, v):
"""Validate 'forecast_minutes'"""
assert v % 30 == 0 # this means it also divides by 5
Expand Down
24 changes: 24 additions & 0 deletions tests/config/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,3 +98,27 @@ def test_config_git(configuration_filename):
assert type(config.git.message) == str
assert type(config.git.hash) == str
assert type(config.git.committed_date) == datetime


def test_incorrect_forecast_minutes():
"""
Check a forecast length no divisible by time resolution causes error
"""

configuration = Configuration()
configuration.input_data = configuration.input_data.set_all_to_defaults()
configuration.input_data.wind.forecast_minutes = 1111
with pytest.raises(Exception):
_ = Configuration(**configuration.dict())


def test_incorrect_history_minutes():
"""
Check a forecast length no divisible by time resolution causes error
"""

configuration = Configuration()
configuration.input_data = configuration.input_data.set_all_to_defaults()
configuration.input_data.wind.history_minutes = 1111
with pytest.raises(Exception):
_ = Configuration(**configuration.dict())
Loading