Skip to content

Commit eb2cc11

Browse files
committed
Fix pylint + doc check
1 parent 99b7633 commit eb2cc11

File tree

2 files changed

+31
-41
lines changed

2 files changed

+31
-41
lines changed

src/sagemaker/mlflow/forward_sagemaker_metrics.py

Lines changed: 27 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -11,23 +11,24 @@
1111
# ANY KIND, either express or implied. See the License for the specific
1212
# language governing permissions and limitations under the License.
1313

14+
"""This module contains code related to forwarding SageMaker TrainingJob Metrics to MLflow."""
15+
16+
from __future__ import absolute_import
1417

15-
import boto3
1618
import os
19+
import platform
20+
import re
21+
from typing import Set, Tuple, List, Dict, Generator
22+
import boto3
1723
import mlflow
1824
from mlflow import MlflowClient
1925
from mlflow.entities import Metric, Param, RunTag
2026

2127
from packaging import version
22-
import platform
23-
import re
24-
25-
from typing import Set, Tuple, List, Dict
2628

2729

2830
def encode(name: str, existing_names: Set[str]) -> str:
29-
"""
30-
Encode a string to comply with MLflow naming restrictions and ensure uniqueness.
31+
"""Encode a string to comply with MLflow naming restrictions and ensure uniqueness.
3132
3233
Args:
3334
name (str): The original string to be encoded.
@@ -55,7 +56,8 @@ def encode_char(match):
5556

5657
if base_name in existing_names:
5758
suffix = 1
58-
# Edge case where even with suffix space there is a collision we will override one of the keys.
59+
# Edge case where even with suffix space there is a collision
60+
# we will override one of the keys.
5961
while f"{base_name}_{suffix}" in existing_names:
6062
suffix += 1
6163
encoded = f"{base_name}_{suffix}"
@@ -68,9 +70,7 @@ def encode_char(match):
6870

6971

7072
def decode(encoded_metric_name: str) -> str:
71-
72-
# TODO: Utilize the stored name mappings to get the original key mappings without having to decode.
73-
"""Decodes an encoded metric name by replacing hexadecimal representations with their corresponding characters.
73+
"""Decodes an encoded metric name by replacing hexadecimal representations with ASCII
7474
7575
This function reverses the encoding process by converting hexadecimal codes
7676
back to their original characters. It looks for patterns of the form "_XX_"
@@ -100,8 +100,7 @@ def replace_code(match):
100100

101101

102102
def get_training_job_details(job_arn: str) -> dict:
103-
"""
104-
Retrieve details of a SageMaker training job.
103+
"""Retrieve details of a SageMaker training job.
105104
106105
Args:
107106
job_arn (str): The ARN of the SageMaker training job.
@@ -118,8 +117,7 @@ def get_training_job_details(job_arn: str) -> dict:
118117

119118

120119
def create_metric_queries(job_arn: str, metric_definitions: list) -> list:
121-
"""
122-
Create metric queries for SageMaker metrics.
120+
"""Create metric queries for SageMaker metrics.
123121
124122
Args:
125123
job_arn (str): The ARN of the SageMaker training job.
@@ -142,8 +140,7 @@ def create_metric_queries(job_arn: str, metric_definitions: list) -> list:
142140

143141

