Skip to content

Commit 8c68620

Browse files
authored
Fix tuning job name generation (aws#35)
1 parent 7d72221 commit 8c68620

File tree

3 files changed

+33
-13
lines changed

3 files changed

+33
-13
lines changed

src/sagemaker/tuner.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,8 @@ def _validate_parameter_ranges(self):
189189

190190

191191
class _TuningJob(_Job):
192+
TUNING_JOB_NAME_MAX_LENGTH = 32
193+
192194
def __init__(self, sagemaker_session, tuning_job_name):
193195
super(_TuningJob, self).__init__(sagemaker_session, tuning_job_name)
194196

@@ -205,12 +207,8 @@ def start_new(cls, tuner, inputs):
205207
"""
206208
config = _Job._load_config(inputs, tuner.estimator)
207209

208-
base_name = tuner.estimator.base_job_name or base_name_from_image(tuner.estimator.train_image())
209-
tuning_job_name = name_from_base(base_name)
210-
211-
# TODO: Update name generation so that the base name isn't limited to so few characters
212-
if len(tuning_job_name) > 32:
213-
raise ValueError('Tuning job name too long - must be 32 characters or fewer: {}'.format(tuning_job_name))
210+
base_name = tuner.base_tuning_job_name or base_name_from_image(tuner.estimator.train_image())
211+
tuning_job_name = name_from_base(base_name, max_length=cls.TUNING_JOB_NAME_MAX_LENGTH, short=True)
214212

215213
tuner.estimator.sagemaker_session.tune(job_name=tuning_job_name, strategy=tuner.strategy,
216214
objective_type=tuner.objective_type,

src/sagemaker/utils.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -31,20 +31,21 @@ def name_from_image(image):
3131
return name_from_base(base_name_from_image(image))
3232

3333

34-
def name_from_base(base):
34+
def name_from_base(base, max_length=63, short=False):
3535
"""Append a timestamp to the provided string.
3636
37-
The appended timestamp is precise to the millisecond. This function assures that the total length of the resulting
38-
string is not longer that 63, trimming the input parameter if necessary.
37+
This function assures that the total length of the resulting string is not
38+
longer than the specified max length, trimming the input parameter if necessary.
3939
4040
Args:
4141
base (str): String used as prefix to generate the unique name.
42+
max_length (int): Maximum length for the resulting string.
43+
short (bool): Whether or not to use a truncated timestamp.
4244
4345
Returns:
44-
str: Input parameter with appended timestamp (no longer than 63 characters).
46+
str: Input parameter with appended timestamp.
4547
"""
46-
max_length = 63
47-
timestamp = sagemaker_timestamp()
48+
timestamp = sagemaker_short_timestamp() if short else sagemaker_timestamp()
4849
trimmed_base = base[:max_length - len(timestamp) - 1]
4950
return '{}-{}'.format(trimmed_base, timestamp)
5051

@@ -70,6 +71,11 @@ def sagemaker_timestamp():
7071
return time.strftime("%Y-%m-%d-%H-%M-%S-{}".format(moment_ms), time.gmtime(moment))
7172

7273

74+
def sagemaker_short_timestamp():
75+
"""Return a timestamp that is relatively short in length"""
76+
return time.strftime('%y%m%d-%H%M')
77+
78+
7379
def debug(func):
7480
"""Print the function name and arguments for debugging."""
7581
@wraps(func)

tests/unit/test_utils.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,11 @@
1212
# language governing permissions and limitations under the License.
1313
from __future__ import absolute_import
1414

15-
from sagemaker.utils import get_config_value
15+
from mock import patch
16+
17+
from sagemaker.utils import get_config_value, name_from_base
18+
19+
NAME = 'base_name'
1620

1721

1822
def test_get_config_value():
@@ -32,3 +36,15 @@ def test_get_config_value():
3236

3337
assert get_config_value('does_not.exist', config) is None
3438
assert get_config_value('other.key', None) is None
39+
40+
41+
@patch('sagemaker.utils.sagemaker_timestamp')
42+
def test_name_from_base(sagemaker_timestamp):
43+
name_from_base(NAME, short=False)
44+
assert sagemaker_timestamp.called_once
45+
46+
47+
@patch('sagemaker.utils.sagemaker_short_timestamp')
48+
def test_name_from_base_short(sagemaker_short_timestamp):
49+
name_from_base(NAME, short=True)
50+
assert sagemaker_short_timestamp.called_once

0 commit comments

Comments
 (0)