9
9
from typing import List , Optional
10
10
11
11
from sqlalchemy import desc
12
- from sqlalchemy .orm import joinedload
12
+ from sqlalchemy .orm import joinedload , contains_eager
13
13
from sqlalchemy .orm .session import Session
14
14
from sqlalchemy .sql .expression import false , true
15
15
@@ -168,11 +168,13 @@ def get_latest_forecast(
168
168
else :
169
169
data_model_forecast_value = ForecastValueSQL
170
170
171
- forecast_values = session .query (data_model_forecast_value ).\
172
- filter (data_model_forecast_value .target_time >= start_target_time ).\
173
- filter (forecasts .id == data_model_forecast_value .forecast_id ). \
174
- order_by (data_model_forecast_value .target_time ).\
175
- all ()
171
+ forecast_values = (
172
+ session .query (data_model_forecast_value )
173
+ .filter (data_model_forecast_value .target_time >= start_target_time )
174
+ .filter (forecasts .id == data_model_forecast_value .forecast_id )
175
+ .order_by (data_model_forecast_value .target_time )
176
+ .all ()
177
+ )
176
178
177
179
forecasts .forecast_values_latest = forecast_values
178
180
@@ -210,22 +212,30 @@ def get_all_gsp_ids_latest_forecast(
210
212
211
213
if historic :
212
214
forecast_value_model = ForecastValueLatestSQL
215
+ join_object = ForecastSQL .forecast_values_latest
213
216
else :
214
217
forecast_value_model = ForecastValueSQL
218
+ join_object = ForecastSQL .forecast_values
215
219
216
220
# start main query
217
221
query = session .query (ForecastSQL )
218
- query = query .distinct (LocationSQL .gsp_id )
219
- query = query .join (LocationSQL )
220
- query = query .join (forecast_value_model )
221
-
222
- query = query .filter (ForecastSQL .historic == historic )
223
222
224
223
if start_created_utc is not None :
225
- query = query .filter (ForecastSQL .created_utc > start_created_utc )
224
+ query = query .filter (ForecastSQL .created_utc >= start_created_utc )
226
225
227
226
if start_target_time is not None :
228
- query = query .filter (forecast_value_model .target_time > start_target_time )
227
+ query = (
228
+ query .join (join_object )
229
+ .filter (forecast_value_model .target_time >= start_target_time )
230
+ .options (contains_eager (join_object ))
231
+ .populate_existing ()
232
+ )
233
+
234
+ # join with tables
235
+ query = query .distinct (LocationSQL .gsp_id )
236
+ query = query .join (LocationSQL )
237
+
238
+ query = query .filter (ForecastSQL .historic == historic )
229
239
230
240
query = query .order_by (LocationSQL .gsp_id , desc (ForecastSQL .created_utc ))
231
241
@@ -236,7 +246,7 @@ def get_all_gsp_ids_latest_forecast(
236
246
query = query .options (joinedload (ForecastSQL .model ))
237
247
query = query .options (joinedload (ForecastSQL .input_data_last_updated ))
238
248
239
- forecasts = query .all ()
249
+ forecasts = query .limit ( 339 ). all ()
240
250
241
251
return forecasts
242
252
0 commit comments