Skip to content

Commit c3ad8da

Browse files
committed
Merge branch 'datamodel-filter'
2 parents 1b296f5 + 43118ce commit c3ad8da

File tree

2 files changed

+69
-43
lines changed

2 files changed

+69
-43
lines changed

nowcasting_datamodel/read/read.py

Lines changed: 58 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -136,54 +136,56 @@ def get_latest_forecast(
136136
return: List of forecasts objects from database
137137
"""
138138

139-
logger.debug("getting latest forecast")
139+
logger.debug(f"Getting latest forecast for gsp {gsp_id}")
140140

141141
# start main query
142142
query = session.query(ForecastSQL)
143143
order_by_items = []
144144

145145
if historic:
146146
query = query.filter(ForecastSQL.historic == true())
147-
join_object = ForecastSQL.forecast_values_latest
148-
forecast_value_model = ForecastValueLatestSQL
149147
else:
150148
query = query.filter(ForecastSQL.historic == false())
151-
join_object = ForecastSQL.forecast_values
152-
forecast_value_model = ForecastValueSQL
149+
150+
if start_target_time is not None:
151+
query = filter_query_on_target_time(query=query,
152+
start_target_time=start_target_time,
153+
historic=historic)
153154

154155
# filter on gsp_id
155156
if gsp_id is not None:
156157
query = query.join(LocationSQL)
157158
query = query.filter(LocationSQL.gsp_id == gsp_id)
158159
order_by_items.append(LocationSQL.gsp_id)
159160

160-
if start_target_time is not None:
161-
query = (
162-
query.join(join_object)
163-
.filter(forecast_value_model.target_time >= start_target_time)
164-
.options(contains_eager(join_object))
165-
.populate_existing()
166-
)
167-
168161
order_by_items.append(ForecastSQL.created_utc.desc())
169162

170163
# this make the newest ones comes to the top
171164
query = query.order_by(*order_by_items)
172165

173166
# get all results
174-
forecasts = query.first()
167+
if not historic:
168+
query = query.limit(1)
169+
forecasts = query.all()
170+
171+
if forecasts is None:
172+
return None
173+
if len(forecasts) == 0:
174+
return None
175+
176+
forecast = forecasts[0]
175177

176178
# sort list
177-
if forecasts is not None:
178-
logger.debug("sorting 'forecast_values_latest' values")
179-
if forecasts.forecast_values_latest is not None:
180-
forecasts.forecast_values_latest = sorted(
181-
forecasts.forecast_values_latest, key=lambda d: d.target_time
182-
)
179+
logger.debug(f"sorting 'forecast_values_latest' values. "
180+
f"There are {len(forecast.forecast_values_latest)}")
181+
if forecast.forecast_values_latest is not None:
182+
forecast.forecast_values_latest = sorted(
183+
forecast.forecast_values_latest, key=lambda d: d.target_time
184+
)
183185

184-
logger.debug(f"Found forecasts for gsp id: {gsp_id} {historic=} {forecasts=}")
186+
logger.debug(f"Found forecasts for gsp id: {gsp_id} {historic=} {forecast=}")
185187

186-
return forecasts
188+
return forecast
187189

188190

189191
def get_all_gsp_ids_latest_forecast(
@@ -206,12 +208,7 @@ def get_all_gsp_ids_latest_forecast(
206208
return: List of forecasts objects from database
207209
"""
208210

209-
if historic:
210-
forecast_value_model = ForecastValueLatestSQL
211-
join_object = ForecastSQL.forecast_values_latest
212-
else:
213-
forecast_value_model = ForecastValueSQL
214-
join_object = ForecastSQL.forecast_values
211+
logger.debug("Getting latest forecast for all gsps")
215212

216213
# start main query
217214
query = session.query(ForecastSQL)
@@ -220,33 +217,56 @@ def get_all_gsp_ids_latest_forecast(
220217
query = query.filter(ForecastSQL.created_utc >= start_created_utc)
221218

222219
if start_target_time is not None:
223-
query = (
224-
query.join(join_object)
225-
.filter(forecast_value_model.target_time >= start_target_time)
226-
.options(contains_eager(join_object))
227-
.populate_existing()
228-
)
220+
query = filter_query_on_target_time(query=query,
221+
start_target_time=start_target_time,
222+
historic=historic)
229223

230224
# join with tables
231-
query = query.distinct(LocationSQL.gsp_id)
225+
if not historic:
226+
query = query.distinct(LocationSQL.gsp_id)
232227
query = query.join(LocationSQL)
233228

234229
query = query.filter(ForecastSQL.historic == historic)
235230

236231
query = query.order_by(LocationSQL.gsp_id, desc(ForecastSQL.created_utc))
237232

238233
if preload_children:
239-
query = query.options(joinedload(ForecastSQL.forecast_values_latest))
240-
query = query.options(joinedload(ForecastSQL.forecast_values))
241234
query = query.options(joinedload(ForecastSQL.location))
242235
query = query.options(joinedload(ForecastSQL.model))
243236
query = query.options(joinedload(ForecastSQL.input_data_last_updated))
244237

245-
forecasts = query.limit(339).all()
238+
forecasts = query.all()
246239

247240
return forecasts
248241

249242

243+
def filter_query_on_target_time(query, start_target_time, historic: bool):
244+
"""
245+
Filter query on start target time
246+
247+
:param query: sql query
248+
:param start_target_time: datetime, target times only included after this
249+
:param historic: bool, if data is historic or latest
250+
:return: query
251+
"""
252+
if historic:
253+
forecast_value_model = ForecastValueLatestSQL
254+
join_object = ForecastSQL.forecast_values_latest
255+
else:
256+
forecast_value_model = ForecastValueSQL
257+
join_object = ForecastSQL.forecast_values
258+
259+
if start_target_time is not None:
260+
query = (
261+
query.join(join_object)
262+
.filter(forecast_value_model.target_time >= start_target_time)
263+
.options(contains_eager(join_object))
264+
.populate_existing()
265+
)
266+
267+
return query
268+
269+
250270
def get_forecast_values(
251271
session: Session,
252272
gsp_id: Optional[int] = None,

tests/read/test_read.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,10 @@ def test_read_target_time(db_session):
102102
f2 = ForecastValueLatestSQL(
103103
target_time=datetime(2022, 1, 1, 0, 30), expected_power_generation_megawatts=2, gsp_id=1
104104
)
105-
f = make_fake_forecast(gsp_id=1, session=db_session, forecast_values_latest=[f1, f2])
105+
f3 = ForecastValueLatestSQL(
106+
target_time=datetime(2022, 1, 1, 1), expected_power_generation_megawatts=3, gsp_id=1
107+
)
108+
f = make_fake_forecast(gsp_id=1, session=db_session, forecast_values_latest=[f1, f2, f3])
106109
f.historic = True
107110

108111
forecast_read = get_latest_forecast(
@@ -114,7 +117,7 @@ def test_read_target_time(db_session):
114117
assert forecast_read is not None
115118
assert forecast_read.location.gsp_id == f.location.gsp_id
116119

117-
assert len(forecast_read.forecast_values_latest) == 1
120+
assert len(forecast_read.forecast_values_latest) == 2
118121

119122

120123
def test_get_forecast_values(db_session, forecasts):
@@ -257,7 +260,7 @@ def test_get_all_gsp_ids_latest_forecast_filter(db_session):
257260

258261

259262
def test_get_all_gsp_ids_latest_forecast_filter_historic(db_session):
260-
"""Test get all historic forecast from all gsp "but filter on starttime"""
263+
"""Test get all historic forecast from all gsp but filter on starttime"""
261264

262265
f1 = make_fake_forecasts(
263266
gsp_ids=[1, 2], session=db_session, t0_datetime_utc=datetime(2020, 1, 1)
@@ -268,7 +271,10 @@ def test_get_all_gsp_ids_latest_forecast_filter_historic(db_session):
268271
gsp_id=1, expected_power_generation_megawatts=1, target_time=datetime(2022, 1, 1)
269272
),
270273
ForecastValueLatestSQL(
271-
gsp_id=1, expected_power_generation_megawatts=1, target_time=datetime(2022, 1, 1, 0, 30)
274+
gsp_id=1, expected_power_generation_megawatts=2, target_time=datetime(2022, 1, 1, 0, 30)
275+
),
276+
ForecastValueLatestSQL(
277+
gsp_id=1, expected_power_generation_megawatts=3, target_time=datetime(2022, 1, 1, 1)
272278
),
273279
]
274280

@@ -278,7 +284,7 @@ def test_get_all_gsp_ids_latest_forecast_filter_historic(db_session):
278284
forecast = get_all_gsp_ids_latest_forecast(
279285
session=db_session, start_target_time=target_time, historic=True
280286
)[0]
281-
assert len(forecast.forecast_values_latest) == 1
287+
assert len(forecast.forecast_values_latest) == 2
282288

283289

284290
def test_get_national_latest_forecast(db_session):

0 commit comments

Comments
 (0)