|
12 | 12 | # language governing permissions and limitations under the License.
|
13 | 13 | from __future__ import absolute_import
|
14 | 14 |
|
| 15 | +import inspect |
| 16 | + |
15 | 17 | from sagemaker.job import _Job
|
16 | 18 | from sagemaker.utils import base_name_from_image, name_from_base
|
17 | 19 |
|
@@ -64,16 +66,18 @@ def __init__(self, estimator, objective_metric_name, hyperparameter_ranges, metr
|
64 | 66 | objective_type='Maximize', max_jobs=1, max_parallel_jobs=1, base_tuning_job_name=None):
|
65 | 67 | if objective_type not in HyperparameterTuner.__objectives__:
|
66 | 68 | raise ValueError("Unsupported 'objective' values")
|
| 69 | + |
67 | 70 | self.estimator = estimator
|
68 |
| - self.metric_name = objective_metric_name |
| 71 | + self.objective_metric_name = objective_metric_name |
69 | 72 | self._hyperparameter_ranges = hyperparameter_ranges
|
70 | 73 | self.strategy = strategy
|
71 |
| - self.objective = objective_type |
| 74 | + self.objective_type = objective_type |
72 | 75 | self.max_jobs = max_jobs
|
73 | 76 | self.max_parallel_jobs = max_parallel_jobs
|
74 | 77 | self.tuning_job_name = base_tuning_job_name
|
75 | 78 | self.metric_definitions = metric_definitions
|
76 | 79 | self.latest_tuning_job = None
|
| 80 | + self._validate_parameter_ranges() |
77 | 81 |
|
78 | 82 | def fit(self, inputs):
|
79 | 83 | """Create tuning job
|
@@ -108,6 +112,30 @@ def hyperparameter_ranges(self):
|
108 | 112 | hyperparameter_ranges[range_type + 'ParameterRanges'] = parameter_range
|
109 | 113 | return hyperparameter_ranges
|
110 | 114 |
|
| 115 | + def _validate_parameter_ranges(self): |
| 116 | + from sagemaker.amazon.hyperparameter import Hyperparameter as hp # noqa |
| 117 | + |
| 118 | + for kls in inspect.getmro(self.estimator.__class__)[::-1]: |
| 119 | + for attribute, value in kls.__dict__.items(): |
| 120 | + if isinstance(value, hp): |
| 121 | + try: |
| 122 | + # The hyperparam names may not be the same as the class attribute that holds them, |
| 123 | + # for instance: local_lloyd_init_method is called local_init_method. We need to map these |
| 124 | + # and pass the correct name to the constructor. |
| 125 | + parameter_range = self._hyperparameter_ranges[value.name] |
| 126 | + |
| 127 | + if isinstance(parameter_range, _ParameterRange): |
| 128 | + for parameter_range_attribute, parameter_range_value in parameter_range.__dict__.items(): |
| 129 | + # Categorical ranges |
| 130 | + if isinstance(parameter_range_value, list): |
| 131 | + for categorical_value in parameter_range_value: |
| 132 | + value.validate(categorical_value) |
| 133 | + # Continuous, Integer ranges |
| 134 | + else: |
| 135 | + value.validate(parameter_range_value) |
| 136 | + except KeyError: |
| 137 | + pass |
| 138 | + |
111 | 139 |
|
112 | 140 | class _TuningJob(_Job):
|
113 | 141 | SAGEMAKER_ESTIMATOR_CLASS_NAME = 'sagemaker_estimator_class_name'
|
@@ -140,7 +168,7 @@ def start_new(cls, tuner, inputs):
|
140 | 168 | tuning_job_name = name_from_base(base_name)
|
141 | 169 |
|
142 | 170 | tuner.estimator.sagemaker_session.tune(job_name=tuning_job_name, strategy=tuner.strategy,
|
143 |
| - objective=tuner.objective, metric_name=tuner.metric_name, |
| 171 | + objective=tuner.objective_type, metric_name=tuner.objective_metric_name, |
144 | 172 | max_jobs=tuner.max_jobs, max_parallel_jobs=tuner.max_parallel_jobs,
|
145 | 173 | parameter_ranges=tuner.hyperparameter_ranges(),
|
146 | 174 | static_hp=static_hyperparameters,
|
|
0 commit comments