Skip to content

Commit ef986c5

Browse files
authored
fix: jumpstart amt tracking (#3077)
1 parent c53bbbc commit ef986c5

File tree

2 files changed

+299
-0
lines changed

2 files changed

+299
-0
lines changed

src/sagemaker/tuner.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
from sagemaker.estimator import Framework
3333
from sagemaker.inputs import TrainingInput
3434
from sagemaker.job import _Job
35+
from sagemaker.jumpstart.utils import add_jumpstart_tags, get_jumpstart_base_name_if_jumpstart_model
3536
from sagemaker.parameter import (
3637
CategoricalParameter,
3738
ContinuousParameter,
@@ -319,6 +320,42 @@ def _prepare_for_tuning(self, job_name=None, include_cls_metadata=False):
319320
"""Prepare the tuner instance for tuning (fit)."""
320321
self._prepare_job_name_for_tuning(job_name=job_name)
321322
self._prepare_static_hyperparameters_for_tuning(include_cls_metadata=include_cls_metadata)
323+
self._prepare_tags_for_tuning()
324+
325+
def _get_model_uri(
326+
self,
327+
estimator,
328+
):
329+
"""Return the model artifact URI used by the Estimator instance.
330+
331+
This attribute can live in multiple places, and accessing the attribute can
332+
raise a TypeError, which needs to be handled.
333+
"""
334+
try:
335+
return getattr(estimator, "model_data", None)
336+
except TypeError:
337+
return getattr(estimator, "model_uri", None)
338+
339+
def _prepare_tags_for_tuning(self):
340+
"""Add tags to tuning job (from Estimator and JumpStart tags)."""
341+
342+
# Add tags from Estimator class
343+
estimator = self.estimator or self.estimator_dict[sorted(self.estimator_dict.keys())[0]]
344+
345+
estimator_tags = getattr(estimator, "tags", []) or []
346+
347+
if self.tags is None and len(estimator_tags) > 0:
348+
self.tags = []
349+
350+
for tag in estimator_tags:
351+
if tag not in self.tags:
352+
self.tags.append(tag)
353+
354+
self.tags = add_jumpstart_tags(
355+
tags=self.tags,
356+
training_script_uri=getattr(estimator, "source_dir", None),
357+
training_model_uri=self._get_model_uri(estimator),
358+
)
322359

323360
def _prepare_job_name_for_tuning(self, job_name=None):
324361
"""Set current job name before starting tuning."""
@@ -331,6 +368,12 @@ def _prepare_job_name_for_tuning(self, job_name=None):
331368
self.estimator or self.estimator_dict[sorted(self.estimator_dict.keys())[0]]
332369
)
333370
base_name = base_name_from_image(estimator.training_image_uri())
371+
372+
jumpstart_base_name = get_jumpstart_base_name_if_jumpstart_model(
373+
getattr(estimator, "source_dir", None),
374+
self._get_model_uri(estimator),
375+
)
376+
base_name = jumpstart_base_name or base_name
334377
self._current_job_name = name_from_base(
335378
base_name, max_length=self.TUNING_JOB_NAME_MAX_LENGTH, short=True
336379
)

tests/unit/test_tuner.py

Lines changed: 256 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,9 @@
2222
from sagemaker import Predictor, TrainingInput, utils
2323
from sagemaker.amazon.amazon_estimator import RecordSet
2424
from sagemaker.estimator import Framework
25+
from sagemaker.fw_utils import UploadedCode
26+
from sagemaker.jumpstart.constants import JUMPSTART_BUCKET_NAME_SET, JUMPSTART_RESOURCE_BASE_NAME
27+
from sagemaker.jumpstart.enums import JumpStartTag
2528
from sagemaker.mxnet import MXNet
2629
from sagemaker.parameter import ParameterRange
2730
from sagemaker.tuner import (
@@ -1518,3 +1521,256 @@ def _convert_tuning_job_details(job_details, estimator_name):
15181521
job_details_copy["TrainingJobDefinitions"] = [training_details]
15191522

15201523
return job_details_copy
1524+
1525+
1526+
@patch("time.time", return_value=510006209.073025)
1527+
@patch("sagemaker.estimator.tar_and_upload_dir")
1528+
@patch("sagemaker.model.Model._upload_code")
1529+
def test_tags_prefixes_jumpstart_models(
1530+
patched_upload_code, patched_tar_and_upload_dir, sagemaker_session
1531+
):
1532+
1533+
jumpstart_source_dir = f"s3://{list(JUMPSTART_BUCKET_NAME_SET)[0]}/source_dirs/source.tar.gz"
1534+
jumpstart_source_dir_2 = f"s3://{list(JUMPSTART_BUCKET_NAME_SET)[1]}/source_dirs/source.tar.gz"
1535+
jumpstart_source_dir_3 = f"s3://{list(JUMPSTART_BUCKET_NAME_SET)[2]}/source_dirs/source.tar.gz"
1536+
1537+
estimator_tag = {"Key": "estimator-tag-key", "Value": "estimator-tag-value"}
1538+
hp_tag = {"Key": "hp-tuner-tag-key", "Value": "hp-tuner-estimator-tag-value"}
1539+
training_model_uri_tag = {
1540+
"Key": JumpStartTag.TRAINING_MODEL_URI.value,
1541+
"Value": jumpstart_source_dir_2,
1542+
}
1543+
training_script_uri_tag = {
1544+
"Key": JumpStartTag.TRAINING_SCRIPT_URI.value,
1545+
"Value": jumpstart_source_dir,
1546+
}
1547+
inference_script_uri_tag = {
1548+
"Key": JumpStartTag.INFERENCE_SCRIPT_URI.value,
1549+
"Value": jumpstart_source_dir_3,
1550+
}
1551+
1552+
patched_tar_and_upload_dir.return_value = UploadedCode(
1553+
s3_prefix="s3://%s/%s" % ("bucket", "key"), script_name="script_name"
1554+
)
1555+
sagemaker_session.boto_region_name = REGION
1556+
1557+
sagemaker_session.sagemaker_client.describe_training_job.return_value = {
1558+
"AlgorithmSpecification": {
1559+
"TrainingInputMode": "File",
1560+
"TrainingImage": "1.dkr.ecr.us-west-2.amazonaws.com/sagemaker-other:1.0.4",
1561+
},
1562+
"HyperParameters": {
1563+
"sagemaker_submit_directory": '"s3://some/sourcedir.tar.gz"',
1564+
"checkpoint_path": '"s3://other/1508872349"',
1565+
"sagemaker_program": '"iris-dnn-classifier.py"',
1566+
"sagemaker_container_log_level": '"logging.INFO"',
1567+
"sagemaker_job_name": '"neo"',
1568+
"training_steps": "100",
1569+
},
1570+
"RoleArn": "arn:aws:iam::366:role/SageMakerRole",
1571+
"ResourceConfig": {
1572+
"VolumeSizeInGB": 30,
1573+
"InstanceCount": 1,
1574+
"InstanceType": "ml.c4.xlarge",
1575+
},
1576+
"EnableNetworkIsolation": False,
1577+
"StoppingCondition": {"MaxRuntimeInSeconds": 24 * 60 * 60},
1578+
"TrainingJobName": "neo",
1579+
"TrainingJobStatus": "Completed",
1580+
"TrainingJobArn": "arn:aws:sagemaker:us-west-2:336:training-job/neo",
1581+
"OutputDataConfig": {"KmsKeyId": "", "S3OutputPath": "s3://place/output/neo"},
1582+
"TrainingJobOutput": {"S3TrainingJobOutput": "s3://here/output.tar.gz"},
1583+
"EnableInterContainerTrafficEncryption": False,
1584+
"ModelArtifacts": {"S3ModelArtifacts": "blah"},
1585+
}
1586+
1587+
sagemaker_session.sagemaker_client.list_tags.return_value = {
1588+
"Tags": [
1589+
estimator_tag,
1590+
hp_tag,
1591+
training_model_uri_tag,
1592+
training_script_uri_tag,
1593+
]
1594+
}
1595+
1596+
instance_type = "ml.p2.xlarge"
1597+
instance_count = 1
1598+
1599+
training_data_uri = "s3://bucket/mydata"
1600+
1601+
image_uri = "fake-image-uri"
1602+
1603+
generic_estimator = Estimator(
1604+
entry_point="transfer_learning.py",
1605+
role=ROLE,
1606+
region=REGION,
1607+
sagemaker_session=sagemaker_session,
1608+
instance_count=instance_count,
1609+
instance_type=instance_type,
1610+
source_dir=jumpstart_source_dir,
1611+
image_uri=image_uri,
1612+
model_uri=jumpstart_source_dir_2,
1613+
tags=[estimator_tag],
1614+
)
1615+
1616+
hp_tuner = HyperparameterTuner(
1617+
generic_estimator,
1618+
OBJECTIVE_METRIC_NAME,
1619+
HYPERPARAMETER_RANGES,
1620+
tags=[hp_tag],
1621+
)
1622+
1623+
hp_tuner.fit({"training": training_data_uri})
1624+
1625+
assert [
1626+
hp_tag,
1627+
estimator_tag,
1628+
training_model_uri_tag,
1629+
training_script_uri_tag,
1630+
] == sagemaker_session.create_tuning_job.call_args_list[0][1]["tags"]
1631+
1632+
assert sagemaker_session.create_tuning_job.call_args_list[0][1]["job_name"].startswith(
1633+
JUMPSTART_RESOURCE_BASE_NAME
1634+
)
1635+
1636+
hp_tuner.deploy(
1637+
initial_instance_count=INSTANCE_COUNT,
1638+
instance_type=INSTANCE_TYPE,
1639+
image_uri=image_uri,
1640+
source_dir=jumpstart_source_dir_3,
1641+
entry_point="inference.py",
1642+
role=ROLE,
1643+
)
1644+
1645+
assert sagemaker_session.create_model.call_args_list[0][1]["name"].startswith(
1646+
JUMPSTART_RESOURCE_BASE_NAME
1647+
)
1648+
1649+
assert sagemaker_session.endpoint_from_production_variants.call_args_list[0].startswith(
1650+
JUMPSTART_RESOURCE_BASE_NAME
1651+
)
1652+
1653+
assert sagemaker_session.create_model.call_args_list[0][1]["tags"] == [
1654+
training_model_uri_tag,
1655+
training_script_uri_tag,
1656+
inference_script_uri_tag,
1657+
]
1658+
1659+
assert sagemaker_session.endpoint_from_production_variants.call_args_list[0][1]["tags"] == [
1660+
training_model_uri_tag,
1661+
training_script_uri_tag,
1662+
inference_script_uri_tag,
1663+
]
1664+
1665+
1666+
@patch("time.time", return_value=510006209.073025)
1667+
@patch("sagemaker.estimator.tar_and_upload_dir")
1668+
@patch("sagemaker.model.Model._upload_code")
1669+
def test_no_tags_prefixes_non_jumpstart_models(
1670+
patched_upload_code, patched_tar_and_upload_dir, sagemaker_session
1671+
):
1672+
1673+
non_jumpstart_source_dir = "s3://blah1/source_dirs/source.tar.gz"
1674+
non_jumpstart_source_dir_2 = "s3://blah2/source_dirs/source.tar.gz"
1675+
non_jumpstart_source_dir_3 = "s3://blah3/source_dirs/source.tar.gz"
1676+
1677+
estimator_tag = {"Key": "estimator-tag-key", "Value": "estimator-tag-value"}
1678+
hp_tag = {"Key": "hp-tuner-tag-key", "Value": "hp-tuner-estimator-tag-value"}
1679+
1680+
patched_tar_and_upload_dir.return_value = UploadedCode(
1681+
s3_prefix="s3://%s/%s" % ("bucket", "key"), script_name="script_name"
1682+
)
1683+
sagemaker_session.boto_region_name = REGION
1684+
1685+
sagemaker_session.sagemaker_client.describe_training_job.return_value = {
1686+
"AlgorithmSpecification": {
1687+
"TrainingInputMode": "File",
1688+
"TrainingImage": "1.dkr.ecr.us-west-2.amazonaws.com/sagemaker-other:1.0.4",
1689+
},
1690+
"HyperParameters": {
1691+
"sagemaker_submit_directory": '"s3://some/sourcedir.tar.gz"',
1692+
"checkpoint_path": '"s3://other/1508872349"',
1693+
"sagemaker_program": '"iris-dnn-classifier.py"',
1694+
"sagemaker_container_log_level": '"logging.INFO"',
1695+
"sagemaker_job_name": '"neo"',
1696+
"training_steps": "100",
1697+
},
1698+
"RoleArn": "arn:aws:iam::366:role/SageMakerRole",
1699+
"ResourceConfig": {
1700+
"VolumeSizeInGB": 30,
1701+
"InstanceCount": 1,
1702+
"InstanceType": "ml.c4.xlarge",
1703+
},
1704+
"EnableNetworkIsolation": False,
1705+
"StoppingCondition": {"MaxRuntimeInSeconds": 24 * 60 * 60},
1706+
"TrainingJobName": "neo",
1707+
"TrainingJobStatus": "Completed",
1708+
"TrainingJobArn": "arn:aws:sagemaker:us-west-2:336:training-job/neo",
1709+
"OutputDataConfig": {"KmsKeyId": "", "S3OutputPath": "s3://place/output/neo"},
1710+
"TrainingJobOutput": {"S3TrainingJobOutput": "s3://here/output.tar.gz"},
1711+
"EnableInterContainerTrafficEncryption": False,
1712+
"ModelArtifacts": {"S3ModelArtifacts": "blah"},
1713+
}
1714+
1715+
sagemaker_session.sagemaker_client.list_tags.return_value = {"Tags": []}
1716+
1717+
sagemaker_session.sagemaker_client.describe_hyper_parameter_tuning_job.return_value = {
1718+
"BestTrainingJob": {"TrainingJobName": "some-name"}
1719+
}
1720+
instance_type = "ml.p2.xlarge"
1721+
instance_count = 1
1722+
1723+
training_data_uri = "s3://bucket/mydata"
1724+
1725+
image_uri = "fake-image-uri"
1726+
1727+
generic_estimator = Estimator(
1728+
entry_point="transfer_learning.py",
1729+
role=ROLE,
1730+
region=REGION,
1731+
sagemaker_session=sagemaker_session,
1732+
instance_count=instance_count,
1733+
instance_type=instance_type,
1734+
source_dir=non_jumpstart_source_dir,
1735+
image_uri=image_uri,
1736+
model_uri=non_jumpstart_source_dir_2,
1737+
tags=[estimator_tag],
1738+
)
1739+
1740+
hp_tuner = HyperparameterTuner(
1741+
generic_estimator,
1742+
OBJECTIVE_METRIC_NAME,
1743+
HYPERPARAMETER_RANGES,
1744+
tags=[hp_tag],
1745+
)
1746+
1747+
hp_tuner.fit({"training": training_data_uri})
1748+
1749+
assert [hp_tag, estimator_tag] == sagemaker_session.create_tuning_job.call_args_list[0][1][
1750+
"tags"
1751+
]
1752+
1753+
assert not sagemaker_session.create_tuning_job.call_args_list[0][1]["job_name"].startswith(
1754+
JUMPSTART_RESOURCE_BASE_NAME
1755+
)
1756+
1757+
hp_tuner.deploy(
1758+
initial_instance_count=INSTANCE_COUNT,
1759+
instance_type=INSTANCE_TYPE,
1760+
image_uri=image_uri,
1761+
source_dir=non_jumpstart_source_dir_3,
1762+
entry_point="inference.py",
1763+
role=ROLE,
1764+
)
1765+
1766+
assert not sagemaker_session.create_model.call_args_list[0][1]["name"].startswith(
1767+
JUMPSTART_RESOURCE_BASE_NAME
1768+
)
1769+
1770+
assert not sagemaker_session.endpoint_from_production_variants.call_args_list[0][1][
1771+
"name"
1772+
].startswith(JUMPSTART_RESOURCE_BASE_NAME)
1773+
1774+
assert sagemaker_session.create_model.call_args_list[0][1]["tags"] == []
1775+
1776+
assert sagemaker_session.endpoint_from_production_variants.call_args_list[0][1]["tags"] == []

0 commit comments

Comments
 (0)