Skip to content

Commit f19f4e7

Browse files
Add hyperparameter range validation (aws#31)
1 parent f2209de commit f19f4e7

File tree

2 files changed

+61
-7
lines changed

2 files changed

+61
-7
lines changed

src/sagemaker/tuner.py

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
# language governing permissions and limitations under the License.
1313
from __future__ import absolute_import
1414

15+
import inspect
16+
1517
from sagemaker.job import _Job
1618
from sagemaker.utils import base_name_from_image, name_from_base
1719

@@ -64,16 +66,18 @@ def __init__(self, estimator, objective_metric_name, hyperparameter_ranges, metr
6466
objective_type='Maximize', max_jobs=1, max_parallel_jobs=1, base_tuning_job_name=None):
6567
if objective_type not in HyperparameterTuner.__objectives__:
6668
raise ValueError("Unsupported 'objective' values")
69+
6770
self.estimator = estimator
68-
self.metric_name = objective_metric_name
71+
self.objective_metric_name = objective_metric_name
6972
self._hyperparameter_ranges = hyperparameter_ranges
7073
self.strategy = strategy
71-
self.objective = objective_type
74+
self.objective_type = objective_type
7275
self.max_jobs = max_jobs
7376
self.max_parallel_jobs = max_parallel_jobs
7477
self.tuning_job_name = base_tuning_job_name
7578
self.metric_definitions = metric_definitions
7679
self.latest_tuning_job = None
80+
self._validate_parameter_ranges()
7781

7882
def fit(self, inputs):
7983
"""Create tuning job
@@ -108,6 +112,30 @@ def hyperparameter_ranges(self):
108112
hyperparameter_ranges[range_type + 'ParameterRanges'] = parameter_range
109113
return hyperparameter_ranges
110114

115+
def _validate_parameter_ranges(self):
116+
from sagemaker.amazon.hyperparameter import Hyperparameter as hp # noqa
117+
118+
for kls in inspect.getmro(self.estimator.__class__)[::-1]:
119+
for attribute, value in kls.__dict__.items():
120+
if isinstance(value, hp):
121+
try:
122+
# The hyperparam names may not be the same as the class attribute that holds them,
123+
# for instance: local_lloyd_init_method is called local_init_method. We need to map these
124+
# and pass the correct name to the constructor.
125+
parameter_range = self._hyperparameter_ranges[value.name]
126+
127+
if isinstance(parameter_range, _ParameterRange):
128+
for parameter_range_attribute, parameter_range_value in parameter_range.__dict__.items():
129+
# Categorical ranges
130+
if isinstance(parameter_range_value, list):
131+
for categorical_value in parameter_range_value:
132+
value.validate(categorical_value)
133+
# Continuous, Integer ranges
134+
else:
135+
value.validate(parameter_range_value)
136+
except KeyError:
137+
pass
138+
111139

112140
class _TuningJob(_Job):
113141
SAGEMAKER_ESTIMATOR_CLASS_NAME = 'sagemaker_estimator_class_name'
@@ -140,7 +168,7 @@ def start_new(cls, tuner, inputs):
140168
tuning_job_name = name_from_base(base_name)
141169

142170
tuner.estimator.sagemaker_session.tune(job_name=tuning_job_name, strategy=tuner.strategy,
143-
objective=tuner.objective, metric_name=tuner.metric_name,
171+
objective=tuner.objective_type, metric_name=tuner.objective_metric_name,
144172
max_jobs=tuner.max_jobs, max_parallel_jobs=tuner.max_parallel_jobs,
145173
parameter_ranges=tuner.hyperparameter_ranges(),
146174
static_hp=static_hyperparameters,

tests/unit/test_tuner.py

Lines changed: 30 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,32 @@ def tuner(estimator):
5959
hyperparameter_ranges=HYPERPARAMETER_RANGES, metric_definitions=METRIC_DEFINTIONS)
6060

6161

62+
def test_validate_parameter_ranges_number_validation_error(sagemaker_session):
63+
pca = PCA(ROLE, TRAIN_INSTANCE_COUNT, TRAIN_INSTANCE_TYPE, NUM_COMPONENTS,
64+
base_job_name='pca', sagemaker_session=sagemaker_session)
65+
66+
invalid_hyperparameter_ranges = {'num_components': IntegerParameter(-1, 2)}
67+
68+
with pytest.raises(ValueError) as e:
69+
HyperparameterTuner(estimator=pca, objective_metric_name=OBJECTIVE_METRIC_NAME,
70+
hyperparameter_ranges=invalid_hyperparameter_ranges, metric_definitions=METRIC_DEFINTIONS)
71+
72+
assert 'Value must be an integer greater than zero' in str(e)
73+
74+
75+
def test_validate_parameter_ranges_string_value_validation_error(sagemaker_session):
76+
pca = PCA(ROLE, TRAIN_INSTANCE_COUNT, TRAIN_INSTANCE_TYPE, NUM_COMPONENTS,
77+
base_job_name='pca', sagemaker_session=sagemaker_session)
78+
79+
invalid_hyperparameter_ranges = {'algorithm_mode': CategoricalParameter([0, 5])}
80+
81+
with pytest.raises(ValueError) as e:
82+
HyperparameterTuner(estimator=pca, objective_metric_name=OBJECTIVE_METRIC_NAME,
83+
hyperparameter_ranges=invalid_hyperparameter_ranges, metric_definitions=METRIC_DEFINTIONS)
84+
85+
assert 'Value must be one of "regular" and "randomized"' in str(e)
86+
87+
6288
def test_tune_1p(sagemaker_session, tuner):
6389
pca = PCA(ROLE, TRAIN_INSTANCE_COUNT, TRAIN_INSTANCE_TYPE, NUM_COMPONENTS,
6490
base_job_name='pca', sagemaker_session=sagemaker_session)
@@ -67,18 +93,18 @@ def test_tune_1p(sagemaker_session, tuner):
6793
pca.subtract_mean = True
6894
pca.extra_components = 5
6995

70-
hyperparameter_ranges = {'num_components': IntegerParameter(2, 4)}
96+
hyperparameter_ranges = {'num_components': IntegerParameter(2, 4),
97+
'algorithm_mode': CategoricalParameter(['regular', 'randomized'])}
7198
tuner.estimator = pca
7299
tuner._hyperparameter_ranges = hyperparameter_ranges
73100

74101
tuner.fit(INPUTS)
75102

76103
_, _, tune_kwargs = sagemaker_session.tune.mock_calls[0]
77104

78-
assert len(tune_kwargs['static_hp']) == 5
105+
assert len(tune_kwargs['static_hp']) == 4
79106
assert tune_kwargs['static_hp']['sagemaker_estimator_class_name'] == pca.__class__.__name__
80-
assert tune_kwargs['static_hp']['sagemaker_estimator_module'] == pca.__module__
81-
assert tune_kwargs['static_hp']['algorithm_mode'] == 'randomized'
107+
assert tune_kwargs['static_hp']['extra_components'] == '5'
82108
assert len(tune_kwargs['parameter_ranges']['IntegerParameterRanges']) == 1
83109
assert tune_kwargs['job_name'].startswith('pca')
84110

0 commit comments

Comments
 (0)