19
19
import git
20
20
import numpy as np
21
21
from pathy import Pathy
22
- from pydantic import BaseModel , Field , RootModel , model_validator , validator
22
+ from pydantic import BaseModel , Field , RootModel , ValidationInfo , field_validator , model_validator
23
23
24
24
# nowcasting_dataset imports
25
25
from ocf_datapipes .utils .consts import (
@@ -93,8 +93,8 @@ class DataSourceMixin(Base):
93
93
94
94
log_level : str = Field (
95
95
"DEBUG" ,
96
- description = "The logging level for this data source. T "
97
- "his is the default value and can be set in each data source" ,
96
+ description = "The logging level for this data source. "
97
+ "This is the default value and can be set in each data source" ,
98
98
)
99
99
100
100
@property
@@ -139,16 +139,16 @@ class DropoutMixin(Base):
139
139
140
140
dropout_fraction : float = Field (0 , description = "Chance of dropout being applied to each sample" )
141
141
142
- @validator ("dropout_timedeltas_minutes" )
143
- def dropout_timedeltas_minutes_negative (cls , v ) :
142
+ @field_validator ("dropout_timedeltas_minutes" )
143
+ def dropout_timedeltas_minutes_negative (cls , v : List [ int ]) -> List [ int ] :
144
144
"""Validate 'dropout_timedeltas_minutes'"""
145
145
if v is not None :
146
146
for m in v :
147
147
assert m <= 0
148
148
return v
149
149
150
- @validator ("dropout_fraction" )
151
- def dropout_fraction_valid (cls , v ) :
150
+ @field_validator ("dropout_fraction" )
151
+ def dropout_fraction_valid (cls , v : float ) -> float :
152
152
"""Validate 'dropout_fraction'"""
153
153
assert 0 <= v <= 1
154
154
return v
@@ -169,8 +169,8 @@ class SystemDropoutMixin(Base):
169
169
system_dropout_fraction_min : float = Field (0 , description = "Min chance of system dropout" )
170
170
system_dropout_fraction_max : float = Field (0 , description = "Max chance of system dropout" )
171
171
172
- @validator ("system_dropout_fraction_min" , "system_dropout_fraction_max" )
173
- def validate_system_dropout_fractions (cls , v ):
172
+ @field_validator ("system_dropout_fraction_min" , "system_dropout_fraction_max" )
173
+ def validate_system_dropout_fractions (cls , v : float ):
174
174
"""Validate dropout fraction values"""
175
175
assert 0 <= v <= 1
176
176
return v
@@ -192,8 +192,8 @@ class TimeResolutionMixin(Base):
192
192
"Note that this needs to be divisible by 5." ,
193
193
)
194
194
195
- @validator ("time_resolution_minutes" )
196
- def forecast_minutes_divide_by_5 (cls , v ) :
195
+ @field_validator ("time_resolution_minutes" )
196
+ def forecast_minutes_divide_by_5 (cls , v : int ) -> int :
197
197
"""Validate 'forecast_minutes'"""
198
198
assert v % 5 == 0 , f"The time resolution ({ v } ) is not divisible by 5"
199
199
return v
@@ -257,7 +257,6 @@ class Wind(DataSourceMixin, TimeResolutionMixin, XYDimensionalNames, DropoutMixi
257
257
None ,
258
258
description = "List of the ML IDs of the Wind systems you'd like to filter to." ,
259
259
)
260
- time_resolution_minutes : int = Field (15 , description = "The temporal resolution (in minutes)." )
261
260
wind_image_size_meters_height : int = METERS_PER_ROI
262
261
wind_image_size_meters_width : int = METERS_PER_ROI
263
262
n_wind_systems_per_example : int = Field (
@@ -286,6 +285,24 @@ class Wind(DataSourceMixin, TimeResolutionMixin, XYDimensionalNames, DropoutMixi
286
285
"Note that this needs to be divisible by 5." ,
287
286
)
288
287
288
+ @field_validator ("forecast_minutes" )
289
+ def forecast_minutes_divide_by_time_resolution (cls , v : int , info : ValidationInfo ) -> int :
290
+ """Check forecast length requested will give stable number of timesteps"""
291
+ if v % info .data ["time_resolution_minutes" ] != 0 :
292
+ message = "Forecast duration must be divisible by time resolution"
293
+ logger .error (message )
294
+ raise Exception (message )
295
+ return v
296
+
297
+ @field_validator ("history_minutes" )
298
+ def history_minutes_divide_by_time_resolution (cls , v : int , info : ValidationInfo ) -> int :
299
+ """Check history length requested will give stable number of timesteps"""
300
+ if v % info .data ["time_resolution_minutes" ] != 0 :
301
+ message = "History duration must be divisible by time resolution"
302
+ logger .error (message )
303
+ raise Exception (message )
304
+ return v
305
+
289
306
290
307
class PVFiles (BaseModel ):
291
308
"""Model to hold pv file and metadata file"""
@@ -305,8 +322,8 @@ class PVFiles(BaseModel):
305
322
306
323
label : Optional [str ] = Field (providers [0 ], description = "Label of where the pv data came from" )
307
324
308
- @validator ("label" )
309
- def v_label0 (cls , v ) :
325
+ @field_validator ("label" )
326
+ def v_label0 (cls , v : str ) -> str :
310
327
"""Validate 'label'"""
311
328
if v not in providers :
312
329
message = f"provider { v } not in { providers } "
@@ -385,6 +402,24 @@ def model_validation(cls, v):
385
402
386
403
return v
387
404
405
+ @field_validator ("forecast_minutes" )
406
+ def forecast_minutes_divide_by_time_resolution (cls , v : int , info : ValidationInfo ) -> int :
407
+ """Check forecast length requested will give stable number of timesteps"""
408
+ if v % info .data ["time_resolution_minutes" ] != 0 :
409
+ message = "Forecast duration must be divisible by time resolution"
410
+ logger .error (message )
411
+ raise Exception (message )
412
+ return v
413
+
414
+ @field_validator ("history_minutes" )
415
+ def history_minutes_divide_by_time_resolution (cls , v : int , info : ValidationInfo ) -> int :
416
+ """Check history length requested will give stable number of timesteps"""
417
+ if v % info .data ["time_resolution_minutes" ] != 0 :
418
+ message = "History duration must be divisible by time resolution"
419
+ logger .error (message )
420
+ raise Exception (message )
421
+ return v
422
+
388
423
389
424
class Sensor (DataSourceMixin , TimeResolutionMixin , XYDimensionalNames ):
390
425
"""PV configuration model"""
@@ -599,15 +634,33 @@ class NWP(DataSourceMixin, TimeResolutionMixin, XYDimensionalNames, DropoutMixin
599
634
0.1 , description = "The number of degrees to coarsen the NWP data to"
600
635
)
601
636
602
- @validator ("nwp_provider" )
603
- def validate_nwp_provider (cls , v ) :
637
+ @field_validator ("nwp_provider" )
638
+ def validate_nwp_provider (cls , v : str ) -> str :
604
639
"""Validate 'nwp_provider'"""
605
640
if v .lower () not in NWP_PROVIDERS :
606
641
message = f"NWP provider { v } is not in { NWP_PROVIDERS } "
607
642
logger .warning (message )
608
643
assert Exception (message )
609
644
return v
610
645
646
+ @field_validator ("forecast_minutes" )
647
+ def forecast_minutes_divide_by_time_resolution (cls , v : int , info : ValidationInfo ) -> int :
648
+ """Check forecast length requested will give stable number of timesteps"""
649
+ if v % info .data ["time_resolution_minutes" ] != 0 :
650
+ message = "Forecast duration must be divisible by time resolution"
651
+ logger .error (message )
652
+ raise Exception (message )
653
+ return v
654
+
655
+ @field_validator ("history_minutes" )
656
+ def history_minutes_divide_by_time_resolution (cls , v : int , info : ValidationInfo ) -> int :
657
+ """Check history length requested will give stable number of timesteps"""
658
+ if v % info .data ["time_resolution_minutes" ] != 0 :
659
+ message = "History duration must be divisible by time resolution"
660
+ logger .error (message )
661
+ raise Exception (message )
662
+ return v
663
+
611
664
612
665
class MultiNWP (RootModel ):
613
666
"""Configuration for multiple NWPs"""
@@ -668,13 +721,13 @@ class GSP(DataSourceMixin, TimeResolutionMixin, DropoutMixin):
668
721
"Note that this needs to be divisible by 5." ,
669
722
)
670
723
671
- @validator ("history_minutes" )
724
+ @field_validator ("history_minutes" )
672
725
def history_minutes_divide_by_30 (cls , v ):
673
726
"""Validate 'history_minutes'"""
674
727
assert v % 30 == 0 # this means it also divides by 5
675
728
return v
676
729
677
- @validator ("forecast_minutes" )
730
+ @field_validator ("forecast_minutes" )
678
731
def forecast_minutes_divide_by_30 (cls , v ):
679
732
"""Validate 'forecast_minutes'"""
680
733
assert v % 30 == 0 # this means it also divides by 5
0 commit comments