Skip to content

Commit ded1d3c

Browse files
committed
add test
1 parent 4229684 commit ded1d3c

File tree

2 files changed

+49
-14
lines changed

2 files changed

+49
-14
lines changed

nowcasting_datamodel/read/read.py

Lines changed: 24 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from typing import List, Optional
1010

1111
from sqlalchemy import desc
12-
from sqlalchemy.orm import joinedload
12+
from sqlalchemy.orm import joinedload, contains_eager
1313
from sqlalchemy.orm.session import Session
1414
from sqlalchemy.sql.expression import false, true
1515

@@ -168,11 +168,13 @@ def get_latest_forecast(
168168
else:
169169
data_model_forecast_value = ForecastValueSQL
170170

171-
forecast_values = session.query(data_model_forecast_value).\
172-
filter(data_model_forecast_value.target_time >= start_target_time).\
173-
filter(forecasts.id == data_model_forecast_value.forecast_id). \
174-
order_by(data_model_forecast_value.target_time).\
175-
all()
171+
forecast_values = (
172+
session.query(data_model_forecast_value)
173+
.filter(data_model_forecast_value.target_time >= start_target_time)
174+
.filter(forecasts.id == data_model_forecast_value.forecast_id)
175+
.order_by(data_model_forecast_value.target_time)
176+
.all()
177+
)
176178

177179
forecasts.forecast_values_latest = forecast_values
178180

@@ -210,22 +212,30 @@ def get_all_gsp_ids_latest_forecast(
210212

211213
if historic:
212214
forecast_value_model = ForecastValueLatestSQL
215+
join_object = ForecastSQL.forecast_values_latest
213216
else:
214217
forecast_value_model = ForecastValueSQL
218+
join_object = ForecastSQL.forecast_values
215219

216220
# start main query
217221
query = session.query(ForecastSQL)
218-
query = query.distinct(LocationSQL.gsp_id)
219-
query = query.join(LocationSQL)
220-
query = query.join(forecast_value_model)
221-
222-
query = query.filter(ForecastSQL.historic == historic)
223222

224223
if start_created_utc is not None:
225-
query = query.filter(ForecastSQL.created_utc > start_created_utc)
224+
query = query.filter(ForecastSQL.created_utc >= start_created_utc)
226225

227226
if start_target_time is not None:
228-
query = query.filter(forecast_value_model.target_time > start_target_time)
227+
query = (
228+
query.join(join_object)
229+
.filter(forecast_value_model.target_time >= start_target_time)
230+
.options(contains_eager(join_object))
231+
.populate_existing()
232+
)
233+
234+
# join with tables
235+
query = query.distinct(LocationSQL.gsp_id)
236+
query = query.join(LocationSQL)
237+
238+
query = query.filter(ForecastSQL.historic == historic)
229239

230240
query = query.order_by(LocationSQL.gsp_id, desc(ForecastSQL.created_utc))
231241

@@ -236,7 +246,7 @@ def get_all_gsp_ids_latest_forecast(
236246
query = query.options(joinedload(ForecastSQL.model))
237247
query = query.options(joinedload(ForecastSQL.input_data_last_updated))
238248

239-
forecasts = query.all()
249+
forecasts = query.limit(339).all()
240250

241251
return forecasts
242252

tests/read/test_read.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -256,6 +256,31 @@ def test_get_all_gsp_ids_latest_forecast_filter(db_session):
256256
assert forecast_values_read[1] == f1[1]
257257

258258

259+
def test_get_all_gsp_ids_latest_forecast_filter_historic(db_session):
260+
"""Test get all historic forecast from all gsp "but filter on starttime"""
261+
262+
f1 = make_fake_forecasts(
263+
gsp_ids=[1, 2], session=db_session, t0_datetime_utc=datetime(2020, 1, 1)
264+
)
265+
f1[0].historic = True
266+
f1[0].forecast_values_latest = [
267+
ForecastValueLatestSQL(
268+
gsp_id=1, expected_power_generation_megawatts=1, target_time=datetime(2022, 1, 1)
269+
),
270+
ForecastValueLatestSQL(
271+
gsp_id=1, expected_power_generation_megawatts=1, target_time=datetime(2022, 1, 1, 0, 30)
272+
),
273+
]
274+
275+
db_session.add_all(f1)
276+
277+
target_time = datetime(2022, 1, 1, 0, 30)
278+
forecast = get_all_gsp_ids_latest_forecast(
279+
session=db_session, start_target_time=target_time, historic=True
280+
)[0]
281+
assert len(forecast.forecast_values_latest) == 1
282+
283+
259284
def test_get_national_latest_forecast(db_session):
260285

261286
f1 = make_fake_national_forecast(session=db_session)

0 commit comments

Comments
 (0)