Skip to content

Commit 3734314

Browse files
author
Dan
authored
change: fix attach for 1P algorithm estimators (#931)
1 parent 6adb29b commit 3734314

File tree

3 files changed

+38
-11
lines changed

3 files changed

+38
-11
lines changed

src/sagemaker/tuner.py

Lines changed: 29 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from enum import Enum
2020

2121
import sagemaker
22-
from sagemaker.amazon.amazon_estimator import RecordSet
22+
from sagemaker.amazon.amazon_estimator import RecordSet, AmazonAlgorithmEstimatorBase
2323
from sagemaker.amazon.hyperparameter import Hyperparameter as hp # noqa
2424
from sagemaker.analytics import HyperparameterTuningJobAnalytics
2525
from sagemaker.estimator import Framework
@@ -358,7 +358,7 @@ def attach(cls, tuning_job_name, sagemaker_session=None, job_details=None, estim
358358
estimator_cls, job_details["TrainingJobDefinition"]
359359
)
360360
estimator = cls._prepare_estimator_from_job_description(
361-
estimator_cls, job_details["TrainingJobDefinition"], sagemaker_session
361+
estimator_cls, job_details, sagemaker_session
362362
)
363363
init_params = cls._prepare_init_params_from_job_description(job_details)
364364

@@ -497,16 +497,25 @@ def _prepare_estimator_cls(cls, estimator_cls, training_details):
497497
)
498498

499499
@classmethod
500-
def _prepare_estimator_from_job_description(
501-
cls, estimator_cls, training_details, sagemaker_session
502-
):
500+
def _prepare_estimator_from_job_description(cls, estimator_cls, job_details, sagemaker_session):
501+
training_details = job_details["TrainingJobDefinition"]
502+
503503
# Swap name for static hyperparameters to what an estimator would expect
504504
training_details["HyperParameters"] = training_details["StaticHyperParameters"]
505505
del training_details["StaticHyperParameters"]
506506

507507
# Remove hyperparameter reserved by SageMaker for tuning jobs
508508
del training_details["HyperParameters"]["_tuning_objective_metric"]
509509

510+
# Add missing hyperparameters defined in the hyperparameter ranges,
511+
# as potentially required in the Amazon algorithm estimator's constructor
512+
if issubclass(estimator_cls, AmazonAlgorithmEstimatorBase):
513+
parameter_ranges = job_details["HyperParameterTuningJobConfig"]["ParameterRanges"]
514+
additional_hyperparameters = cls._extract_hyperparameters_from_parameter_ranges(
515+
parameter_ranges
516+
)
517+
training_details["HyperParameters"].update(additional_hyperparameters)
518+
510519
# Add items expected by the estimator (but aren't needed otherwise)
511520
training_details["TrainingJobName"] = ""
512521
if "KmsKeyId" not in training_details["OutputDataConfig"]:
@@ -559,6 +568,21 @@ def _prepare_parameter_ranges(cls, parameter_ranges):
559568

560569
return ranges
561570

571+
@classmethod
572+
def _extract_hyperparameters_from_parameter_ranges(cls, parameter_ranges):
573+
hyperparameters = {}
574+
575+
for parameter in parameter_ranges["CategoricalParameterRanges"]:
576+
hyperparameters[parameter["Name"]] = parameter["Values"][0]
577+
578+
for parameter in parameter_ranges["ContinuousParameterRanges"]:
579+
hyperparameters[parameter["Name"]] = float(parameter["MinValue"])
580+
581+
for parameter in parameter_ranges["IntegerParameterRanges"]:
582+
hyperparameters[parameter["Name"]] = int(parameter["MinValue"])
583+
584+
return hyperparameters
585+
562586
def hyperparameter_ranges(self):
563587
"""Return the hyperparameter ranges in a dictionary to be used as part of a request for creating a
564588
hyperparameter tuning job.

tests/integ/test_tuner.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -460,12 +460,15 @@ def test_tuning_lda(sagemaker_session):
460460
time.sleep(15)
461461
tuner.wait()
462462

463-
desc = tuner.latest_tuning_job.sagemaker_session.sagemaker_client.describe_hyper_parameter_tuning_job(
464-
HyperParameterTuningJobName=latest_tuning_job_name
463+
attached_tuner = HyperparameterTuner.attach(
464+
tuning_job_name, sagemaker_session=sagemaker_session
465465
)
466-
assert desc["HyperParameterTuningJobConfig"]["TrainingJobEarlyStoppingType"] == "Auto"
466+
assert attached_tuner.early_stopping_type == "Auto"
467+
assert attached_tuner.estimator.alpha0 == 1.0
468+
assert attached_tuner.estimator.num_topics == 1
469+
470+
best_training_job = attached_tuner.best_training_job()
467471

468-
best_training_job = tuner.best_training_job()
469472
with timeout_and_delete_endpoint_by_name(best_training_job, sagemaker_session):
470473
predictor = tuner.deploy(1, "ml.c4.xlarge")
471474
predict_input = np.random.rand(1, feature_num)

tests/unit/test_tuner.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@
7878
"IntegerParameterRanges": [
7979
{
8080
"MaxValue": "100",
81-
"Name": "mini_batch_size",
81+
"Name": "num_components",
8282
"MinValue": "10",
8383
"ScalingType": "Auto",
8484
}
@@ -416,7 +416,7 @@ def test_attach_tuning_job_with_estimator_from_hyperparameters(sagemaker_session
416416
assert tuner.estimator.output_kms_key == ""
417417

418418
assert "_tuning_objective_metric" not in tuner.estimator.hyperparameters()
419-
assert tuner.estimator.hyperparameters()["num_components"] == "1"
419+
assert tuner.estimator.hyperparameters()["num_components"] == "10"
420420

421421

422422
def test_attach_tuning_job_with_estimator_from_hyperparameters_with_early_stopping(

0 commit comments

Comments
 (0)