Skip to content

Commit 22d3d07

Browse files
authored
Define and use HpoParameter for hyperparameters in 1P algorithms. (aws#5)
1 parent e2334a9 commit 22d3d07

File tree

4 files changed

+256
-5
lines changed

4 files changed

+256
-5
lines changed

src/sagemaker/amazon/hyperparameter.py

Lines changed: 48 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
1111
# ANY KIND, either express or implied. See the License for the specific
1212
# language governing permissions and limitations under the License.
13+
from sagemaker.hpo import _HpoParameter
1314

1415

1516
class Hyperparameter(object):
@@ -45,13 +46,35 @@ def validate(self, value):
4546
raise ValueError(error_message)
4647

4748
def __get__(self, obj, objtype):
48-
"""Return the value of this hyperparameter"""
49+
"""Return the value of this hyperparameter, whether it be a range or one discrete value"""
50+
if '_hpo_parameters' in dir(obj) and self.name in obj._hpo_parameters:
51+
return obj._hpo_parameters[self.name]
4952
if '_hyperparameters' not in dir(obj) or self.name not in obj._hyperparameters:
5053
raise AttributeError()
5154
return obj._hyperparameters[self.name]
5255

5356
def __set__(self, obj, value):
57+
"""Assign values for hyperparameters"""
58+
if isinstance(value, _HpoParameter):
59+
self._set_hpo_parameter(obj, value)
60+
else:
61+
self._set_hyperparameter(obj, value)
62+
63+
def _set_hpo_parameter(self, obj, value):
64+
# remove from the hyperparameters if it's there
65+
if '_hyperparameters' in dir(obj) and self.name in obj._hyperparameters:
66+
del obj._hyperparameters[self.name]
67+
68+
if '_hpo_parameters' not in dir(obj):
69+
obj._hpo_parameters = dict()
70+
obj._hpo_parameters[self.name] = value
71+
72+
def _set_hyperparameter(self, obj, value):
5473
"""Validate the supplied value and set this hyperparameter to value"""
74+
# remove from the hpo_parameters if it's there
75+
if '_hpo_parameters' in dir(obj) and self.name in obj._hpo_parameters:
76+
del obj._hpo_parameters[self.name]
77+
5578
value = None if value is None else self.data_type(value)
5679
self.validate(value)
5780
if '_hyperparameters' not in dir(obj):
@@ -60,11 +83,34 @@ def __set__(self, obj, value):
6083

6184
def __delete__(self, obj):
6285
"""Delete this hyperparameter"""
63-
del obj._hyperparameters[self.name]
86+
if '_hyperparameters' in dir(obj) and self.name in obj._hyperparameters:
87+
del obj._hyperparameters[self.name]
88+
if '_hpo_parameters' in dir(obj) and self.name in obj._hpo_parameters:
89+
del obj._hpo_parameters[self.name]
6490

6591
@staticmethod
6692
def serialize_all(obj):
6793
"""Return all non-None ``hyperparameter`` values on ``obj`` as a ``dict[str,str].``"""
6894
if '_hyperparameters' not in dir(obj):
6995
return {}
7096
return {k: str(v) for k, v in obj._hyperparameters.items() if v is not None}
97+
98+
@staticmethod
99+
def serialize_all_hpo(obj):
100+
"""Return collections of ``ParameterRanges``
101+
102+
Args:
103+
obj (object): 1P estimator that has hyperparameters defined
104+
105+
Returns:
106+
Dictionary of ParameterRanges suitable for HPO tuning job.
107+
"""
108+
if '_hpo_parameters' not in dir(obj):
109+
obj._hpo_parameters = dict()
110+
111+
parameter_ranges = dict()
112+
for range_type in _HpoParameter.__all_types__:
113+
parameter_range = [param.as_hpo_range(p_name)
114+
for p_name, param in obj._hpo_parameters.items() if param.__name__ == range_type]
115+
parameter_ranges[range_type+'ParameterRange'] = parameter_range
116+
return parameter_ranges

src/sagemaker/hpo.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
# Copyright 2017-2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
14+
15+
class _HpoParameter(object):
16+
__all_types__ = ['Continuous', 'Categorical', 'Integer']
17+
18+
def __init__(self, min_value, max_value):
19+
self.min_value = min_value
20+
self.max_value = max_value
21+
22+
def as_hpo_range(self, name):
23+
return {'Name': name,
24+
'Type': self.__name__,
25+
'MinValue': str(self.min_value),
26+
'MaxValue': str(self.max_value)}
27+
28+
29+
class ContinuousParameter(_HpoParameter):
30+
__name__ = 'Continuous'
31+
32+
def __init__(self, min_value, max_value):
33+
super(ContinuousParameter, self).__init__(min_value, max_value)
34+
35+
36+
class CategoricalParameter(_HpoParameter):
37+
__name__ = 'Categorical'
38+
39+
def __init__(self, values):
40+
if isinstance(values, list):
41+
self.values = [str(v) for v in values]
42+
else:
43+
self.values = [str(values)]
44+
45+
def as_hpo_range(self, name):
46+
return {'Name': name,
47+
'Type': self.__name__,
48+
'Values': self.values}
49+
50+
51+
class IntegerParameter(_HpoParameter):
52+
__name__ = 'Integer'
53+
54+
def __init__(self, min_value, max_value):
55+
super(IntegerParameter, self).__init__(min_value, max_value)

tests/unit/test_hpo.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
# Copyright 2017 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
from sagemaker.hpo import _HpoParameter, ContinuousParameter, IntegerParameter, CategoricalParameter
14+
15+
16+
def test_continuous_parameter():
17+
cont_param = ContinuousParameter(0.1, 1e-2)
18+
assert isinstance(cont_param, _HpoParameter)
19+
assert cont_param.__name__ is 'Continuous'
20+
21+
22+
def test_continuous_parameter_hpo():
23+
cont_param = ContinuousParameter(0.1, 1e-2)
24+
hpo = cont_param.as_hpo_range('some')
25+
assert len(hpo.keys()) == 4
26+
assert hpo['Name'] == 'some'
27+
assert hpo['Type'] == cont_param.__name__
28+
assert hpo['MinValue'] == '0.1'
29+
assert hpo['MaxValue'] == '0.01' # min < max verification is performed by HPO
30+
31+
32+
def test_integer_parameter():
33+
int_param = IntegerParameter(1, 2)
34+
assert isinstance(int_param, _HpoParameter)
35+
assert int_param.__name__ is 'Integer'
36+
37+
38+
def test_integer_parameter_hpo():
39+
int_param = IntegerParameter(1, 2)
40+
hpo = int_param.as_hpo_range('some')
41+
assert len(hpo.keys()) == 4
42+
assert hpo['Name'] == 'some'
43+
assert hpo['Type'] == int_param.__name__
44+
assert hpo['MinValue'] == '1'
45+
assert hpo['MaxValue'] == '2'
46+
47+
48+
def test_categorical_parameter_list():
49+
cat_param = CategoricalParameter(['a', 'z'])
50+
assert isinstance(cat_param, _HpoParameter)
51+
assert cat_param.__name__ is 'Categorical'
52+
53+
54+
def test_categorical_parameter__list_hpo():
55+
cat_param = CategoricalParameter([1, 10])
56+
hpo = cat_param.as_hpo_range('some')
57+
assert len(hpo.keys()) == 3
58+
assert hpo['Name'] == 'some'
59+
assert hpo['Type'] == cat_param.__name__
60+
assert hpo['Values'] == ['1', '10']
61+
62+
63+
def test_categorical_parameter_value():
64+
cat_param = CategoricalParameter('a')
65+
assert isinstance(cat_param, _HpoParameter)
66+
assert cat_param.__name__ is 'Categorical'
67+
68+
69+
def test_categorical_parameter_value_hpo():
70+
cat_param = CategoricalParameter('a')
71+
hpo = cat_param.as_hpo_range('some')
72+
assert len(hpo.keys()) == 3
73+
assert hpo['Name'] == 'some'
74+
assert hpo['Type'] == cat_param.__name__
75+
assert hpo['Values'] == ['a']

tests/unit/test_hyperparameter.py

Lines changed: 78 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,14 @@
1212
# language governing permissions and limitations under the License.
1313
import pytest
1414
from sagemaker.amazon.hyperparameter import Hyperparameter
15+
from sagemaker.hpo import ContinuousParameter, IntegerParameter, CategoricalParameter
1516

1617

1718
class Test(object):
1819

19-
blank = Hyperparameter(name="some-name", data_type=int)
20+
blank = Hyperparameter(name='some-name', data_type=int)
2021
elizabeth = Hyperparameter(name='elizabeth')
21-
validated = Hyperparameter(name="validated", validate=lambda value: value > 55, data_type=int)
22+
validated = Hyperparameter(name='validated', validate=lambda value: value > 55, data_type=int)
2223

2324

2425
def test_blank_access():
@@ -60,7 +61,7 @@ def test_validated():
6061
def test_data_type():
6162
x = Test()
6263
x.validated = 66
63-
assert type(x.validated) == Test.__dict__["validated"].data_type
64+
assert type(x.validated) == Test.__dict__['validated'].data_type
6465

6566

6667
def test_from_string():
@@ -72,3 +73,77 @@ def test_from_string():
7273

7374
x.validated = from_api
7475
assert x.validated == value
76+
77+
78+
def test_serialize_hpo():
79+
x = Test()
80+
81+
x.validated = ContinuousParameter(0, 5)
82+
x.elizabeth = IntegerParameter(0, 5)
83+
x.blank = CategoricalParameter([0, 5])
84+
85+
assert isinstance(x.validated, ContinuousParameter)
86+
assert isinstance(x.elizabeth, IntegerParameter)
87+
assert isinstance(x.blank, CategoricalParameter)
88+
89+
hpo_ranges = Hyperparameter.serialize_all_hpo(x)
90+
assert hpo_ranges[x.validated.__name__+'ParameterRange'][0]['Name'] == 'validated'
91+
assert hpo_ranges[x.elizabeth.__name__+'ParameterRange'][0]['Name'] == 'elizabeth'
92+
assert hpo_ranges[x.blank.__name__+'ParameterRange'][0]['Name'] == 'some-name'
93+
94+
95+
def test_assign_to_hpo():
96+
x = Test()
97+
x.validated = 66
98+
99+
assert Hyperparameter.serialize_all(x)['validated']
100+
hpo_ranges = Hyperparameter.serialize_all_hpo(x)
101+
assert len(hpo_ranges.keys()) == 3
102+
assert [len(hpo_ranges[r]) for r in hpo_ranges.keys()] == [0, 0, 0]
103+
104+
x.validated = ContinuousParameter(0, 5)
105+
assert isinstance(x.validated, ContinuousParameter)
106+
assert Hyperparameter.serialize_all(x) == {}
107+
hpo_ranges = Hyperparameter.serialize_all_hpo(x)
108+
assert len(hpo_ranges) == 3
109+
assert hpo_ranges[x.validated.__name__+'ParameterRange'][0]['Name'] == 'validated'
110+
111+
# hp is preserved, necessary when attaching to a training job
112+
assert int == Test.__dict__['validated'].data_type
113+
114+
115+
def test_assign_from_hpo():
116+
x = Test()
117+
118+
x.validated = IntegerParameter(0, 5)
119+
x.validated = 67
120+
121+
assert Hyperparameter.serialize_all(x)['validated']
122+
hpo_ranges = Hyperparameter.serialize_all_hpo(x)
123+
assert len(hpo_ranges.keys()) == 3
124+
assert [len(hpo_ranges[r]) for r in hpo_ranges.keys()] == [0, 0, 0]
125+
126+
assert type(x.validated) == Test.__dict__['validated'].data_type
127+
128+
129+
def test_hpo_assign_none():
130+
x = Test()
131+
132+
x.validated = ContinuousParameter(0, 5)
133+
x.validated = None
134+
135+
assert Hyperparameter.serialize_all(x) == {}
136+
hpo_ranges = Hyperparameter.serialize_all_hpo(x)
137+
assert len(hpo_ranges.keys()) == 3
138+
assert [len(hpo_ranges[r]) for r in hpo_ranges.keys()] == [0, 0, 0]
139+
140+
assert int == Test.__dict__['validated'].data_type
141+
142+
143+
def test_delete_in_hpo():
144+
x = Test()
145+
x.validated = ContinuousParameter(0, 5)
146+
del(x.validated)
147+
148+
with pytest.raises(AttributeError):
149+
x.validated

0 commit comments

Comments
 (0)