Skip to content

Commit 2db578a

Browse files
committed
Merge branch 'datamodel-filter'
2 parents 5ec39e9 + f8a4834 commit 2db578a

File tree

2 files changed

+42
-22
lines changed

2 files changed

+42
-22
lines changed

nowcasting_datamodel/read/read.py

Lines changed: 34 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -148,9 +148,9 @@ def get_latest_forecast(
148148
query = query.filter(ForecastSQL.historic == false())
149149

150150
if start_target_time is not None:
151-
query = filter_query_on_target_time(query=query,
152-
start_target_time=start_target_time,
153-
historic=historic)
151+
query = filter_query_on_target_time(
152+
query=query, start_target_time=start_target_time, historic=historic
153+
)
154154

155155
# filter on gsp_id
156156
if gsp_id is not None:
@@ -159,13 +159,14 @@ def get_latest_forecast(
159159
order_by_items.append(LocationSQL.gsp_id)
160160

161161
order_by_items.append(ForecastSQL.created_utc.desc())
162+
if not historic:
163+
created_utc = get_latest_forecast_created_utc(session=session, gsp_id=gsp_id)
164+
query = query.filter(ForecastSQL.created_utc == created_utc)
162165

163166
# this make the newest ones comes to the top
164167
query = query.order_by(*order_by_items)
165168

166169
# get all results
167-
if not historic:
168-
query = query.limit(1)
169170
forecasts = query.all()
170171

171172
if forecasts is None:
@@ -176,8 +177,10 @@ def get_latest_forecast(
176177
forecast = forecasts[0]
177178

178179
# sort list
179-
logger.debug(f"sorting 'forecast_values_latest' values. "
180-
f"There are {len(forecast.forecast_values_latest)}")
180+
logger.debug(
181+
f"sorting 'forecast_values_latest' values. "
182+
f"There are {len(forecast.forecast_values_latest)}"
183+
)
181184
if forecast.forecast_values_latest is not None:
182185
forecast.forecast_values_latest = sorted(
183186
forecast.forecast_values_latest, key=lambda d: d.target_time
@@ -216,26 +219,33 @@ def get_all_gsp_ids_latest_forecast(
216219
if start_created_utc is not None:
217220
query = query.filter(ForecastSQL.created_utc >= start_created_utc)
218221

219-
if start_target_time is not None:
220-
query = filter_query_on_target_time(query=query,
221-
start_target_time=start_target_time,
222-
historic=historic)
223-
224222
# join with tables
225223
if not historic:
226224
query = query.distinct(LocationSQL.gsp_id)
227225
query = query.join(LocationSQL)
228226

229227
query = query.filter(ForecastSQL.historic == historic)
230228

231-
query = query.order_by(LocationSQL.gsp_id, desc(ForecastSQL.created_utc))
229+
if start_target_time is not None:
230+
query = filter_query_on_target_time(
231+
query=query, start_target_time=start_target_time, historic=historic
232+
)
232233

233234
if preload_children:
234235
query = query.options(joinedload(ForecastSQL.location))
235236
query = query.options(joinedload(ForecastSQL.model))
236237
query = query.options(joinedload(ForecastSQL.input_data_last_updated))
238+
if not historic:
239+
query = query.options(joinedload(ForecastSQL.forecast_values))
240+
query = query.options(joinedload(ForecastSQL.forecast_values_latest))
237241

238-
forecasts = query.all()
242+
query = query.order_by(LocationSQL.gsp_id, desc(ForecastSQL.created_utc))
243+
244+
forecasts = query.populate_existing().all()
245+
246+
logger.debug(f"Found {len(forecasts)} forecasts")
247+
if len(forecasts) > 0:
248+
logger.debug(f"The first forecast has {len(forecasts[0].forecast_values)} forecast_values")
239249

240250
return forecasts
241251

@@ -257,13 +267,14 @@ def filter_query_on_target_time(query, start_target_time, historic: bool):
257267
join_object = ForecastSQL.forecast_values
258268

259269
if start_target_time is not None:
260-
query = (
261-
query.join(join_object)
262-
.filter(forecast_value_model.target_time >= start_target_time)
263-
.options(contains_eager(join_object))
264-
.populate_existing()
270+
logger.debug(f"Filtering '{start_target_time=}'")
271+
query = query.join(join_object).filter(
272+
forecast_value_model.target_time >= start_target_time
265273
)
266274

275+
if historic:
276+
query = query.options(contains_eager(join_object)).populate_existing()
277+
267278
return query
268279

269280

@@ -347,7 +358,7 @@ def get_latest_national_forecast(
347358
return forecast
348359

349360

350-
def get_latest_forecast_created_utc(session: Session, gsp_id: int) -> datetime:
361+
def get_latest_forecast_created_utc(session: Session, gsp_id: Optional[int] = None) -> datetime:
351362
"""
352363
Get the latest forecast created utc value. Can choose for different gsps
353364
@@ -360,8 +371,9 @@ def get_latest_forecast_created_utc(session: Session, gsp_id: int) -> datetime:
360371
query = session.query(ForecastSQL.created_utc)
361372

362373
# filter on gsp_id
363-
query = query.join(LocationSQL)
364-
query = query.filter(LocationSQL.gsp_id == gsp_id)
374+
if gsp_id is not None:
375+
query = query.join(LocationSQL)
376+
query = query.filter(LocationSQL.gsp_id == gsp_id)
365377

366378
# order, so latest is at the top
367379
query = query.order_by(ForecastSQL.created_utc.desc())

tests/read/test_read.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from nowcasting_datamodel.models import (
1414
Forecast,
1515
ForecastValue,
16+
ForecastValueSQL,
1617
ForecastValueLatestSQL,
1718
InputDataLastUpdatedSQL,
1819
LocationSQL,
@@ -248,15 +249,22 @@ def test_get_all_gsp_ids_latest_forecast_filter(db_session):
248249
gsp_ids=[1, 2], session=db_session, t0_datetime_utc=datetime(2020, 1, 1)
249250
)
250251
db_session.add_all(f1)
252+
db_session.add_all(f1[0].forecast_values)
253+
db_session.commit()
254+
assert len(f1[0].forecast_values) == 2
251255

252256
start_created_utc = datetime.now() - timedelta(days=1)
253257
target_time = datetime(2020, 1, 1) - timedelta(days=1)
258+
assert len(f1[0].forecast_values) == 2
254259
forecast_values_read = get_all_gsp_ids_latest_forecast(
255260
session=db_session, start_created_utc=start_created_utc, start_target_time=target_time
256261
)
262+
assert len(db_session.query(ForecastValueSQL).all()) == 4
257263
assert len(forecast_values_read) == 2
258264
assert forecast_values_read[0] == f1[0]
259265
assert forecast_values_read[1] == f1[1]
266+
assert len(f1[0].forecast_values) == 2
267+
assert len(forecast_values_read[0].forecast_values) == 2
260268

261269

262270
def test_get_all_gsp_ids_latest_forecast_filter_historic(db_session):

0 commit comments

Comments
 (0)