Skip to content

Commit 32b921f

Browse files
evakraviclaytonparnell
authored andcommitted
feat: jumpstart training metrics
1 parent 9e5dca6 commit 32b921f

File tree

7 files changed

+166
-1
lines changed

7 files changed

+166
-1
lines changed

src/sagemaker/jumpstart/artifacts.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,9 @@
1212
# language governing permissions and limitations under the License.
1313
"""This module contains functions for obtaining JumpStart ECR and S3 URIs."""
1414
from __future__ import absolute_import
15+
from copy import deepcopy
1516
import os
16-
from typing import Dict, Optional
17+
from typing import Dict, List, Optional
1718
from sagemaker import image_uris
1819
from sagemaker.jumpstart.constants import (
1920
ENV_VARIABLE_JUMPSTART_MODEL_ARTIFACT_BUCKET_OVERRIDE,
@@ -363,3 +364,32 @@ def _retrieve_default_environment_variables(
363364
for environment_variable in model_specs.inference_environment_variables:
364365
default_environment_variables[environment_variable.name] = str(environment_variable.default)
365366
return default_environment_variables
367+
368+
369+
def _retrieve_default_training_metric_definitions(
370+
model_id: str,
371+
model_version: str,
372+
region: Optional[str],
373+
) -> Optional[List[Dict[str, str]]]:
374+
"""Retrieves the default training metric definitions for the model.
375+
376+
Args:
377+
model_id (str): JumpStart model ID of the JumpStart model for which to
378+
retrieve the default training metric definitions.
379+
model_version (str): Version of the JumpStart model for which to retrieve the
380+
default training metric definitions.
381+
region (Optional[str]): Region for which to retrieve default training metric
382+
definitions.
383+
384+
Returns:
385+
list: the default training metric definitions to use for the model or None.
386+
"""
387+
388+
if region is None:
389+
region = JUMPSTART_DEFAULT_REGION_NAME
390+
391+
model_specs = jumpstart_accessors.JumpStartModelsAccessor.get_model_specs(
392+
region=region, model_id=model_id, version=model_version
393+
)
394+
395+
return deepcopy(model_specs.metrics) if model_specs.metrics else None

src/sagemaker/jumpstart/types.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -292,6 +292,7 @@ class JumpStartModelSpecs(JumpStartDataHolderType):
292292
"training_dependencies",
293293
"training_vulnerabilities",
294294
"deprecated",
295+
"metrics",
295296
]
296297

297298
def __init__(self, spec: Dict[str, Any]):
@@ -328,6 +329,7 @@ def from_json(self, json_obj: Dict[str, Any]) -> None:
328329
self.training_dependencies: List[str] = json_obj["training_dependencies"]
329330
self.training_vulnerabilities: List[str] = json_obj["training_vulnerabilities"]
330331
self.deprecated: bool = bool(json_obj["deprecated"])
332+
self.metrics: Optional[List[Dict[str, str]]] = json_obj.get("metrics", None)
331333

332334
if self.training_supported:
333335
self.training_ecr_specs: JumpStartECRSpecs = JumpStartECRSpecs(

src/sagemaker/metric_definitions.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""Accessors to retrieve metric definition for training jobs."""
14+
15+
from __future__ import absolute_import
16+
17+
import logging
18+
from typing import Dict, Optional, List
19+
20+
from sagemaker.jumpstart import utils as jumpstart_utils
21+
from sagemaker.jumpstart import artifacts
22+
23+
logger = logging.getLogger(__name__)
24+
25+
26+
def retrieve_default(
27+
region: Optional[str] = None,
28+
model_id: Optional[str] = None,
29+
model_version: Optional[str] = None,
30+
) -> Optional[List[Dict[str, str]]]:
31+
"""Retrieves the default training metric definitions for the model matching the given arguments.
32+
33+
Args:
34+
region (str): The AWS Region for which to retrieve the default default training metric
35+
definitions. Defaults to ``None``.
36+
model_id (str): The model ID of the model for which to
37+
retrieve the default training metric definitions. (Default: None).
38+
model_version (str): The version of the model for which to retrieve the
39+
default training metric definitions. (Default: None).
40+
Returns:
41+
list: The default metric definitions to use for the model or None.
42+
43+
Raises:
44+
ValueError: If the combination of arguments specified is not supported.
45+
"""
46+
if not jumpstart_utils.is_jumpstart_model_input(model_id, model_version):
47+
raise ValueError(
48+
"Must specify `model_id` and `model_version` when retrieving default training "
49+
"metric definitions."
50+
)
51+
52+
return artifacts._retrieve_default_training_metric_definitions(model_id, model_version, region)

tests/unit/sagemaker/jumpstart/constants.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1183,6 +1183,7 @@
11831183
"training_dependencies": [],
11841184
"training_vulnerabilities": [],
11851185
"deprecated": False,
1186+
"metrics": [{"Regex": "val_accuracy: ([0-9\\.]+)", "Name": "pytorch-ic:val-accuracy"}],
11861187
}
11871188

11881189
BASE_HEADER = {

tests/unit/sagemaker/metric_definitions/__init__.py

Whitespace-only changes.

tests/unit/sagemaker/metric_definitions/jumpstart/__init__.py

Whitespace-only changes.
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
from __future__ import absolute_import
14+
15+
16+
from mock.mock import patch
17+
import pytest
18+
19+
from sagemaker import metric_definitions
20+
21+
from tests.unit.sagemaker.jumpstart.utils import get_spec_from_base_spec
22+
23+
24+
@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs")
25+
def test_jumpstart_default_metric_definitions(patched_get_model_specs):
26+
27+
patched_get_model_specs.side_effect = get_spec_from_base_spec
28+
29+
model_id = "pytorch-ic-mobilenet-v2"
30+
region = "us-west-2"
31+
32+
definitions = metric_definitions.retrieve_default(
33+
region=region,
34+
model_id=model_id,
35+
model_version="*",
36+
)
37+
assert definitions == [
38+
{"Regex": "val_accuracy: ([0-9\\.]+)", "Name": "pytorch-ic:val-accuracy"}
39+
]
40+
41+
patched_get_model_specs.assert_called_once_with(region=region, model_id=model_id, version="*")
42+
43+
patched_get_model_specs.reset_mock()
44+
45+
definitions = metric_definitions.retrieve_default(
46+
region=region,
47+
model_id=model_id,
48+
model_version="1.*",
49+
)
50+
assert definitions == [
51+
{"Regex": "val_accuracy: ([0-9\\.]+)", "Name": "pytorch-ic:val-accuracy"}
52+
]
53+
54+
patched_get_model_specs.assert_called_once_with(region=region, model_id=model_id, version="1.*")
55+
56+
patched_get_model_specs.reset_mock()
57+
58+
with pytest.raises(KeyError):
59+
metric_definitions.retrieve_default(
60+
region=region,
61+
model_id="blah",
62+
model_version="*",
63+
)
64+
65+
with pytest.raises(ValueError):
66+
metric_definitions.retrieve_default(
67+
region="mars-south-1",
68+
model_id=model_id,
69+
model_version="*",
70+
)
71+
72+
with pytest.raises(ValueError):
73+
metric_definitions.retrieve_default(
74+
model_version="*",
75+
)
76+
77+
with pytest.raises(ValueError):
78+
metric_definitions.retrieve_default(
79+
model_id=model_id,
80+
)

0 commit comments

Comments
 (0)