Skip to content
This repository was archived by the owner on Jun 2, 2025. It is now read-only.

Refactor common.py: Added helper function to reduce code duplication #291

Merged
merged 9 commits into from
Mar 28, 2024
Merged
75 changes: 48 additions & 27 deletions ocf_datapipes/training/common.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
"""Common functionality for datapipes"""
import logging
from datetime import datetime, timedelta
from typing import Dict, List, Optional, Tuple, Union
from typing import Callable, Dict, List, Optional, Tuple, Union

import numpy as np
import xarray as xr
from torch.utils.data import functional_datapipe
from torch.utils.data.datapipes.datapipe import IterDataPipe

from ocf_datapipes.batch import BatchKey, NumpyBatch
from ocf_datapipes.config.model import Configuration
from ocf_datapipes.config.model import Configuration, InputData
from ocf_datapipes.load import (
OpenAWOSFromNetCDF,
OpenConfiguration,
Expand Down Expand Up @@ -37,6 +37,36 @@
logger = logging.getLogger(__name__)


def is_config_and_path_valid(
use_flag: bool,
config,
filepath_resolver: Union[str, Callable[[InputData], str]],
) -> bool:
"""
Checks if the given configuration should be used based on specific criteria.

Args:
use_flag (bool): Indicates whether to consider using the configuration.
config (object): The configuration object to check.
filepath_resolver (str or callable): Specifies how to access the file path within config;
can be an attribute name (str) or a function (callable) that returns the file path.

Returns:
bool: True if all conditions are met (use_flag is True, config is not None,
and the resolved file path is not empty), otherwise False.
"""

if not use_flag or config is None:
return False

filepath = (
filepath_resolver(config)
if callable(filepath_resolver)
else getattr(config, filepath_resolver, "")
)
return bool(filepath)


def open_and_return_datapipes(
configuration_filename: str,
use_gsp: bool = True,
Expand Down Expand Up @@ -77,32 +107,23 @@ def open_and_return_datapipes(
and len(conf_in.nwp) != 0
and all(v.nwp_zarr_path != "" for _, v in conf_in.nwp.items())
)
use_pv = (
use_pv and (conf_in.pv is not None) and (conf_in.pv.pv_files_groups[0].pv_filename != "")
)
use_sat = (
use_sat
and (conf_in.satellite is not None)
and (conf_in.satellite.satellite_zarr_path != "")
)
use_hrv = (
use_hrv
and (conf_in.hrvsatellite is not None)
and (conf_in.hrvsatellite.hrvsatellite_zarr_path != "")
)
use_topo = (
use_topo
and (conf_in.topographic is not None)
and (conf_in.topographic.topographic_filename != "")
)
use_gsp = use_gsp and (conf_in.gsp is not None) and (conf_in.gsp.gsp_zarr_path != "")
use_sensor = (
use_sensor and (conf_in.sensor is not None) and (conf_in.sensor.sensor_filename != "")

use_pv = is_config_and_path_valid(
use_pv,
conf_in.pv,
lambda config: config.pv_files_groups[0].pv_filename if config.pv_files_groups else "",
)
use_wind = (
use_wind
and (conf_in.wind is not None)
and (conf_in.wind.wind_files_groups[0].wind_filename != "")
use_sat = is_config_and_path_valid(use_sat, conf_in.satellite, "satellite_zarr_path")
use_hrv = is_config_and_path_valid(use_hrv, conf_in.hrvsatellite, "hrvsatellite_zarr_path")
use_topo = is_config_and_path_valid(use_topo, conf_in.topographic, "topographic_filename")
use_gsp = is_config_and_path_valid(use_gsp, conf_in.gsp, "gsp_zarr_path")
use_sensor = is_config_and_path_valid(use_sensor, conf_in.sensor, "sensor_filename")
use_wind = is_config_and_path_valid(
use_wind,
conf_in.wind,
lambda config: config.wind_files_groups[0].wind_filename
if config.wind_files_groups
else "",
)

logger.debug(
Expand Down
3 changes: 1 addition & 2 deletions ocf_datapipes/training/pvnet_site.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,8 @@
import xarray as xr
from torch.utils.data import IterDataPipe, functional_datapipe
from torch.utils.data.datapipes.iter import IterableWrapper
from ocf_datapipes.batch import BatchKey, NumpyBatch

from ocf_datapipes.batch import MergeNumpyModalities, MergeNWPNumpyModalities
from ocf_datapipes.batch import BatchKey, MergeNumpyModalities, MergeNWPNumpyModalities
from ocf_datapipes.training.common import (
DatapipeKeyForker,
_get_datapipes_dict,
Expand Down
1 change: 0 additions & 1 deletion ocf_datapipes/training/windnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,6 @@ def __init__(self, filenames: List[str], keys: List[str]):

def __iter__(self):
"""Iterate through each filename, loading it, uncombining it, and then yielding it"""
import numpy as np
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jacobbieker just checking this import of numpy here didn't have some specific use and can be removed?


while True:
for filename in self.filenames:
Expand Down