Skip to content

Commit 99b7633

Browse files
committed
feature: Support logging of MLFlow metrics when network_isolation mode is enabled
1 parent 66d5fdf commit 99b7633

File tree

5 files changed

+637
-0
lines changed

5 files changed

+637
-0
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ dependencies = [
4848
"PyYAML~=6.0",
4949
"requests",
5050
"sagemaker-core>=1.0.0,<2.0.0",
51+
"sagemaker-mlflow",
5152
"schema",
5253
"smdebug_rulesconfig==1.0.1",
5354
"tblib>=1.7.0,<4",

src/sagemaker/estimator.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,8 @@
107107
from sagemaker.workflow.parameters import ParameterString
108108
from sagemaker.workflow.pipeline_context import PipelineSession, runnable_by_pipeline
109109

110+
from sagemaker.mlflow.forward_sagemaker_metrics import log_sagemaker_job_to_mlflow
111+
110112
logger = logging.getLogger(__name__)
111113

112114

@@ -1366,8 +1368,14 @@ def fit(
13661368
experiment_config = check_and_get_run_experiment_config(experiment_config)
13671369
self.latest_training_job = _TrainingJob.start_new(self, inputs, experiment_config)
13681370
self.jobs.append(self.latest_training_job)
1371+
forward_to_mlflow_tracking_server = False
1372+
if os.environ.get("MLFLOW_TRACKING_URI") and self.enable_network_isolation():
1373+
wait = True
1374+
forward_to_mlflow_tracking_server = True
13691375
if wait:
13701376
self.latest_training_job.wait(logs=logs)
1377+
if forward_to_mlflow_tracking_server:
1378+
log_sagemaker_job_to_mlflow(self.latest_training_job.name)
13711379

13721380
def _compilation_job_name(self):
13731381
"""Placeholder docstring"""
Lines changed: 321 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,321 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
14+
15+
import boto3
16+
import os
17+
import mlflow
18+
from mlflow import MlflowClient
19+
from mlflow.entities import Metric, Param, RunTag
20+
21+
from packaging import version
22+
import platform
23+
import re
24+
25+
from typing import Set, Tuple, List, Dict
26+
27+
28+
def encode(name: str, existing_names: Set[str]) -> str:
29+
"""
30+
Encode a string to comply with MLflow naming restrictions and ensure uniqueness.
31+
32+
Args:
33+
name (str): The original string to be encoded.
34+
existing_names (Set[str]): Set of existing encoded names to avoid collisions.
35+
36+
Returns:
37+
str: The encoded string if changes were necessary, otherwise the original string.
38+
"""
39+
40+
def encode_char(match):
41+
return f"_{ord(match.group(0)):02x}_"
42+
43+
# Check if we're on Mac/Unix and using MLflow 2.16.0 or greater
44+
is_unix = platform.system() != "Windows"
45+
mlflow_version = version.parse(mlflow.__version__)
46+
allow_colon = is_unix and mlflow_version >= version.parse("2.16.0")
47+
48+
if allow_colon:
49+
pattern = r"[^\w\-./:\s]"
50+
else:
51+
pattern = r"[^\w\-./\s]"
52+
53+
encoded = re.sub(pattern, encode_char, name)
54+
base_name = encoded[:240] # Leave room for potential suffix to accommodate duplicates
55+
56+
if base_name in existing_names:
57+
suffix = 1
58+
# Edge case where even with suffix space there is a collision we will override one of the keys.
59+
while f"{base_name}_{suffix}" in existing_names:
60+
suffix += 1
61+
encoded = f"{base_name}_{suffix}"
62+
63+
# Max length is 250 for mlflow metric/params
64+
encoded = encoded[:250]
65+
66+
existing_names.add(encoded)
67+
return encoded
68+
69+
70+
def decode(encoded_metric_name: str) -> str:
71+
72+
# TODO: Utilize the stored name mappings to get the original key mappings without having to decode.
73+
"""Decodes an encoded metric name by replacing hexadecimal representations with their corresponding characters.
74+
75+
This function reverses the encoding process by converting hexadecimal codes
76+
back to their original characters. It looks for patterns of the form "_XX_"
77+
where XX is a two-digit hexadecimal code, and replaces them with the
78+
corresponding ASCII character.
79+
80+
Args:
81+
encoded_metric_name (str): The encoded metric name to be decoded.
82+
83+
Returns:
84+
str: The decoded metric name with hexadecimal codes replaced by their
85+
corresponding characters.
86+
87+
Example:
88+
>>> decode("loss_3a_val")
89+
"loss:val"
90+
"""
91+
92+
def replace_code(match):
93+
code = match.group(1)
94+
return chr(int(code, 16))
95+
96+
# Replace encoded characters
97+
decoded = re.sub(r"_([0-9a-f]{2})_", replace_code, encoded_metric_name)
98+
99+
return decoded
100+
101+
102+
def get_training_job_details(job_arn: str) -> dict:
103+
"""
104+
Retrieve details of a SageMaker training job.
105+
106+
Args:
107+
job_arn (str): The ARN of the SageMaker training job.
108+
109+
Returns:
110+
dict: A dictionary containing the details of the training job.
111+
112+
Raises:
113+
boto3.exceptions.Boto3Error: If there's an issue with the AWS API call.
114+
"""
115+
sagemaker_client = boto3.client("sagemaker")
116+
job_name = job_arn.split("/")[-1]
117+
return sagemaker_client.describe_training_job(TrainingJobName=job_name)
118+
119+
120+
def create_metric_queries(job_arn: str, metric_definitions: list) -> list:
121+
"""
122+
Create metric queries for SageMaker metrics.
123+
124+
Args:
125+
job_arn (str): The ARN of the SageMaker training job.
126+
metric_definitions (list): List of metric definitions from the training job.
127+
128+
Returns:
129+
list: A list of dictionaries, each representing a metric query.
130+
"""
131+
metric_queries = []
132+
for metric in metric_definitions:
133+
query = {
134+
"MetricName": metric["Name"],
135+
"XAxisType": "Timestamp",
136+
"MetricStat": "Avg",
137+
"Period": "OneMinute",
138+
"ResourceArn": job_arn,
139+
}
140+
metric_queries.append(query)
141+
return metric_queries
142+
143+
144+
def get_metric_data(metric_queries: list) -> dict:
145+
"""
146+
Retrieve metric data from SageMaker.
147+
148+
Args:
149+
metric_queries (list): A list of metric queries.
150+
151+
Returns:
152+
dict: A dictionary containing the metric data results.
153+
154+
Raises:
155+
boto3.exceptions.Boto3Error: If there's an issue with the AWS API call.
156+
"""
157+
sagemaker_metrics_client = boto3.client("sagemaker-metrics")
158+
metric_data = sagemaker_metrics_client.batch_get_metrics(MetricQueries=metric_queries)
159+
return metric_data
160+
161+
162+
def prepare_mlflow_metrics(
163+
metric_queries: list, metric_results: list
164+
) -> Tuple[List[Metric], Dict[str, str]]:
165+
"""
166+
Prepare metrics for MLflow logging, encoding metric names if necessary.
167+
168+
Args:
169+
metric_queries (list): The original metric queries sent to SageMaker.
170+
metric_results (list): The metric results from SageMaker batch_get_metrics.
171+
172+
Returns:
173+
Tuple[List[Metric], Dict[str, str]]:
174+
- A list of Metric objects with encoded names (if necessary)
175+
- A mapping of encoded to original names for metrics (only for encoded metrics)
176+
"""
177+
mlflow_metrics = []
178+
metric_name_mapping = {}
179+
existing_names = set()
180+
181+
for query, result in zip(metric_queries, metric_results):
182+
if result["Status"] == "Complete":
183+
metric_name = query["MetricName"]
184+
encoded_name = encode(metric_name, existing_names)
185+
metric_name_mapping[encoded_name] = metric_name
186+
187+
mlflow_metrics.extend(
188+
[
189+
Metric(key=encoded_name, value=value, timestamp=timestamp, step=step)
190+
for step, (timestamp, value) in enumerate(
191+
zip(result["XAxisValues"], result["MetricValues"])
192+
)
193+
]
194+
)
195+
196+
return mlflow_metrics, metric_name_mapping
197+
198+
199+
def prepare_mlflow_params(hyperparameters: Dict[str, str]) -> Tuple[List[Param], Dict[str, str]]:
200+
"""
201+
Prepare hyperparameters for MLflow logging, encoding parameter names if necessary.
202+
203+
Args:
204+
hyperparameters (Dict[str, str]): The hyperparameters from the SageMaker job.
205+
206+
Returns:
207+
Tuple[List[Param], Dict[str, str]]:
208+
- A list of Param objects with encoded names (if necessary)
209+
- A mapping of encoded to original names for hyperparameters (only for encoded parameters)
210+
"""
211+
mlflow_params = []
212+
param_name_mapping = {}
213+
existing_names = set()
214+
215+
for key, value in hyperparameters.items():
216+
encoded_key = encode(key, existing_names)
217+
param_name_mapping[encoded_key] = key
218+
mlflow_params.append(Param(encoded_key, str(value)))
219+
220+
return mlflow_params, param_name_mapping
221+
222+
223+
def batch_items(items: list, batch_size: int) -> list:
224+
"""
225+
Yield successive batch_size chunks from items.
226+
227+
Args:
228+
items (list): The list of items to be batched.
229+
batch_size (int): The size of each batch.
230+
231+
Yields:
232+
list: A batch of items.
233+
"""
234+
for i in range(0, len(items), batch_size):
235+
yield items[i : i + batch_size]
236+
237+
238+
def log_to_mlflow(metrics: list, params: list, tags: dict) -> None:
239+
"""
240+
Log metrics, parameters, and tags to MLflow.
241+
242+
Args:
243+
metrics (list): List of metrics to log.
244+
params (list): List of parameters to log.
245+
tags (dict): Dictionary of tags to set.
246+
247+
Raises:
248+
mlflow.exceptions.MlflowException: If there's an issue with MLflow logging.
249+
"""
250+
client = MlflowClient()
251+
252+
experiment_name = os.getenv("MLFLOW_EXPERIMENT_NAME")
253+
if experiment_name is None or experiment_name.strip() == "":
254+
experiment_name = "Default"
255+
print("MLFLOW_EXPERIMENT_NAME not set. Using Default")
256+
257+
experiment = client.get_experiment_by_name(experiment_name)
258+
if experiment is None:
259+
experiment_id = client.create_experiment(experiment_name)
260+
else:
261+
experiment_id = experiment.experiment_id
262+
263+
run = client.create_run(experiment_id)
264+
265+
for metric_batch in batch_items(metrics, 1000):
266+
client.log_batch(
267+
run.info.run_id,
268+
metrics=metric_batch,
269+
)
270+
for param_batch in batch_items(params, 1000):
271+
client.log_batch(run.info.run_id, params=param_batch)
272+
273+
tag_items = list(tags.items())
274+
for tag_batch in batch_items(tag_items, 1000):
275+
tag_objects = [RunTag(key, str(value)) for key, value in tag_batch]
276+
client.log_batch(run.info.run_id, tags=tag_objects)
277+
client.set_terminated(run.info.run_id)
278+
279+
280+
def log_sagemaker_job_to_mlflow(training_job_arn: str) -> None:
281+
"""
282+
Retrieve SageMaker metrics and hyperparameters and log them to MLflow.
283+
284+
Args:
285+
training_job_arn (str): The ARN of the SageMaker training job.
286+
287+
Raises:
288+
Exception: If there's any error during the process.
289+
"""
290+
# Get training job details
291+
mlflow.set_tracking_uri(os.getenv("MLFLOW_TRACKING_URI"))
292+
job_details = get_training_job_details(training_job_arn)
293+
294+
# Extract hyperparameters and metric definitions
295+
hyperparameters = job_details["HyperParameters"]
296+
metric_definitions = job_details["AlgorithmSpecification"]["MetricDefinitions"]
297+
298+
# Create and get metric queries
299+
metric_queries = create_metric_queries(job_details["TrainingJobArn"], metric_definitions)
300+
metric_data = get_metric_data(metric_queries)
301+
302+
# Create a mapping of encoded to original metric names
303+
# Prepare data for MLflow
304+
mlflow_metrics, metric_name_mapping = prepare_mlflow_metrics(
305+
metric_queries, metric_data["MetricQueryResults"]
306+
)
307+
308+
# Create a mapping of encoded to original hyperparameter names
309+
# Prepare data for MLflow
310+
mlflow_params, param_name_mapping = prepare_mlflow_params(hyperparameters)
311+
312+
mlflow_tags = {
313+
"training_job_arn": training_job_arn,
314+
"metric_name_mapping": str(metric_name_mapping),
315+
"param_name_mapping": str(param_name_mapping),
316+
}
317+
318+
# Log to MLflow
319+
log_to_mlflow(mlflow_metrics, mlflow_params, mlflow_tags)
320+
print(f"Logged {len(mlflow_metrics)} metric datapoints to MLflow")
321+
print(f"Logged {len(mlflow_params)} hyperparameters to MLflow")

0 commit comments

Comments
 (0)