11
11
# ANY KIND, either express or implied. See the License for the specific
12
12
# language governing permissions and limitations under the License.
13
13
14
+ """This module contains code related to forwarding SageMaker TrainingJob Metrics to MLflow."""
15
+
16
+ from __future__ import absolute_import
14
17
15
- import boto3
16
18
import os
19
+ import platform
20
+ import re
21
+ from typing import Set , Tuple , List , Dict , Generator
22
+ import boto3
17
23
import mlflow
18
24
from mlflow import MlflowClient
19
25
from mlflow .entities import Metric , Param , RunTag
20
26
21
27
from packaging import version
22
- import platform
23
- import re
24
-
25
- from typing import Set , Tuple , List , Dict
26
28
27
29
28
30
def encode (name : str , existing_names : Set [str ]) -> str :
29
- """
30
- Encode a string to comply with MLflow naming restrictions and ensure uniqueness.
31
+ """Encode a string to comply with MLflow naming restrictions and ensure uniqueness.
31
32
32
33
Args:
33
34
name (str): The original string to be encoded.
@@ -55,7 +56,8 @@ def encode_char(match):
55
56
56
57
if base_name in existing_names :
57
58
suffix = 1
58
- # Edge case where even with suffix space there is a collision we will override one of the keys.
59
+ # Edge case where even with suffix space there is a collision
60
+ # we will override one of the keys.
59
61
while f"{ base_name } _{ suffix } " in existing_names :
60
62
suffix += 1
61
63
encoded = f"{ base_name } _{ suffix } "
@@ -68,9 +70,7 @@ def encode_char(match):
68
70
69
71
70
72
def decode (encoded_metric_name : str ) -> str :
71
-
72
- # TODO: Utilize the stored name mappings to get the original key mappings without having to decode.
73
- """Decodes an encoded metric name by replacing hexadecimal representations with their corresponding characters.
73
+ """Decodes an encoded metric name by replacing hexadecimal representations with ASCII
74
74
75
75
This function reverses the encoding process by converting hexadecimal codes
76
76
back to their original characters. It looks for patterns of the form "_XX_"
@@ -100,8 +100,7 @@ def replace_code(match):
100
100
101
101
102
102
def get_training_job_details (job_arn : str ) -> dict :
103
- """
104
- Retrieve details of a SageMaker training job.
103
+ """Retrieve details of a SageMaker training job.
105
104
106
105
Args:
107
106
job_arn (str): The ARN of the SageMaker training job.
@@ -118,8 +117,7 @@ def get_training_job_details(job_arn: str) -> dict:
118
117
119
118
120
119
def create_metric_queries (job_arn : str , metric_definitions : list ) -> list :
121
- """
122
- Create metric queries for SageMaker metrics.
120
+ """Create metric queries for SageMaker metrics.
123
121
124
122
Args:
125
123
job_arn (str): The ARN of the SageMaker training job.
@@ -142,8 +140,7 @@ def create_metric_queries(job_arn: str, metric_definitions: list) -> list:
142
140
143
141
144
142
def get_metric_data (metric_queries : list ) -> dict :
145
- """
146
- Retrieve metric data from SageMaker.
143
+ """Retrieve metric data from SageMaker.
147
144
148
145
Args:
149
146
metric_queries (list): A list of metric queries.
@@ -162,8 +159,7 @@ def get_metric_data(metric_queries: list) -> dict:
162
159
def prepare_mlflow_metrics (
163
160
metric_queries : list , metric_results : list
164
161
) -> Tuple [List [Metric ], Dict [str , str ]]:
165
- """
166
- Prepare metrics for MLflow logging, encoding metric names if necessary.
162
+ """Prepare metrics for MLflow logging, encoding metric names if necessary.
167
163
168
164
Args:
169
165
metric_queries (list): The original metric queries sent to SageMaker.
@@ -184,29 +180,26 @@ def prepare_mlflow_metrics(
184
180
encoded_name = encode (metric_name , existing_names )
185
181
metric_name_mapping [encoded_name ] = metric_name
186
182
187
- mlflow_metrics .extend (
188
- [
189
- Metric (key = encoded_name , value = value , timestamp = timestamp , step = step )
190
- for step , (timestamp , value ) in enumerate (
191
- zip (result ["XAxisValues" ], result ["MetricValues" ])
192
- )
193
- ]
194
- )
183
+ for step , (timestamp , value ) in enumerate (
184
+ zip (result ["XAxisValues" ], result ["MetricValues" ])
185
+ ):
186
+ metric = Metric (key = encoded_name , value = value , timestamp = timestamp , step = step )
187
+ mlflow_metrics .append (metric )
195
188
196
189
return mlflow_metrics , metric_name_mapping
197
190
198
191
199
192
def prepare_mlflow_params (hyperparameters : Dict [str , str ]) -> Tuple [List [Param ], Dict [str , str ]]:
200
- """
201
- Prepare hyperparameters for MLflow logging, encoding parameter names if necessary.
193
+ """Prepare hyperparameters for MLflow logging, encoding parameter names if necessary.
202
194
203
195
Args:
204
196
hyperparameters (Dict[str, str]): The hyperparameters from the SageMaker job.
205
197
206
198
Returns:
207
199
Tuple[List[Param], Dict[str, str]]:
208
200
- A list of Param objects with encoded names (if necessary)
209
- - A mapping of encoded to original names for hyperparameters (only for encoded parameters)
201
+ - A mapping of encoded to original names for
202
+ hyperparameters (only for encoded parameters)
210
203
"""
211
204
mlflow_params = []
212
205
param_name_mapping = {}
@@ -220,9 +213,8 @@ def prepare_mlflow_params(hyperparameters: Dict[str, str]) -> Tuple[List[Param],
220
213
return mlflow_params , param_name_mapping
221
214
222
215
223
- def batch_items (items : list , batch_size : int ) -> list :
224
- """
225
- Yield successive batch_size chunks from items.
216
+ def batch_items (items : list , batch_size : int ) -> Generator :
217
+ """Yield successive batch_size chunks from items.
226
218
227
219
Args:
228
220
items (list): The list of items to be batched.
@@ -236,8 +228,7 @@ def batch_items(items: list, batch_size: int) -> list:
236
228
237
229
238
230
def log_to_mlflow (metrics : list , params : list , tags : dict ) -> None :
239
- """
240
- Log metrics, parameters, and tags to MLflow.
231
+ """Log metrics, parameters, and tags to MLflow.
241
232
242
233
Args:
243
234
metrics (list): List of metrics to log.
@@ -278,8 +269,7 @@ def log_to_mlflow(metrics: list, params: list, tags: dict) -> None:
278
269
279
270
280
271
def log_sagemaker_job_to_mlflow (training_job_arn : str ) -> None :
281
- """
282
- Retrieve SageMaker metrics and hyperparameters and log them to MLflow.
272
+ """Retrieve SageMaker metrics and hyperparameters and log them to MLflow.
283
273
284
274
Args:
285
275
training_job_arn (str): The ARN of the SageMaker training job.
0 commit comments