Skip to content

Commit 453b6a8

Browse files
authored
Refactor EstimatorBase and Framework to have a prepare_for_training() method (aws#15)
* Refactor EstimatorBase and Framework to have a prepare_for_training() method * Specify argument directly instead of using **kwargs
1 parent 6ee0d78 commit 453b6a8

File tree

2 files changed

+77
-92
lines changed

2 files changed

+77
-92
lines changed

src/sagemaker/estimator.py

Lines changed: 29 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,29 @@ def hyperparameters(self):
120120
"""
121121
pass
122122

123+
def prepare_for_training(self, job_name=None):
124+
"""Set any values in the estimator that need to be set before training.
125+
126+
Args:
127+
* job_name (str): Name of the training job to be created. If not specified, one is generated,
128+
using the base name given to the constructor if applicable.
129+
"""
130+
if job_name is not None:
131+
self._current_job_name = job_name
132+
else:
133+
# honor supplied base_job_name or generate it
134+
base_name = self.base_job_name or base_name_from_image(self.train_image())
135+
self._current_job_name = name_from_base(base_name)
136+
137+
# if output_path was specified we use it otherwise initialize here.
138+
# For Local Mode with local_code=True we don't need an explicit output_path
139+
if self.output_path is None:
140+
local_code = get_config_value('local.local_code', self.sagemaker_session.config)
141+
if self.sagemaker_session.local_mode and local_code:
142+
self.output_path = ''
143+
else:
144+
self.output_path = 's3://{}/'.format(self.sagemaker_session.default_bucket())
145+
123146
def fit(self, inputs, wait=True, logs=True, job_name=None):
124147
"""Train a model using the input training dataset.
125148
@@ -148,22 +171,7 @@ def fit(self, inputs, wait=True, logs=True, job_name=None):
148171
job_name (str): Training job name. If not specified, the estimator generates a default job name,
149172
based on the training image name and current timestamp.
150173
"""
151-
152-
if job_name is not None:
153-
self._current_job_name = job_name
154-
else:
155-
# make sure the job name is unique for each invocation, honor supplied base_job_name or generate it
156-
base_name = self.base_job_name or base_name_from_image(self.train_image())
157-
self._current_job_name = name_from_base(base_name)
158-
159-
# if output_path was specified we use it otherwise initialize here.
160-
# For Local Mode with local_code=True we don't need an explicit output_path
161-
if self.output_path is None:
162-
local_code = get_config_value('local.local_code', self.sagemaker_session.config)
163-
if self.sagemaker_session.local_mode and local_code:
164-
self.output_path = ''
165-
else:
166-
self.output_path = 's3://{}/'.format(self.sagemaker_session.default_bucket())
174+
self.prepare_for_training(job_name=job_name)
167175

168176
self.latest_training_job = _TrainingJob.start_new(self, inputs)
169177
if wait:
@@ -505,39 +513,14 @@ def __init__(self, entry_point, source_dir=None, hyperparameters=None, enable_cl
505513
self._hyperparameters = hyperparameters or {}
506514
self.code_location = code_location
507515

508-
def fit(self, inputs, wait=True, logs=True, job_name=None):
509-
"""Train a model using the input training dataset.
510-
511-
The API calls the Amazon SageMaker CreateTrainingJob API to start model training.
512-
The API uses configuration you provided to create the estimator and the
513-
specified input training data to send the CreatingTrainingJob request to Amazon SageMaker.
514-
515-
This is a synchronous operation. After the model training successfully completes,
516-
you can call the ``deploy()`` method to host the model using the Amazon SageMaker hosting services.
516+
def prepare_for_training(self, job_name=None):
517+
"""Set hyperparameters needed for training. This method will also validate ``source_dir``.
517518
518519
Args:
519-
inputs (str or dict or sagemaker.session.s3_input): Information about the training data.
520-
This can be one of three types:
521-
(str) - the S3 location where training data is saved.
522-
(dict[str, str] or dict[str, sagemaker.session.s3_input]) - If using multiple channels for
523-
training data, you can specify a dict mapping channel names
524-
to strings or :func:`~sagemaker.session.s3_input` objects.
525-
(sagemaker.session.s3_input) - channel configuration for S3 data sources that can provide
526-
additional information about the training dataset. See :func:`sagemaker.session.s3_input`
527-
for full details.
528-
wait (bool): Whether the call shouldl wait until the job completes (default: True).
529-
logs (bool): Whether to show the logs produced by the job.
530-
Only meaningful when wait is True (default: True).
531-
job_name (str): Training job name. If not specified, the estimator generates a default job name,
532-
based on the training image name and current timestamp.
520+
* job_name (str): Name of the training job to be created. If not specified, one is generated,
521+
using the base name given to the constructor if applicable.
533522
"""
534-
# always determine new job name _here_ because it is used before base is called
535-
if job_name is not None:
536-
self._current_job_name = job_name
537-
else:
538-
# honor supplied base_job_name or generate it
539-
base_name = self.base_job_name or base_name_from_image(self.train_image())
540-
self._current_job_name = name_from_base(base_name)
523+
super(Framework, self).prepare_for_training(job_name=job_name)
541524

542525
# validate source dir will raise a ValueError if there is something wrong with the
543526
# source directory. We are intentionally not handling it because this is a critical error.
@@ -567,7 +550,6 @@ def fit(self, inputs, wait=True, logs=True, job_name=None):
567550
self._hyperparameters[CONTAINER_LOG_LEVEL_PARAM_NAME] = self.container_log_level
568551
self._hyperparameters[JOB_NAME_PARAM_NAME] = self._current_job_name
569552
self._hyperparameters[SAGEMAKER_REGION_PARAM_NAME] = self.sagemaker_session.boto_region_name
570-
super(Framework, self).fit(inputs, wait, logs, self._current_job_name)
571553

572554
def _stage_user_code_in_s3(self):
573555
""" Upload the user training script to s3 and return the location.

tests/unit/test_estimator.py

Lines changed: 48 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import logging
1616
import json
1717
import os
18+
1819
import pytest
1920
from mock import Mock, patch
2021

@@ -39,20 +40,22 @@
3940
REGION = 'us-west-2'
4041
JOB_NAME = '{}-{}'.format(IMAGE_NAME, TIMESTAMP)
4142

42-
COMMON_TRAIN_ARGS = {'volume_size': 30,
43-
'hyperparameters': {
44-
'sagemaker_program': 'dummy_script.py',
45-
'sagemaker_enable_cloudwatch_metrics': False,
46-
'sagemaker_container_log_level': logging.INFO,
47-
},
48-
'input_mode': 'File',
49-
'instance_type': 'c4.4xlarge',
50-
'inputs': 's3://mybucket/train',
51-
'instance_count': 1,
52-
'role': 'DummyRole',
53-
'kms_key_id': None,
54-
'max_run': 24,
55-
'wait': True}
43+
COMMON_TRAIN_ARGS = {
44+
'volume_size': 30,
45+
'hyperparameters': {
46+
'sagemaker_program': 'dummy_script.py',
47+
'sagemaker_enable_cloudwatch_metrics': False,
48+
'sagemaker_container_log_level': logging.INFO,
49+
},
50+
'input_mode': 'File',
51+
'instance_type': 'c4.4xlarge',
52+
'inputs': 's3://mybucket/train',
53+
'instance_count': 1,
54+
'role': 'DummyRole',
55+
'kms_key_id': None,
56+
'max_run': 24,
57+
'wait': True,
58+
}
5659

5760
DESCRIBE_TRAINING_JOB_RESULT = {
5861
'ModelArtifacts': {
@@ -275,19 +278,6 @@ def test_attach_framework(sagemaker_session):
275278
assert framework_estimator.entry_point == 'iris-dnn-classifier.py'
276279

277280

278-
def test_fit_then_fit_again(sagemaker_session):
279-
fw = DummyFramework(entry_point=SCRIPT_PATH, role=ROLE, sagemaker_session=sagemaker_session,
280-
train_instance_count=INSTANCE_COUNT, train_instance_type=INSTANCE_TYPE,
281-
enable_cloudwatch_metrics=True)
282-
fw.fit(inputs=s3_input('s3://mybucket/train'))
283-
first_job_name = fw.latest_training_job.name
284-
285-
fw.fit(inputs=s3_input('s3://mybucket/train2'))
286-
second_job_name = fw.latest_training_job.name
287-
288-
assert first_job_name != second_job_name
289-
290-
291281
@patch('time.strftime', return_value=TIMESTAMP)
292282
def test_fit_verify_job_name(strftime, sagemaker_session):
293283
fw = DummyFramework(entry_point=SCRIPT_PATH, role='DummyRole', sagemaker_session=sagemaker_session,
@@ -304,42 +294,55 @@ def test_fit_verify_job_name(strftime, sagemaker_session):
304294
assert fw.latest_training_job.name == JOB_NAME
305295

306296

307-
def test_fit_force_name(sagemaker_session):
297+
def test_prepare_for_training_unique_job_name_generation(sagemaker_session):
298+
fw = DummyFramework(entry_point=SCRIPT_PATH, role=ROLE, sagemaker_session=sagemaker_session,
299+
train_instance_count=INSTANCE_COUNT, train_instance_type=INSTANCE_TYPE,
300+
enable_cloudwatch_metrics=True)
301+
fw.prepare_for_training()
302+
first_job_name = fw._current_job_name
303+
304+
fw.prepare_for_training()
305+
second_job_name = fw._current_job_name
306+
307+
assert first_job_name != second_job_name
308+
309+
310+
def test_prepare_for_training_force_name(sagemaker_session):
308311
fw = DummyFramework(entry_point=SCRIPT_PATH, role=ROLE, sagemaker_session=sagemaker_session,
309312
train_instance_count=INSTANCE_COUNT, train_instance_type=INSTANCE_TYPE,
310313
base_job_name='some', enable_cloudwatch_metrics=True)
311-
fw.fit(inputs=s3_input('s3://mybucket/train'), job_name='use_it')
312-
assert 'use_it' == fw.latest_training_job.name
314+
fw.prepare_for_training(job_name='use_it')
315+
assert 'use_it' == fw._current_job_name
313316

314317

315318
@patch('time.strftime', return_value=TIMESTAMP)
316-
def test_fit_force_generation(strftime, sagemaker_session):
319+
def test_prepare_for_training_force_name_generation(strftime, sagemaker_session):
317320
fw = DummyFramework(entry_point=SCRIPT_PATH, role=ROLE, sagemaker_session=sagemaker_session,
318321
train_instance_count=INSTANCE_COUNT, train_instance_type=INSTANCE_TYPE,
319322
base_job_name='some', enable_cloudwatch_metrics=True)
320323
fw.base_job_name = None
321-
fw.fit(inputs=s3_input('s3://mybucket/train'))
322-
assert JOB_NAME == fw.latest_training_job.name
324+
fw.prepare_for_training()
325+
assert JOB_NAME == fw._current_job_name
323326

324327

325328
@patch('time.strftime', return_value=TIMESTAMP)
326329
def test_init_with_source_dir_s3(strftime, sagemaker_session):
327-
uri = 'bucket/mydata'
328-
329330
fw = DummyFramework(entry_point=SCRIPT_PATH, source_dir='s3://location', role=ROLE,
330331
sagemaker_session=sagemaker_session,
331332
train_instance_count=INSTANCE_COUNT, train_instance_type=INSTANCE_TYPE,
332333
enable_cloudwatch_metrics=False)
333-
fw.fit('s3://{}'.format(uri))
334-
335-
expected_hyperparameters = BASE_HP.copy()
336-
expected_hyperparameters['sagemaker_enable_cloudwatch_metrics'] = 'false'
337-
expected_hyperparameters['sagemaker_container_log_level'] = str(logging.INFO)
338-
expected_hyperparameters['sagemaker_submit_directory'] = json.dumps("s3://location")
339-
expected_hyperparameters['sagemaker_region'] = '"us-west-2"'
340-
341-
actual_hyperparameter = sagemaker_session.method_calls[1][2]['hyperparameters']
342-
assert actual_hyperparameter == expected_hyperparameters
334+
fw.prepare_for_training()
335+
336+
expected_hyperparameters = {
337+
'sagemaker_program': SCRIPT_NAME,
338+
'sagemaker_submit_directory': 's3://mybucket/{}/source/sourcedir.tar.gz'.format(JOB_NAME),
339+
'sagemaker_job_name': JOB_NAME,
340+
'sagemaker_enable_cloudwatch_metrics': False,
341+
'sagemaker_container_log_level': logging.INFO,
342+
'sagemaker_submit_directory': 's3://location',
343+
'sagemaker_region': 'us-west-2',
344+
}
345+
assert fw._hyperparameters == expected_hyperparameters
343346

344347

345348
# _TrainingJob 'utils'

0 commit comments

Comments
 (0)