-
Notifications
You must be signed in to change notification settings - Fork 1.2k
fix: jumpstart amt tracking #3077
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
122d5a7
a00de67
29ee04a
a2cda63
d7cd740
d984c25
1aab22b
13efaef
85d372d
a26590b
897365f
383b2d6
ab7a60e
eef2ed7
21a5357
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -32,6 +32,7 @@ | |
from sagemaker.estimator import Framework | ||
from sagemaker.inputs import TrainingInput | ||
from sagemaker.job import _Job | ||
from sagemaker.jumpstart.utils import add_jumpstart_tags, get_jumpstart_base_name_if_jumpstart_model | ||
from sagemaker.parameter import ( | ||
CategoricalParameter, | ||
ContinuousParameter, | ||
|
@@ -319,6 +320,42 @@ def _prepare_for_tuning(self, job_name=None, include_cls_metadata=False): | |
"""Prepare the tuner instance for tuning (fit).""" | ||
self._prepare_job_name_for_tuning(job_name=job_name) | ||
self._prepare_static_hyperparameters_for_tuning(include_cls_metadata=include_cls_metadata) | ||
self._prepare_tags_for_tuning() | ||
|
||
def _get_model_uri( | ||
self, | ||
estimator, | ||
): | ||
"""Return the model artifact URI used by the Estimator instance. | ||
|
||
This attribute can live in multiple places, and accessing the attribute can | ||
raise a TypeError, which needs to be handled. | ||
""" | ||
try: | ||
return getattr(estimator, "model_data", None) | ||
except TypeError: | ||
return getattr(estimator, "model_uri", None) | ||
|
||
def _prepare_tags_for_tuning(self): | ||
"""Add tags to tuning job (from Estimator and JumpStart tags).""" | ||
|
||
# Add tags from Estimator class | ||
estimator = self.estimator or self.estimator_dict[sorted(self.estimator_dict.keys())[0]] | ||
|
||
estimator_tags = getattr(estimator, "tags", []) or [] | ||
|
||
if self.tags is None and len(estimator_tags) > 0: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Recommendation generated by Amazon CodeGuru Reviewer. Leave feedback on this recommendation by replying to the comment or by reacting to the comment using emoji. To check if a container or sequence (string, list, tuple) is empty, use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. +1 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sorry, I disagree with this suggestion. Because python has no type checking, |
||
self.tags = [] | ||
|
||
for tag in estimator_tags: | ||
if tag not in self.tags: | ||
self.tags.append(tag) | ||
|
||
self.tags = add_jumpstart_tags( | ||
evakravi marked this conversation as resolved.
Show resolved
Hide resolved
|
||
tags=self.tags, | ||
training_script_uri=getattr(estimator, "source_dir", None), | ||
training_model_uri=self._get_model_uri(estimator), | ||
) | ||
|
||
def _prepare_job_name_for_tuning(self, job_name=None): | ||
"""Set current job name before starting tuning.""" | ||
|
@@ -331,6 +368,12 @@ def _prepare_job_name_for_tuning(self, job_name=None): | |
self.estimator or self.estimator_dict[sorted(self.estimator_dict.keys())[0]] | ||
) | ||
base_name = base_name_from_image(estimator.training_image_uri()) | ||
|
||
jumpstart_base_name = get_jumpstart_base_name_if_jumpstart_model( | ||
getattr(estimator, "source_dir", None), | ||
self._get_model_uri(estimator), | ||
) | ||
base_name = jumpstart_base_name or base_name | ||
self._current_job_name = name_from_base( | ||
base_name, max_length=self.TUNING_JOB_NAME_MAX_LENGTH, short=True | ||
) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
will 'estimator_dict' always have a value
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It should always have a value. I used the same code as here:
sagemaker-python-sdk/src/sagemaker/tuner.py
Line 330 in 8d84618