144142
def get_metric_data(metric_queries: list) -> dict:
145-
"""
146-
Retrieve metric data from SageMaker.
143+
"""Retrieve metric data from SageMaker.
147144
148145
Args:
149146
metric_queries (list): A list of metric queries.
@@ -162,8 +159,7 @@ def get_metric_data(metric_queries: list) -> dict:
162159
def prepare_mlflow_metrics(
163160
metric_queries: list, metric_results: list
164161
) -> Tuple[List[Metric], Dict[str, str]]:
165-
"""
166-
Prepare metrics for MLflow logging, encoding metric names if necessary.
162+
"""Prepare metrics for MLflow logging, encoding metric names if necessary.
167163
168164
Args:
169165
metric_queries (list): The original metric queries sent to SageMaker.
@@ -184,29 +180,26 @@ def prepare_mlflow_metrics(
184180
encoded_name = encode(metric_name, existing_names)
185181
metric_name_mapping[encoded_name] = metric_name
186182

187-
mlflow_metrics.extend(
188-
[
189-
Metric(key=encoded_name, value=value, timestamp=timestamp, step=step)
190-
for step, (timestamp, value) in enumerate(
191-
zip(result["XAxisValues"], result["MetricValues"])
192-
)
193-
]
194-
)
183+
for step, (timestamp, value) in enumerate(
184+
zip(result["XAxisValues"], result["MetricValues"])
185+
):
186+
metric = Metric(key=encoded_name, value=value, timestamp=timestamp, step=step)
187+
mlflow_metrics.append(metric)
195188

196189
return mlflow_metrics, metric_name_mapping
197190

198191

199192
def prepare_mlflow_params(hyperparameters: Dict[str, str]) -> Tuple[List[Param], Dict[str, str]]:
200-
"""
201-
Prepare hyperparameters for MLflow logging, encoding parameter names if necessary.
193+
"""Prepare hyperparameters for MLflow logging, encoding parameter names if necessary.
202194
203195
Args:
204196
hyperparameters (Dict[str, str]): The hyperparameters from the SageMaker job.
205197
206198
Returns:
207199
Tuple[List[Param], Dict[str, str]]:
208200
- A list of Param objects with encoded names (if necessary)
209-
- A mapping of encoded to original names for hyperparameters (only for encoded parameters)
201+
- A mapping of encoded to original names for
202+
hyperparameters (only for encoded parameters)
210203
"""
211204
mlflow_params = []
212205
param_name_mapping = {}
@@ -220,9 +213,8 @@ def prepare_mlflow_params(hyperparameters: Dict[str, str]) -> Tuple[List[Param],
220213
return mlflow_params, param_name_mapping
221214

222215

223-
def batch_items(items: list, batch_size: int) -> list:
224-
"""
225-
Yield successive batch_size chunks from items.
216+
def batch_items(items: list, batch_size: int) -> Generator:
217+
"""Yield successive batch_size chunks from items.
226218
227219
Args:
228220
items (list): The list of items to be batched.
@@ -236,8 +228,7 @@ def batch_items(items: list, batch_size: int) -> list:
236228

237229

238230
def log_to_mlflow(metrics: list, params: list, tags: dict) -> None:
239-
"""
240-
Log metrics, parameters, and tags to MLflow.
231+
"""Log metrics, parameters, and tags to MLflow.
241232
242233
Args:
243234
metrics (list): List of metrics to log.
@@ -278,8 +269,7 @@ def log_to_mlflow(metrics: list, params: list, tags: dict) -> None:
278269

279270

280271
def log_sagemaker_job_to_mlflow(training_job_arn: str) -> None:
281-
"""
282-
Retrieve SageMaker metrics and hyperparameters and log them to MLflow.
272+
"""Retrieve SageMaker metrics and hyperparameters and log them to MLflow.
283273
284274
Args:
285275
training_job_arn (str): The ARN of the SageMaker training job.

tests/unit/sagemaker/mlflow/test_forward_sagemaker_metrics.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,13 @@
1111
# ANY KIND, either express or implied. See the License for the specific
1212
# language governing permissions and limitations under the License.
1313

14+
from __future__ import absolute_import
15+
from unittest.mock import patch, MagicMock, Mock
16+
import json
1417
import pytest
1518
from mlflow.entities import Metric, Param
16-
from unittest.mock import patch, MagicMock, Mock
1719
import requests
1820

19-
import json
2021

2122
from sagemaker.mlflow.forward_sagemaker_metrics import (
2223
encode,
@@ -161,10 +162,9 @@ def test_batch_items():
161162
assert batches == [[1, 2], [3, 4], [5]]
162163

163164

164-
@patch("mlflow.MlflowClient")
165165
@patch("os.getenv")
166166
@patch("requests.Session.request")
167-
def test_log_to_mlflow(mock_request, mock_getenv, mock_mlflow_client):
167+
def test_log_to_mlflow(mock_request, mock_getenv):
168168
# Set up return values for os.getenv calls
169169
def getenv_side_effect(arg, default=None):
170170
values = {

0 commit comments

Comments
 (0)