Skip to content
This repository was archived by the owner on Jun 2, 2025. It is now read-only.

Commit ebd8e28

Browse files
authored
Refactor common.py: Added helper function to reduce code duplication (#291)
Added helper function to reduce the repetition of code for config validation checks
1 parent be712a1 commit ebd8e28

File tree

1 file changed

+48
-27
lines changed

1 file changed

+48
-27
lines changed

ocf_datapipes/training/common.py

Lines changed: 48 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,15 @@
11
"""Common functionality for datapipes"""
22
import logging
33
from datetime import datetime, timedelta
4-
from typing import Dict, List, Optional, Tuple, Union
4+
from typing import Callable, Dict, List, Optional, Tuple, Union
55

66
import numpy as np
77
import xarray as xr
88
from torch.utils.data import functional_datapipe
99
from torch.utils.data.datapipes.datapipe import IterDataPipe
1010

1111
from ocf_datapipes.batch import BatchKey, NumpyBatch
12-
from ocf_datapipes.config.model import Configuration
12+
from ocf_datapipes.config.model import Configuration, InputData
1313
from ocf_datapipes.load import (
1414
OpenAWOSFromNetCDF,
1515
OpenConfiguration,
@@ -37,6 +37,36 @@
3737
logger = logging.getLogger(__name__)
3838

3939

40+
def is_config_and_path_valid(
41+
use_flag: bool,
42+
config,
43+
filepath_resolver: Union[str, Callable[[InputData], str]],
44+
) -> bool:
45+
"""
46+
Checks if the given configuration should be used based on specific criteria.
47+
48+
Args:
49+
use_flag (bool): Indicates whether to consider using the configuration.
50+
config (object): The configuration object to check.
51+
filepath_resolver (str or callable): Specifies how to access the file path within config;
52+
can be an attribute name (str) or a function (callable) that returns the file path.
53+
54+
Returns:
55+
bool: True if all conditions are met (use_flag is True, config is not None,
56+
and the resolved file path is not empty), otherwise False.
57+
"""
58+
59+
if not use_flag or config is None:
60+
return False
61+
62+
filepath = (
63+
filepath_resolver(config)
64+
if callable(filepath_resolver)
65+
else getattr(config, filepath_resolver, "")
66+
)
67+
return bool(filepath)
68+
69+
4070
def open_and_return_datapipes(
4171
configuration_filename: str,
4272
use_gsp: bool = True,
@@ -77,32 +107,23 @@ def open_and_return_datapipes(
77107
and len(conf_in.nwp) != 0
78108
and all(v.nwp_zarr_path != "" for _, v in conf_in.nwp.items())
79109
)
80-
use_pv = (
81-
use_pv and (conf_in.pv is not None) and (conf_in.pv.pv_files_groups[0].pv_filename != "")
82-
)
83-
use_sat = (
84-
use_sat
85-
and (conf_in.satellite is not None)
86-
and (conf_in.satellite.satellite_zarr_path != "")
87-
)
88-
use_hrv = (
89-
use_hrv
90-
and (conf_in.hrvsatellite is not None)
91-
and (conf_in.hrvsatellite.hrvsatellite_zarr_path != "")
92-
)
93-
use_topo = (
94-
use_topo
95-
and (conf_in.topographic is not None)
96-
and (conf_in.topographic.topographic_filename != "")
97-
)
98-
use_gsp = use_gsp and (conf_in.gsp is not None) and (conf_in.gsp.gsp_zarr_path != "")
99-
use_sensor = (
100-
use_sensor and (conf_in.sensor is not None) and (conf_in.sensor.sensor_filename != "")
110+
111+
use_pv = is_config_and_path_valid(
112+
use_pv,
113+
conf_in.pv,
114+
lambda config: config.pv_files_groups[0].pv_filename if config.pv_files_groups else "",
101115
)
102-
use_wind = (
103-
use_wind
104-
and (conf_in.wind is not None)
105-
and (conf_in.wind.wind_files_groups[0].wind_filename != "")
116+
use_sat = is_config_and_path_valid(use_sat, conf_in.satellite, "satellite_zarr_path")
117+
use_hrv = is_config_and_path_valid(use_hrv, conf_in.hrvsatellite, "hrvsatellite_zarr_path")
118+
use_topo = is_config_and_path_valid(use_topo, conf_in.topographic, "topographic_filename")
119+
use_gsp = is_config_and_path_valid(use_gsp, conf_in.gsp, "gsp_zarr_path")
120+
use_sensor = is_config_and_path_valid(use_sensor, conf_in.sensor, "sensor_filename")
121+
use_wind = is_config_and_path_valid(
122+
use_wind,
123+
conf_in.wind,
124+
lambda config: config.wind_files_groups[0].wind_filename
125+
if config.wind_files_groups
126+
else "",
106127
)
107128

108129
logger.debug(

0 commit comments

Comments
 (0)