Skip to content

Commit 6f16fef

Browse files
add model to adjuster (#167)
* add model to adjuster * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add migration --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 34efef1 commit 6f16fef

File tree

6 files changed

+54
-4
lines changed

6 files changed

+54
-4
lines changed

nowcasting_datamodel/fake.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -280,6 +280,8 @@ def make_fake_me_latest(session: Session):
280280
gsp_id=0, session=session, installed_capacity_mw=14000, label=national_gb_label
281281
)
282282

283+
model = get_model(session=session, name="fake_model", version="0.1.2")
284+
283285
session.add_all([location, datetime_interval, metric])
284286

285287
metric_values = []
@@ -294,6 +296,7 @@ def make_fake_me_latest(session: Session):
294296
datetime_interval=datetime_interval,
295297
metric=metric,
296298
location=location,
299+
model=model,
297300
)
298301
metric_values.append(m)
299302

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
"""
2+
3+
Add model_id in mertric_value
4+
5+
Revision ID: 37c68fd8e65c
6+
Revises: 489955d7a5a0
7+
Create Date: 2023-04-14 15:05:00.205258
8+
9+
"""
10+
import sqlalchemy as sa
11+
from alembic import op
12+
13+
# revision identifiers, used by Alembic.
14+
revision = "37c68fd8e65c"
15+
down_revision = "489955d7a5a0"
16+
branch_labels = None
17+
depends_on = None
18+
19+
20+
def upgrade(): # noqa
21+
# ### commands auto generated by Alembic - please adjust! ###
22+
op.add_column("metric_value", sa.Column("model_id", sa.Integer(), nullable=True))
23+
op.create_index(op.f("ix_metric_value_model_id"), "metric_value", ["model_id"], unique=False)
24+
op.create_foreign_key(None, "metric_value", "model", ["model_id"], ["id"])
25+
# ### end Alembic commands ###
26+
27+
28+
def downgrade(): # noqa
29+
# ### commands auto generated by Alembic - please adjust! ###
30+
op.drop_constraint(None, "metric_value", type_="foreignkey")
31+
op.drop_index(op.f("ix_metric_value_model_id"), table_name="metric_value")
32+
op.drop_column("metric_value", "model_id")
33+
# ### end Alembic commands ###

nowcasting_datamodel/models/metric.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,10 @@ class MetricValueSQL(Base_Forecast, CreatedMixin):
145145
datetime_interval = relationship("DatetimeIntervalSQL", back_populates="metric_value")
146146
datetime_interval_id = Column(Integer, ForeignKey("datetime_interval.id"), index=True)
147147

148+
# many (metric values) to one (model)
149+
model = relationship("MLModelSQL", back_populates="metric_value")
150+
model_id = Column(Integer, ForeignKey("model.id"), index=True)
151+
148152

149153
class MetricValue(EnhancedBaseModel):
150154
"""Location that the forecast is for"""

nowcasting_datamodel/models/models.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ class MLModelSQL(Base_Forecast):
3939
version = Column(String)
4040

4141
forecast = relationship("ForecastSQL", back_populates="model")
42+
metric_value = relationship("MetricValueSQL", back_populates="model")
4243

4344

4445
class MLModel(EnhancedBaseModel):

nowcasting_datamodel/read/read_metric.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,11 @@
55
"""
66
import logging
77
from datetime import datetime
8-
from typing import List
8+
from typing import List, Optional
99

1010
from sqlalchemy.orm.session import Session
1111

12+
from nowcasting_datamodel.models import MLModelSQL
1213
from nowcasting_datamodel.models.metric import DatetimeIntervalSQL, MetricSQL, MetricValueSQL
1314

1415
logger = logging.getLogger(__name__)
@@ -94,7 +95,7 @@ def get_datetime_interval(
9495

9596

9697
def read_latest_me_national(
97-
session: Session, metric_name: str = "Half Hourly ME"
98+
session: Session, metric_name: str = "Half Hourly ME", model_name: Optional[str] = None
9899
) -> List[MetricValueSQL]:
99100
"""
100101
Get the latest me for the national forecast
@@ -103,9 +104,12 @@ def read_latest_me_national(
103104
104105
:param session: database sessions
105106
:param metric_name: metric name, defaulted to "Half Hourly ME"
107+
:param model_name: model name, defaulted to None
106108
:return: list of MetricValueSQL for ME for each 'time_of_day' and 'forecast_horizon_minutes'
107109
"""
108110

111+
logger.debug(f"Reading latest {metric_name=} for {model_name=}")
112+
109113
# start main query
110114
query = session.query(MetricValueSQL)
111115
query = query.join(MetricSQL)
@@ -121,6 +125,10 @@ def read_latest_me_national(
121125
MetricValueSQL.created_utc.desc(),
122126
)
123127

128+
if model_name is not None:
129+
query = query.join(MLModelSQL)
130+
query = query.filter(MLModelSQL.name == model_name)
131+
124132
# get all results
125133
metric_values = query.all()
126134

nowcasting_datamodel/save/adjust.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,11 +44,12 @@ def add_adjust_to_national_forecast(forecast: ForecastSQL, session):
4444
:return:
4545
"""
4646

47-
# get the target time for the first forecast
47+
# get the target time and model name
4848
datetime_now = forecast.forecast_values[0].target_time
49+
model_name = forecast.model.name
4950

5051
# 1. read metric values
51-
latest_me = read_latest_me_national(session=session)
52+
latest_me = read_latest_me_national(session=session, model_name=model_name)
5253
assert len(latest_me) > 0
5354

5455
# 2. filter value down to now onwards

0 commit comments

Comments
 (0)