Skip to content

Commit 122d5a7

Browse files
committed
fix: jumpstart amt tracking
1 parent 8d84618 commit 122d5a7

File tree

2 files changed

+110
-0
lines changed

2 files changed

+110
-0
lines changed

src/sagemaker/tuner.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
from sagemaker.estimator import Framework
3333
from sagemaker.inputs import TrainingInput
3434
from sagemaker.job import _Job
35+
from sagemaker.jumpstart.utils import add_jumpstart_tags, get_jumpstart_base_name_if_jumpstart_model
3536
from sagemaker.parameter import (
3637
CategoricalParameter,
3738
ContinuousParameter,
@@ -318,6 +319,44 @@ def _prepare_for_tuning(self, job_name=None, include_cls_metadata=False):
318319
"""Prepare the tuner instance for tuning (fit)."""
319320
self._prepare_job_name_for_tuning(job_name=job_name)
320321
self._prepare_static_hyperparameters_for_tuning(include_cls_metadata=include_cls_metadata)
322+
self._prepare_tags_for_tuning()
323+
324+
def _get_model_uri(
325+
self,
326+
estimator,
327+
):
328+
"""Get model uri attribute of ``Estimator`` object.
329+
330+
This attribute can live in multiple places, and accessing the attribute can
331+
raise a TypeError, which needs to be handled.
332+
"""
333+
try:
334+
model_data = getattr(estimator, "model_data", None)
335+
except TypeError:
336+
model_data = None
337+
return model_data or getattr(estimator, "model_uri", None)
338+
339+
def _prepare_tags_for_tuning(self):
340+
"""Add tags to tuning job (from Estimator and JumpStart tags)."""
341+
342+
# Add tags from Estimator class
343+
estimator = self.estimator or self.estimator_dict[sorted(self.estimator_dict.keys())[0]]
344+
345+
estimator_tags = getattr(estimator, "tags", []) or []
346+
347+
if self.tags is None and len(estimator_tags) > 0:
348+
self.tags = []
349+
350+
for tag in estimator_tags:
351+
if tag not in self.tags:
352+
self.tags.append(tag)
353+
354+
# Add JumpStart tags
355+
self.tags = add_jumpstart_tags(
356+
tags=self.tags,
357+
training_script_uri=getattr(estimator, "source_dir", None),
358+
training_model_uri=self._get_model_uri(estimator),
359+
)
321360

322361
def _prepare_job_name_for_tuning(self, job_name=None):
323362
"""Set current job name before starting tuning."""
@@ -330,6 +369,12 @@ def _prepare_job_name_for_tuning(self, job_name=None):
330369
self.estimator or self.estimator_dict[sorted(self.estimator_dict.keys())[0]]
331370
)
332371
base_name = base_name_from_image(estimator.training_image_uri())
372+
373+
jumpstart_base_name = get_jumpstart_base_name_if_jumpstart_model(
374+
getattr(estimator, "source_dir", None),
375+
self._get_model_uri(estimator),
376+
)
377+
base_name = jumpstart_base_name or base_name
333378
self._current_job_name = name_from_base(
334379
base_name, max_length=self.TUNING_JOB_NAME_MAX_LENGTH, short=True
335380
)

tests/unit/test_tuner.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,9 @@
2222
from sagemaker import Predictor, TrainingInput, utils
2323
from sagemaker.amazon.amazon_estimator import RecordSet
2424
from sagemaker.estimator import Framework
25+
from sagemaker.fw_utils import UploadedCode
26+
from sagemaker.jumpstart.constants import JUMPSTART_BUCKET_NAME_SET, JUMPSTART_RESOURCE_BASE_NAME
27+
from sagemaker.jumpstart.enums import JumpStartTag
2528
from sagemaker.mxnet import MXNet
2629
from sagemaker.parameter import ParameterRange
2730
from sagemaker.tuner import (
@@ -1518,3 +1521,65 @@ def _convert_tuning_job_details(job_details, estimator_name):
15181521
job_details_copy["TrainingJobDefinitions"] = [training_details]
15191522

15201523
return job_details_copy
1524+
1525+
1526+
@patch("time.time", return_value=510006209.073025)
1527+
@patch("sagemaker.estimator.tar_and_upload_dir")
1528+
@patch("sagemaker.model.Model._upload_code")
1529+
def test_tags_prefixes_jumpstart_models(
1530+
patched_upload_code, patched_tar_and_upload_dir, sagemaker_session
1531+
):
1532+
1533+
patched_tar_and_upload_dir.return_value = UploadedCode(
1534+
s3_prefix="s3://%s/%s" % ("bucket", "key"), script_name="script_name"
1535+
)
1536+
sagemaker_session.boto_region_name = REGION
1537+
1538+
instance_type = "ml.p2.xlarge"
1539+
instance_count = 1
1540+
1541+
training_data_uri = "s3://bucket/mydata"
1542+
1543+
jumpstart_source_dir = f"s3://{list(JUMPSTART_BUCKET_NAME_SET)[0]}/source_dirs/source.tar.gz"
1544+
jumpstart_source_dir_2 = f"s3://{list(JUMPSTART_BUCKET_NAME_SET)[1]}/source_dirs/source.tar.gz"
1545+
1546+
image_uri = "fake-image-uri"
1547+
1548+
generic_estimator = Estimator(
1549+
entry_point="transfer_learning.py",
1550+
role=ROLE,
1551+
region=REGION,
1552+
sagemaker_session=sagemaker_session,
1553+
instance_count=instance_count,
1554+
instance_type=instance_type,
1555+
source_dir=jumpstart_source_dir,
1556+
image_uri=image_uri,
1557+
model_uri=jumpstart_source_dir_2,
1558+
tags=[{"Key": "estimator-tag-key", "Value": "estimator-tag-value"}],
1559+
)
1560+
1561+
hp_tuner = HyperparameterTuner(
1562+
generic_estimator,
1563+
OBJECTIVE_METRIC_NAME,
1564+
HYPERPARAMETER_RANGES,
1565+
tags=[{"Key": "hp-tuner-tag-key", "Value": "hp-tuner-estimator-tag-value"}],
1566+
)
1567+
1568+
hp_tuner.fit({"training": training_data_uri})
1569+
1570+
assert [
1571+
{"Key": "hp-tuner-tag-key", "Value": "hp-tuner-estimator-tag-value"},
1572+
{"Key": "estimator-tag-key", "Value": "estimator-tag-value"},
1573+
{
1574+
"Key": JumpStartTag.TRAINING_MODEL_URI.value,
1575+
"Value": jumpstart_source_dir_2,
1576+
},
1577+
{
1578+
"Key": JumpStartTag.TRAINING_SCRIPT_URI.value,
1579+
"Value": jumpstart_source_dir,
1580+
},
1581+
] == sagemaker_session.create_tuning_job.call_args_list[0][1]["tags"]
1582+
1583+
assert sagemaker_session.create_tuning_job.call_args_list[0][1]["job_name"].startswith(
1584+
JUMPSTART_RESOURCE_BASE_NAME
1585+
)

0 commit comments

Comments
 (0)