|
1 | 1 | """Common functionality for datapipes"""
|
2 | 2 | import logging
|
3 | 3 | from datetime import datetime, timedelta
|
4 |
| -from typing import Dict, List, Optional, Tuple, Union |
| 4 | +from typing import Callable, Dict, List, Optional, Tuple, Union |
5 | 5 |
|
6 | 6 | import numpy as np
|
7 | 7 | import xarray as xr
|
8 | 8 | from torch.utils.data import functional_datapipe
|
9 | 9 | from torch.utils.data.datapipes.datapipe import IterDataPipe
|
10 | 10 |
|
11 | 11 | from ocf_datapipes.batch import BatchKey, NumpyBatch
|
12 |
| -from ocf_datapipes.config.model import Configuration |
| 12 | +from ocf_datapipes.config.model import Configuration, InputData |
13 | 13 | from ocf_datapipes.load import (
|
14 | 14 | OpenAWOSFromNetCDF,
|
15 | 15 | OpenConfiguration,
|
|
37 | 37 | logger = logging.getLogger(__name__)
|
38 | 38 |
|
39 | 39 |
|
| 40 | +def is_config_and_path_valid( |
| 41 | + use_flag: bool, |
| 42 | + config, |
| 43 | + filepath_resolver: Union[str, Callable[[InputData], str]], |
| 44 | +) -> bool: |
| 45 | + """ |
| 46 | + Checks if the given configuration should be used based on specific criteria. |
| 47 | +
|
| 48 | + Args: |
| 49 | + use_flag (bool): Indicates whether to consider using the configuration. |
| 50 | + config (object): The configuration object to check. |
| 51 | + filepath_resolver (str or callable): Specifies how to access the file path within config; |
| 52 | + can be an attribute name (str) or a function (callable) that returns the file path. |
| 53 | +
|
| 54 | + Returns: |
| 55 | + bool: True if all conditions are met (use_flag is True, config is not None, |
| 56 | + and the resolved file path is not empty), otherwise False. |
| 57 | + """ |
| 58 | + |
| 59 | + if not use_flag or config is None: |
| 60 | + return False |
| 61 | + |
| 62 | + filepath = ( |
| 63 | + filepath_resolver(config) |
| 64 | + if callable(filepath_resolver) |
| 65 | + else getattr(config, filepath_resolver, "") |
| 66 | + ) |
| 67 | + return bool(filepath) |
| 68 | + |
| 69 | + |
40 | 70 | def open_and_return_datapipes(
|
41 | 71 | configuration_filename: str,
|
42 | 72 | use_gsp: bool = True,
|
@@ -77,32 +107,23 @@ def open_and_return_datapipes(
|
77 | 107 | and len(conf_in.nwp) != 0
|
78 | 108 | and all(v.nwp_zarr_path != "" for _, v in conf_in.nwp.items())
|
79 | 109 | )
|
80 |
| - use_pv = ( |
81 |
| - use_pv and (conf_in.pv is not None) and (conf_in.pv.pv_files_groups[0].pv_filename != "") |
82 |
| - ) |
83 |
| - use_sat = ( |
84 |
| - use_sat |
85 |
| - and (conf_in.satellite is not None) |
86 |
| - and (conf_in.satellite.satellite_zarr_path != "") |
87 |
| - ) |
88 |
| - use_hrv = ( |
89 |
| - use_hrv |
90 |
| - and (conf_in.hrvsatellite is not None) |
91 |
| - and (conf_in.hrvsatellite.hrvsatellite_zarr_path != "") |
92 |
| - ) |
93 |
| - use_topo = ( |
94 |
| - use_topo |
95 |
| - and (conf_in.topographic is not None) |
96 |
| - and (conf_in.topographic.topographic_filename != "") |
97 |
| - ) |
98 |
| - use_gsp = use_gsp and (conf_in.gsp is not None) and (conf_in.gsp.gsp_zarr_path != "") |
99 |
| - use_sensor = ( |
100 |
| - use_sensor and (conf_in.sensor is not None) and (conf_in.sensor.sensor_filename != "") |
| 110 | + |
| 111 | + use_pv = is_config_and_path_valid( |
| 112 | + use_pv, |
| 113 | + conf_in.pv, |
| 114 | + lambda config: config.pv_files_groups[0].pv_filename if config.pv_files_groups else "", |
101 | 115 | )
|
102 |
| - use_wind = ( |
103 |
| - use_wind |
104 |
| - and (conf_in.wind is not None) |
105 |
| - and (conf_in.wind.wind_files_groups[0].wind_filename != "") |
| 116 | + use_sat = is_config_and_path_valid(use_sat, conf_in.satellite, "satellite_zarr_path") |
| 117 | + use_hrv = is_config_and_path_valid(use_hrv, conf_in.hrvsatellite, "hrvsatellite_zarr_path") |
| 118 | + use_topo = is_config_and_path_valid(use_topo, conf_in.topographic, "topographic_filename") |
| 119 | + use_gsp = is_config_and_path_valid(use_gsp, conf_in.gsp, "gsp_zarr_path") |
| 120 | + use_sensor = is_config_and_path_valid(use_sensor, conf_in.sensor, "sensor_filename") |
| 121 | + use_wind = is_config_and_path_valid( |
| 122 | + use_wind, |
| 123 | + conf_in.wind, |
| 124 | + lambda config: config.wind_files_groups[0].wind_filename |
| 125 | + if config.wind_files_groups |
| 126 | + else "", |
106 | 127 | )
|
107 | 128 |
|
108 | 129 | logger.debug(
|
|
0 commit comments