Skip to content

Commit 479c9c4

Browse files
authored
fix: ignore private Automatic Model Tuning hyperparameter when attaching AlgorithmEstimator (#1230)
1 parent 1a5574c commit 479c9c4

File tree

2 files changed

+97
-0
lines changed

2 files changed

+97
-0
lines changed

src/sagemaker/algorithm.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -553,3 +553,28 @@ def _algorithm_training_input_modes(self, training_channels):
553553
current_input_modes = current_input_modes & supported_input_modes
554554

555555
return current_input_modes
556+
557+
@classmethod
558+
def _prepare_init_params_from_job_description(cls, job_details, model_channel_name=None):
559+
"""Convert the job description to init params that can be handled by the
560+
class constructor
561+
562+
Args:
563+
job_details (dict): the returned job details from a DescribeTrainingJob
564+
API call.
565+
model_channel_name (str): Name of the channel where pre-trained
566+
model data will be downloaded.
567+
568+
Returns:
569+
dict: The transformed init_params
570+
"""
571+
init_params = super(AlgorithmEstimator, cls)._prepare_init_params_from_job_description(
572+
job_details, model_channel_name
573+
)
574+
575+
# This hyperparameter is added by Amazon SageMaker Automatic Model Tuning.
576+
# It cannot be set through instantiating an estimator.
577+
if "_tuning_objective_metric" in init_params["hyperparameters"]:
578+
del init_params["hyperparameters"]["_tuning_objective_metric"]
579+
580+
return init_params

tests/unit/test_algorithm.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -943,3 +943,75 @@ def test_algorithm_no_required_hyperparameters(session):
943943
train_instance_count=1,
944944
sagemaker_session=session,
945945
)
946+
947+
948+
def test_algorithm_attach_from_hyperparameter_tuning():
949+
session = Mock()
950+
job_name = "training-job-that-is-part-of-a-tuning-job"
951+
algo_arn = "arn:aws:sagemaker:us-east-2:000000000000:algorithm/scikit-decision-trees"
952+
role_arn = "arn:aws:iam::123412341234:role/SageMakerRole"
953+
instance_count = 1
954+
instance_type = "ml.m4.xlarge"
955+
train_volume_size = 30
956+
input_mode = "File"
957+
958+
session.sagemaker_client.list_tags.return_value = {"Tags": []}
959+
session.sagemaker_client.describe_algorithm.return_value = DESCRIBE_ALGORITHM_RESPONSE
960+
session.sagemaker_client.describe_training_job.return_value = {
961+
"TrainingJobName": job_name,
962+
"TrainingJobArn": "arn:aws:sagemaker:us-east-2:123412341234:training-job/%s" % job_name,
963+
"TuningJobArn": "arn:aws:sagemaker:us-east-2:123412341234:hyper-parameter-tuning-job/%s"
964+
% job_name,
965+
"ModelArtifacts": {
966+
"S3ModelArtifacts": "s3://sagemaker-us-east-2-123412341234/output/model.tar.gz"
967+
},
968+
"TrainingJobOutput": {
969+
"S3TrainingJobOutput": "s3://sagemaker-us-east-2-123412341234/output/output.tar.gz"
970+
},
971+
"TrainingJobStatus": "Succeeded",
972+
"HyperParameters": {
973+
"_tuning_objective_metric": "validation:accuracy",
974+
"max_leaf_nodes": 1,
975+
"free_text_hp1": "foo",
976+
},
977+
"AlgorithmSpecification": {"AlgorithmName": algo_arn, "TrainingInputMode": input_mode},
978+
"MetricDefinitions": [
979+
{"Name": "validation:accuracy", "Regex": "validation-accuracy: (\\S+)"}
980+
],
981+
"RoleArn": role_arn,
982+
"InputDataConfig": [
983+
{
984+
"ChannelName": "training",
985+
"DataSource": {
986+
"S3DataSource": {
987+
"S3DataType": "S3Prefix",
988+
"S3Uri": "s3://sagemaker-us-east-2-123412341234/input/training.csv",
989+
"S3DataDistributionType": "FullyReplicated",
990+
}
991+
},
992+
"CompressionType": "None",
993+
"RecordWrapperType": "None",
994+
}
995+
],
996+
"OutputDataConfig": {
997+
"KmsKeyId": "",
998+
"S3OutputPath": "s3://sagemaker-us-east-2-123412341234/output",
999+
"RemoveJobNameFromS3OutputPath": False,
1000+
},
1001+
"ResourceConfig": {
1002+
"InstanceType": instance_type,
1003+
"InstanceCount": instance_count,
1004+
"VolumeSizeInGB": train_volume_size,
1005+
},
1006+
"StoppingCondition": {"MaxRuntimeInSeconds": 86400},
1007+
}
1008+
1009+
estimator = AlgorithmEstimator.attach(job_name, sagemaker_session=session)
1010+
assert estimator.hyperparameters() == {"max_leaf_nodes": 1, "free_text_hp1": "foo"}
1011+
assert estimator.algorithm_arn == algo_arn
1012+
assert estimator.role == role_arn
1013+
assert estimator.train_instance_count == instance_count
1014+
assert estimator.train_instance_type == instance_type
1015+
assert estimator.train_volume_size == train_volume_size
1016+
assert estimator.input_mode == input_mode
1017+
assert estimator.sagemaker_session == session

0 commit comments

Comments
 (0)