Skip to content

Commit a00de67

Browse files
committed
chore: address git comments
1 parent 122d5a7 commit a00de67

File tree

2 files changed

+57
-5
lines changed

2 files changed

+57
-5
lines changed

src/sagemaker/tuner.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -325,16 +325,15 @@ def _get_model_uri(
325325
self,
326326
estimator,
327327
):
328-
"""Get model uri attribute of ``Estimator`` object.
328+
"""Return the model artifact URI used by the Estimator instance.
329329
330330
This attribute can live in multiple places, and accessing the attribute can
331331
raise a TypeError, which needs to be handled.
332332
"""
333333
try:
334-
model_data = getattr(estimator, "model_data", None)
334+
return getattr(estimator, "model_data", None)
335335
except TypeError:
336-
model_data = None
337-
return model_data or getattr(estimator, "model_uri", None)
336+
return getattr(estimator, "model_uri", None)
338337

339338
def _prepare_tags_for_tuning(self):
340339
"""Add tags to tuning job (from Estimator and JumpStart tags)."""
@@ -351,7 +350,6 @@ def _prepare_tags_for_tuning(self):
351350
if tag not in self.tags:
352351
self.tags.append(tag)
353352

354-
# Add JumpStart tags
355353
self.tags = add_jumpstart_tags(
356354
tags=self.tags,
357355
training_script_uri=getattr(estimator, "source_dir", None),

tests/unit/test_tuner.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1583,3 +1583,57 @@ def test_tags_prefixes_jumpstart_models(
15831583
assert sagemaker_session.create_tuning_job.call_args_list[0][1]["job_name"].startswith(
15841584
JUMPSTART_RESOURCE_BASE_NAME
15851585
)
1586+
1587+
1588+
@patch("time.time", return_value=510006209.073025)
1589+
@patch("sagemaker.estimator.tar_and_upload_dir")
1590+
@patch("sagemaker.model.Model._upload_code")
1591+
def test_no_tags_prefixes_non_jumpstart_models(
1592+
patched_upload_code, patched_tar_and_upload_dir, sagemaker_session
1593+
):
1594+
1595+
patched_tar_and_upload_dir.return_value = UploadedCode(
1596+
s3_prefix="s3://%s/%s" % ("bucket", "key"), script_name="script_name"
1597+
)
1598+
sagemaker_session.boto_region_name = REGION
1599+
1600+
instance_type = "ml.p2.xlarge"
1601+
instance_count = 1
1602+
1603+
training_data_uri = "s3://bucket/mydata"
1604+
1605+
non_jumpstart_source_dir = "s3://non-js-bucket/sdfsdfs"
1606+
non_jumpstart_source_dir_2 = "s3://non-js-bucket/sdfsdsfsdfsddfs"
1607+
1608+
image_uri = "fake-image-uri"
1609+
1610+
generic_estimator = Estimator(
1611+
entry_point="transfer_learning.py",
1612+
role=ROLE,
1613+
region=REGION,
1614+
sagemaker_session=sagemaker_session,
1615+
instance_count=instance_count,
1616+
instance_type=instance_type,
1617+
source_dir=non_jumpstart_source_dir,
1618+
image_uri=image_uri,
1619+
model_uri=non_jumpstart_source_dir_2,
1620+
tags=[{"Key": "estimator-tag-key", "Value": "estimator-tag-value"}],
1621+
)
1622+
1623+
hp_tuner = HyperparameterTuner(
1624+
generic_estimator,
1625+
OBJECTIVE_METRIC_NAME,
1626+
HYPERPARAMETER_RANGES,
1627+
tags=[{"Key": "hp-tuner-tag-key", "Value": "hp-tuner-estimator-tag-value"}],
1628+
)
1629+
1630+
hp_tuner.fit({"training": training_data_uri})
1631+
1632+
assert [
1633+
{"Key": "hp-tuner-tag-key", "Value": "hp-tuner-estimator-tag-value"},
1634+
{"Key": "estimator-tag-key", "Value": "estimator-tag-value"},
1635+
] == sagemaker_session.create_tuning_job.call_args_list[0][1]["tags"]
1636+
1637+
assert not sagemaker_session.create_tuning_job.call_args_list[0][1]["job_name"].startswith(
1638+
JUMPSTART_RESOURCE_BASE_NAME
1639+
)

0 commit comments

Comments
 (0)