Skip to content

fix: ignore private Automatic Model Tuning hyperparameter when attaching AlgorithmEstimator #1230

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jan 15, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions src/sagemaker/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -553,3 +553,28 @@ def _algorithm_training_input_modes(self, training_channels):
current_input_modes = current_input_modes & supported_input_modes

return current_input_modes

@classmethod
def _prepare_init_params_from_job_description(cls, job_details, model_channel_name=None):
"""Convert the job description to init params that can be handled by the
class constructor

Args:
job_details (dict): the returned job details from a DescribeTrainingJob
API call.
model_channel_name (str): Name of the channel where pre-trained
model data will be downloaded.

Returns:
dict: The transformed init_params
"""
init_params = super(AlgorithmEstimator, cls)._prepare_init_params_from_job_description(
job_details, model_channel_name
)

# This hyperparameter is added by Amazon SageMaker Automatic Model Tuning.
# It cannot be set through instantiating an estimator.
if "_tuning_objective_metric" in init_params["hyperparameters"]:
del init_params["hyperparameters"]["_tuning_objective_metric"]

return init_params
72 changes: 72 additions & 0 deletions tests/unit/test_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -943,3 +943,75 @@ def test_algorithm_no_required_hyperparameters(session):
train_instance_count=1,
sagemaker_session=session,
)


def test_algorithm_attach_from_hyperparameter_tuning():
session = Mock()
job_name = "training-job-that-is-part-of-a-tuning-job"
algo_arn = "arn:aws:sagemaker:us-east-2:000000000000:algorithm/scikit-decision-trees"
role_arn = "arn:aws:iam::123412341234:role/SageMakerRole"
instance_count = 1
instance_type = "ml.m4.xlarge"
train_volume_size = 30
input_mode = "File"

session.sagemaker_client.list_tags.return_value = {"Tags": []}
session.sagemaker_client.describe_algorithm.return_value = DESCRIBE_ALGORITHM_RESPONSE
session.sagemaker_client.describe_training_job.return_value = {
"TrainingJobName": job_name,
"TrainingJobArn": "arn:aws:sagemaker:us-east-2:123412341234:training-job/%s" % job_name,
"TuningJobArn": "arn:aws:sagemaker:us-east-2:123412341234:hyper-parameter-tuning-job/%s"
% job_name,
"ModelArtifacts": {
"S3ModelArtifacts": "s3://sagemaker-us-east-2-123412341234/output/model.tar.gz"
},
"TrainingJobOutput": {
"S3TrainingJobOutput": "s3://sagemaker-us-east-2-123412341234/output/output.tar.gz"
},
"TrainingJobStatus": "Succeeded",
"HyperParameters": {
"_tuning_objective_metric": "validation:accuracy",
"max_leaf_nodes": 1,
"free_text_hp1": "foo",
},
"AlgorithmSpecification": {"AlgorithmName": algo_arn, "TrainingInputMode": input_mode},
"MetricDefinitions": [
{"Name": "validation:accuracy", "Regex": "validation-accuracy: (\\S+)"}
],
"RoleArn": role_arn,
"InputDataConfig": [
{
"ChannelName": "training",
"DataSource": {
"S3DataSource": {
"S3DataType": "S3Prefix",
"S3Uri": "s3://sagemaker-us-east-2-123412341234/input/training.csv",
"S3DataDistributionType": "FullyReplicated",
}
},
"CompressionType": "None",
"RecordWrapperType": "None",
}
],
"OutputDataConfig": {
"KmsKeyId": "",
"S3OutputPath": "s3://sagemaker-us-east-2-123412341234/output",
"RemoveJobNameFromS3OutputPath": False,
},
"ResourceConfig": {
"InstanceType": instance_type,
"InstanceCount": instance_count,
"VolumeSizeInGB": train_volume_size,
},
"StoppingCondition": {"MaxRuntimeInSeconds": 86400},
}

estimator = AlgorithmEstimator.attach(job_name, sagemaker_session=session)
assert estimator.hyperparameters() == {"max_leaf_nodes": 1, "free_text_hp1": "foo"}
assert estimator.algorithm_arn == algo_arn
assert estimator.role == role_arn
assert estimator.train_instance_count == instance_count
assert estimator.train_instance_type == instance_type
assert estimator.train_volume_size == train_volume_size
assert estimator.input_mode == input_mode
assert estimator.sagemaker_session == session