Skip to content

change: set _current_job_name and base_tuning_job_name in HyperparameterTuner.attach() #1650

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 30, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 15 additions & 23 deletions src/sagemaker/tuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
)
from sagemaker.session import Session
from sagemaker.session import s3_input
from sagemaker.utils import base_name_from_image, name_from_base
from sagemaker.utils import base_from_name, base_name_from_image, name_from_base

AMAZON_ESTIMATOR_MODULE = "sagemaker"
AMAZON_ESTIMATOR_CLS_NAMES = {
Expand Down Expand Up @@ -587,18 +587,21 @@ def attach(cls, tuning_job_name, sagemaker_session=None, job_details=None, estim
)

if "TrainingJobDefinition" in job_details:
return cls._attach_with_training_details(
tuning_job_name, sagemaker_session, estimator_cls, job_details
tuner = cls._attach_with_training_details(sagemaker_session, estimator_cls, job_details)
else:
tuner = cls._attach_with_training_details_list(
sagemaker_session, estimator_cls, job_details
)

return cls._attach_with_training_details_list(
tuning_job_name, sagemaker_session, estimator_cls, job_details
tuner.latest_tuning_job = _TuningJob(
sagemaker_session=sagemaker_session, job_name=tuning_job_name
)
tuner._current_job_name = tuning_job_name

return tuner

@classmethod
def _attach_with_training_details(
cls, tuning_job_name, sagemaker_session, estimator_cls, job_details
):
def _attach_with_training_details(cls, sagemaker_session, estimator_cls, job_details):
"""Create a HyperparameterTuner bound to an existing hyperparameter
tuning job that has the ``TrainingJobDefinition`` field set."""
estimator = cls._prepare_estimator(
Expand All @@ -609,17 +612,10 @@ def _attach_with_training_details(
)
init_params = cls._prepare_init_params_from_job_description(job_details)

tuner = cls(estimator=estimator, **init_params)
tuner.latest_tuning_job = _TuningJob(
sagemaker_session=sagemaker_session, job_name=tuning_job_name
)

return tuner
return cls(estimator=estimator, **init_params)

@classmethod
def _attach_with_training_details_list(
cls, tuning_job_name, sagemaker_session, estimator_cls, job_details
):
def _attach_with_training_details_list(cls, sagemaker_session, estimator_cls, job_details):
"""Create a HyperparameterTuner bound to an existing hyperparameter
tuning job that has the ``TrainingJobDefinitions`` field set."""
estimator_names = sorted(
Expand Down Expand Up @@ -664,18 +660,13 @@ def _attach_with_training_details_list(

init_params = cls._prepare_init_params_from_job_description(job_details)

tuner = HyperparameterTuner.create(
return HyperparameterTuner.create(
estimator_dict=estimator_dict,
objective_metric_name_dict=objective_metric_name_dict,
hyperparameter_ranges_dict=hyperparameter_ranges_dict,
metric_definitions_dict=metric_definitions_dict,
**init_params
)
tuner.latest_tuning_job = _TuningJob(
sagemaker_session=sagemaker_session, job_name=tuning_job_name
)

return tuner

def deploy(
self,
Expand Down Expand Up @@ -941,6 +932,7 @@ def _prepare_init_params_from_job_description(cls, job_details):
job_details.get("WarmStartConfig", None)
),
"early_stopping_type": tuning_config["TrainingJobEarlyStoppingType"],
"base_tuning_job_name": base_from_name(job_details["HyperParameterTuningJobName"]),
}

if "HyperParameterTuningJobObjective" in tuning_config:
Expand Down
23 changes: 19 additions & 4 deletions tests/unit/test_tuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,12 @@
import pytest
from mock import Mock, patch

from sagemaker import Predictor
from sagemaker import Predictor, utils
from sagemaker.amazon.amazon_estimator import RecordSet
from sagemaker.estimator import Framework
from sagemaker.mxnet import MXNet

from sagemaker.session import s3_input

from sagemaker.parameter import ParameterRange
from sagemaker.session import s3_input
from sagemaker.tuner import (
_TuningJob,
create_identical_dataset_and_algorithm_tuner,
Expand Down Expand Up @@ -498,6 +496,9 @@ def test_attach_tuning_job_with_estimator_from_hyperparameters(sagemaker_session
tuner = HyperparameterTuner.attach(JOB_NAME, sagemaker_session=sagemaker_session)

assert tuner.latest_tuning_job.name == JOB_NAME
assert tuner.base_tuning_job_name == JOB_NAME
assert tuner._current_job_name == JOB_NAME

assert tuner.objective_metric_name == OBJECTIVE_METRIC_NAME
assert tuner.max_jobs == 1
assert tuner.max_parallel_jobs == 1
Expand Down Expand Up @@ -580,6 +581,20 @@ def test_attach_with_no_specified_estimator(sagemaker_session):
assert isinstance(tuner.estimator, Estimator)


def test_attach_with_generated_job_name(sagemaker_session):
job_name = utils.name_from_base(BASE_JOB_NAME, max_length=32, short=True)

job_details = copy.deepcopy(TUNING_JOB_DETAILS)
job_details["HyperParameterTuningJobName"] = job_name

sagemaker_session.sagemaker_client.describe_hyper_parameter_tuning_job = Mock(
name="describe_tuning_job", return_value=job_details
)

tuner = HyperparameterTuner.attach(job_name, sagemaker_session=sagemaker_session)
assert BASE_JOB_NAME == tuner.base_tuning_job_name


def test_attach_with_warm_start_config(sagemaker_session):
warm_start_config = WarmStartConfig(
warm_start_type=WarmStartTypes.TRANSFER_LEARNING, parents={"p1", "p2"}
Expand Down