Skip to content

Commit 884aef2

Browse files
Merge pull request #64 from openclimatefix/datamodel-filter
tidy up
2 parents 9a555df + a331db3 commit 884aef2

File tree

2 files changed

+63
-36
lines changed

2 files changed

+63
-36
lines changed

nowcasting_datamodel/read/read.py

Lines changed: 62 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@
1111
from sqlalchemy import desc
1212
from sqlalchemy.orm import contains_eager, joinedload
1313
from sqlalchemy.orm.session import Session
14-
from sqlalchemy.sql.expression import false, true
1514

15+
from nowcasting_datamodel import N_GSP
1616
from nowcasting_datamodel.models import (
1717
ForecastSQL,
1818
ForecastValueLatestSQL,
@@ -138,36 +138,14 @@ def get_latest_forecast(
138138

139139
logger.debug(f"Getting latest forecast for gsp {gsp_id}")
140140

141-
# start main query
142-
query = session.query(ForecastSQL)
143-
order_by_items = []
144-
145-
if historic:
146-
query = query.filter(ForecastSQL.historic == true())
147-
else:
148-
query = query.filter(ForecastSQL.historic == false())
149-
150-
if start_target_time is not None:
151-
query = filter_query_on_target_time(
152-
query=query, start_target_time=start_target_time, historic=historic
153-
)
154-
155-
# filter on gsp_id
156141
if gsp_id is not None:
157-
query = query.join(LocationSQL)
158-
query = query.filter(LocationSQL.gsp_id == gsp_id)
159-
order_by_items.append(LocationSQL.gsp_id)
160-
161-
order_by_items.append(ForecastSQL.created_utc.desc())
162-
if not historic:
163-
created_utc = get_latest_forecast_created_utc(session=session, gsp_id=gsp_id)
164-
query = query.filter(ForecastSQL.created_utc == created_utc)
165-
166-
# this make the newest ones comes to the top
167-
query = query.order_by(*order_by_items)
142+
gsp_ids = [gsp_id]
143+
else:
144+
gsp_ids = None
168145

169-
# get all results
170-
forecasts = query.all()
146+
forecasts = get_latest_forecast_for_gsps(
147+
session=session, start_target_time=start_target_time, historic=historic, gsp_ids=gsp_ids
148+
)
171149

172150
if forecasts is None:
173151
return None
@@ -213,35 +191,84 @@ def get_all_gsp_ids_latest_forecast(
213191

214192
logger.debug("Getting latest forecast for all gsps")
215193

194+
return get_latest_forecast_for_gsps(
195+
session=session,
196+
start_created_utc=start_created_utc,
197+
start_target_time=start_target_time,
198+
preload_children=preload_children,
199+
historic=historic,
200+
gsp_ids=list(range(0, N_GSP + 1)),
201+
)
202+
203+
204+
def get_latest_forecast_for_gsps(
205+
session: Session,
206+
start_created_utc: Optional[datetime] = None,
207+
start_target_time: Optional[datetime] = None,
208+
preload_children: Optional[bool] = False,
209+
historic: bool = False,
210+
gsp_ids: List[int] = None,
211+
):
212+
"""
213+
Read forecasts
214+
215+
:param session: database session
216+
:param start_created_utc: Filter: forecast creation time should be larger than this datetime
217+
:param start_target_time:
218+
Filter: forecast values target time should be larger than this datetime
219+
:param preload_children: Option to preload children. This is a speed up, if we need them.
220+
:param historic: Option to load historic values or not
221+
:param gsp_ids: Option to filter on gsps. If None, then only the lastest forecast is loaded.
222+
223+
return: List of forecasts objects from database
224+
225+
:param session:
226+
:param start_created_utc:
227+
:param start_target_time:
228+
:param preload_children:
229+
:param historic:
230+
:param gsp_ids:
231+
:return:
232+
"""
233+
order_by_cols = []
234+
216235
# start main query
217236
query = session.query(ForecastSQL)
218237

238+
# filter on created_utc
219239
if start_created_utc is not None:
220240
query = query.filter(ForecastSQL.created_utc >= start_created_utc)
221241

222-
# join with tables
223-
if not historic:
224-
query = query.distinct(LocationSQL.gsp_id)
242+
# join with location table and filter
243+
if gsp_ids is not None:
244+
if not historic:
245+
# for historic they are already distinct
246+
query = query.distinct(LocationSQL.gsp_id)
247+
query = query.filter(LocationSQL.gsp_id.in_(gsp_ids))
248+
order_by_cols.append(LocationSQL.gsp_id)
225249
query = query.join(LocationSQL)
226250

251+
# filter on historic
227252
query = query.filter(ForecastSQL.historic == historic)
228253

254+
# filter on target time
229255
if start_target_time is not None:
230256
query = filter_query_on_target_time(
231257
query=query, start_target_time=start_target_time, historic=historic
232258
)
233259

260+
# option to preload values, makes querying quicker
234261
if preload_children:
235262
query = query.options(joinedload(ForecastSQL.location))
236263
query = query.options(joinedload(ForecastSQL.model))
237264
query = query.options(joinedload(ForecastSQL.input_data_last_updated))
238265
if not historic:
239266
query = query.options(joinedload(ForecastSQL.forecast_values))
240-
query = query.options(joinedload(ForecastSQL.forecast_values_latest))
241267

242-
query = query.order_by(LocationSQL.gsp_id, desc(ForecastSQL.created_utc))
268+
order_by_cols.append(desc(ForecastSQL.created_utc))
269+
query = query.order_by(*order_by_cols)
243270

244-
forecasts = query.populate_existing().all()
271+
forecasts = query.all()
245272

246273
logger.debug(f"Found {len(forecasts)} forecasts")
247274
if len(forecasts) > 0:

tests/read/test_read.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@
1313
from nowcasting_datamodel.models import (
1414
Forecast,
1515
ForecastValue,
16-
ForecastValueSQL,
1716
ForecastValueLatestSQL,
17+
ForecastValueSQL,
1818
InputDataLastUpdatedSQL,
1919
LocationSQL,
2020
MLModel,

0 commit comments

Comments
 (0)