Skip to content

Commit d62a328

Browse files
authored
fix: fix HyperparameterTuner.attach for Marketplace algorithms (#1291)
The AlgorithmSpecification for a training job that is part of a hyperparameter tuning job will have TrainingImage set to "", rather than not being present in the training job description. Looking for AlgorithmName first, instead of TrainingImage, fixes the issue of algorithm_arn not being populated when attaching a hyperparameter tuning job.
1 parent b967a44 commit d62a328

File tree

2 files changed

+4
-4
lines changed

2 files changed

+4
-4
lines changed

src/sagemaker/estimator.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -753,10 +753,10 @@ class constructor
753753
has_hps = "HyperParameters" in job_details
754754
init_params["hyperparameters"] = job_details["HyperParameters"] if has_hps else {}
755755

756-
if "TrainingImage" in job_details["AlgorithmSpecification"]:
757-
init_params["image"] = job_details["AlgorithmSpecification"]["TrainingImage"]
758-
elif "AlgorithmName" in job_details["AlgorithmSpecification"]:
756+
if "AlgorithmName" in job_details["AlgorithmSpecification"]:
759757
init_params["algorithm_arn"] = job_details["AlgorithmSpecification"]["AlgorithmName"]
758+
elif "TrainingImage" in job_details["AlgorithmSpecification"]:
759+
init_params["image"] = job_details["AlgorithmSpecification"]["TrainingImage"]
760760
else:
761761
raise RuntimeError(
762762
"Invalid AlgorithmSpecification. Either TrainingImage or "

tests/unit/test_estimator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2289,11 +2289,11 @@ def test_prepare_init_params_from_job_description_with_image_training_job():
22892289

22902290

22912291
def test_prepare_init_params_from_job_description_with_algorithm_training_job():
2292-
22932292
algorithm_job_description = RETURNED_JOB_DESCRIPTION.copy()
22942293
algorithm_job_description["AlgorithmSpecification"] = {
22952294
"TrainingInputMode": "File",
22962295
"AlgorithmName": "arn:aws:sagemaker:us-east-2:1234:algorithm/scikit-decision-trees",
2296+
"TrainingImage": "",
22972297
}
22982298

22992299
init_params = EstimatorBase._prepare_init_params_from_job_description(

0 commit comments

Comments
 (0)