|
19 | 19 | from enum import Enum
|
20 | 20 |
|
21 | 21 | import sagemaker
|
22 |
| -from sagemaker.amazon.amazon_estimator import RecordSet |
| 22 | +from sagemaker.amazon.amazon_estimator import RecordSet, AmazonAlgorithmEstimatorBase |
23 | 23 | from sagemaker.amazon.hyperparameter import Hyperparameter as hp # noqa
|
24 | 24 | from sagemaker.analytics import HyperparameterTuningJobAnalytics
|
25 | 25 | from sagemaker.estimator import Framework
|
@@ -358,7 +358,7 @@ def attach(cls, tuning_job_name, sagemaker_session=None, job_details=None, estim
|
358 | 358 | estimator_cls, job_details["TrainingJobDefinition"]
|
359 | 359 | )
|
360 | 360 | estimator = cls._prepare_estimator_from_job_description(
|
361 |
| - estimator_cls, job_details["TrainingJobDefinition"], sagemaker_session |
| 361 | + estimator_cls, job_details, sagemaker_session |
362 | 362 | )
|
363 | 363 | init_params = cls._prepare_init_params_from_job_description(job_details)
|
364 | 364 |
|
@@ -497,16 +497,25 @@ def _prepare_estimator_cls(cls, estimator_cls, training_details):
|
497 | 497 | )
|
498 | 498 |
|
499 | 499 | @classmethod
|
500 |
| - def _prepare_estimator_from_job_description( |
501 |
| - cls, estimator_cls, training_details, sagemaker_session |
502 |
| - ): |
| 500 | + def _prepare_estimator_from_job_description(cls, estimator_cls, job_details, sagemaker_session): |
| 501 | + training_details = job_details["TrainingJobDefinition"] |
| 502 | + |
503 | 503 | # Swap name for static hyperparameters to what an estimator would expect
|
504 | 504 | training_details["HyperParameters"] = training_details["StaticHyperParameters"]
|
505 | 505 | del training_details["StaticHyperParameters"]
|
506 | 506 |
|
507 | 507 | # Remove hyperparameter reserved by SageMaker for tuning jobs
|
508 | 508 | del training_details["HyperParameters"]["_tuning_objective_metric"]
|
509 | 509 |
|
| 510 | + # Add missing hyperparameters defined in the hyperparameter ranges, |
| 511 | + # as potentially required in the Amazon algorithm estimator's constructor |
| 512 | + if issubclass(estimator_cls, AmazonAlgorithmEstimatorBase): |
| 513 | + parameter_ranges = job_details["HyperParameterTuningJobConfig"]["ParameterRanges"] |
| 514 | + additional_hyperparameters = cls._extract_hyperparameters_from_parameter_ranges( |
| 515 | + parameter_ranges |
| 516 | + ) |
| 517 | + training_details["HyperParameters"].update(additional_hyperparameters) |
| 518 | + |
510 | 519 | # Add items expected by the estimator (but aren't needed otherwise)
|
511 | 520 | training_details["TrainingJobName"] = ""
|
512 | 521 | if "KmsKeyId" not in training_details["OutputDataConfig"]:
|
@@ -559,6 +568,21 @@ def _prepare_parameter_ranges(cls, parameter_ranges):
|
559 | 568 |
|
560 | 569 | return ranges
|
561 | 570 |
|
| 571 | + @classmethod |
| 572 | + def _extract_hyperparameters_from_parameter_ranges(cls, parameter_ranges): |
| 573 | + hyperparameters = {} |
| 574 | + |
| 575 | + for parameter in parameter_ranges["CategoricalParameterRanges"]: |
| 576 | + hyperparameters[parameter["Name"]] = parameter["Values"][0] |
| 577 | + |
| 578 | + for parameter in parameter_ranges["ContinuousParameterRanges"]: |
| 579 | + hyperparameters[parameter["Name"]] = float(parameter["MinValue"]) |
| 580 | + |
| 581 | + for parameter in parameter_ranges["IntegerParameterRanges"]: |
| 582 | + hyperparameters[parameter["Name"]] = int(parameter["MinValue"]) |
| 583 | + |
| 584 | + return hyperparameters |
| 585 | + |
562 | 586 | def hyperparameter_ranges(self):
|
563 | 587 | """Return the hyperparameter ranges in a dictionary to be used as part of a request for creating a
|
564 | 588 | hyperparameter tuning job.
|
|
0 commit comments