Skip to content

Commit ce47c76

Browse files
authored
Fix model_dir adjustment for hyperparameter tuning jobs (#181)
1 parent 00a7a0b commit ce47c76

File tree

2 files changed

+51
-31
lines changed

2 files changed

+51
-31
lines changed

src/sagemaker_tensorflow_container/training.py

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
1+
# Copyright 2018-2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
22
#
33
# Licensed under the Apache License, Version 2.0 (the 'License'). You
44
# may not use this file except in compliance with the License. A copy of
@@ -10,7 +10,6 @@
1010
# distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
1111
# ANY KIND, either express or implied. See the License for the specific
1212
# language governing permissions and limitations under the License.
13-
1413
from __future__ import absolute_import
1514

1615
import json
@@ -106,11 +105,11 @@ def _run_ps(env, cluster):
106105
threading.Thread(target=lambda: server.join()).start()
107106

108107

109-
def _run_worker(env, tf_config):
108+
def _run_worker(env, cmd_args, tf_config):
110109
env_vars = env.to_env_vars()
111110
env_vars['TF_CONFIG'] = json.dumps(tf_config)
112111

113-
framework.entry_point.run(env.module_dir, env.user_entry_point, env.to_cmd_args(), env_vars)
112+
framework.entry_point.run(env.module_dir, env.user_entry_point, cmd_args, env_vars)
114113

115114

116115
def _wait_until_master_is_down(master):
@@ -125,7 +124,7 @@ def _wait_until_master_is_down(master):
125124
return
126125

127126

128-
def train(env):
127+
def train(env, cmd_args):
129128
"""Get training job environment from env and run the training job.
130129
131130
Args:
@@ -141,7 +140,7 @@ def train(env):
141140
logger.info('Launching parameter server process')
142141
_run_ps(env, tf_config['cluster'])
143142
logger.info('Launching worker process')
144-
_run_worker(env, tf_config)
143+
_run_worker(env, cmd_args, tf_config)
145144

146145
if not _is_host_master(env.hosts, env.current_host):
147146
_wait_until_master_is_down(env.hosts[0])
@@ -155,8 +154,7 @@ def train(env):
155154
else:
156155
runner_type = framework.runner.ProcessRunnerType
157156

158-
framework.entry_point.run(env.module_dir, env.user_entry_point,
159-
env.to_cmd_args(), env.to_env_vars(),
157+
framework.entry_point.run(env.module_dir, env.user_entry_point, cmd_args, env.to_env_vars(),
160158
runner=runner_type)
161159

162160

@@ -187,18 +185,28 @@ def _log_model_missing_warning(model_dir):
187185
'https://www.tensorflow.org/guide/saved_model#structure_of_a_savedmodel_directory')
188186

189187

188+
def _model_dir_with_training_job(model_dir, job_name):
189+
if model_dir.startswith('/opt/ml'):
190+
return model_dir
191+
else:
192+
return '{}/{}/model'.format(model_dir, job_name)
193+
194+
190195
def main():
191196
"""Training entry point
192197
"""
193198
hyperparameters = framework.env.read_hyperparameters()
194199
env = framework.training_env(hyperparameters=hyperparameters)
195200

201+
user_hyperparameters = env.hyperparameters
202+
196203
# If the training job is part of the multiple training jobs for tuning, we need to append the training job name to
197204
# model_dir in case they read from/write to the same object
198205
if '_tuning_objective_metric' in hyperparameters:
199-
env.hyperparameters['model_dir'] = os.path.join(hyperparameters.get('model_dir'), env.job_name, 'checkpoints')
206+
model_dir = _model_dir_with_training_job(hyperparameters.get('model_dir'), env.job_name)
207+
logger.info('Appending the training job name to model_dir: {}'.format(model_dir))
208+
user_hyperparameters['model_dir'] = model_dir
200209

201-
s3_utils.configure(env.hyperparameters.get('model_dir'), os.environ.get('SAGEMAKER_REGION'))
202-
logger.setLevel(env.log_level)
203-
train(env)
210+
s3_utils.configure(user_hyperparameters.get('model_dir'), os.environ.get('SAGEMAKER_REGION'))
211+
train(env, framework.mapping.to_cmd_args(user_hyperparameters))
204212
_log_model_missing_warning(MODEL_DIR)

test/unit/test_training.py

Lines changed: 31 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2017-2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
1+
# Copyright 2017-2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License"). You
44
# may not use this file except in compliance with the License. A copy of
@@ -40,6 +40,7 @@
4040
PS_TASK_1 = {'index': 0, 'type': 'ps'}
4141
PS_TASK_2 = {'index': 1, 'type': 'ps'}
4242
MODEL_DIR = 's3://bucket/prefix'
43+
MODEL_DIR_CMD_LIST = ['--model_dir', MODEL_DIR]
4344
REGION = 'us-west-2'
4445
RESOURCE_PATH = os.path.join(os.path.dirname(__file__), '..', 'resources')
4546

@@ -82,9 +83,8 @@ def test_is_host_master():
8283

8384
@patch('sagemaker_containers.beta.framework.entry_point.run')
8485
def test_single_machine(run_module, single_machine_training_env):
85-
training.train(single_machine_training_env)
86-
run_module.assert_called_with(MODULE_DIR, MODULE_NAME,
87-
single_machine_training_env.to_cmd_args(),
86+
training.train(single_machine_training_env, MODEL_DIR_CMD_LIST)
87+
run_module.assert_called_with(MODULE_DIR, MODULE_NAME, MODEL_DIR_CMD_LIST,
8888
single_machine_training_env.to_env_vars(),
8989
runner=runner.ProcessRunnerType)
9090

@@ -93,9 +93,8 @@ def test_single_machine(run_module, single_machine_training_env):
9393
def test_train_horovod(run_module, single_machine_training_env):
9494
single_machine_training_env.additional_framework_parameters['sagemaker_mpi_enabled'] = True
9595

96-
training.train(single_machine_training_env)
97-
run_module.assert_called_with(MODULE_DIR, MODULE_NAME,
98-
single_machine_training_env.to_cmd_args(),
96+
training.train(single_machine_training_env, MODEL_DIR_CMD_LIST)
97+
run_module.assert_called_with(MODULE_DIR, MODULE_NAME, MODEL_DIR_CMD_LIST,
9998
single_machine_training_env.to_env_vars(),
10099
runner=runner.MPIRunnerType)
101100

@@ -108,7 +107,7 @@ def test_train_horovod(run_module, single_machine_training_env):
108107
@patch('threading.Thread', lambda target: target())
109108
@patch('time.sleep', MagicMock())
110109
def test_train_distributed_master(run, tf_server, cluster_spec, distributed_training_env):
111-
training.train(distributed_training_env)
110+
training.train(distributed_training_env, MODEL_DIR_CMD_LIST)
112111

113112
cluster_spec.assert_called_with({'worker': ['host2:2222'],
114113
'master': ['host1:2222'],
@@ -126,8 +125,7 @@ def test_train_distributed_master(run, tf_server, cluster_spec, distributed_trai
126125
'"environment": "cloud", ' \
127126
'"task": {"index": 0, "type": "master"}}'
128127

129-
run.assert_called_with('s3://my/bucket', 'script_name',
130-
distributed_training_env.to_cmd_args(),
128+
run.assert_called_with('s3://my/bucket', 'script_name', MODEL_DIR_CMD_LIST,
131129
{'TF_CONFIG': tf_config})
132130

133131

@@ -140,7 +138,7 @@ def test_train_distributed_master(run, tf_server, cluster_spec, distributed_trai
140138
def test_train_distributed_worker(run, tf_server, cluster_spec, distributed_training_env):
141139
distributed_training_env.current_host = HOST2
142140

143-
training.train(distributed_training_env)
141+
training.train(distributed_training_env, MODEL_DIR_CMD_LIST)
144142

145143
cluster_spec.assert_called_with({'worker': ['host2:2222'],
146144
'master': ['host1:2222'],
@@ -158,8 +156,7 @@ def test_train_distributed_worker(run, tf_server, cluster_spec, distributed_trai
158156
'"environment": "cloud", ' \
159157
'"task": {"index": 0, "type": "worker"}}'
160158

161-
run.assert_called_with('s3://my/bucket', 'script_name',
162-
distributed_training_env.to_cmd_args(),
159+
run.assert_called_with('s3://my/bucket', 'script_name', MODEL_DIR_CMD_LIST,
163160
{'TF_CONFIG': tf_config})
164161

165162

@@ -168,9 +165,9 @@ def test_train_distributed_no_ps(run, distributed_training_env):
168165
distributed_training_env.additional_framework_parameters[
169166
training.SAGEMAKER_PARAMETER_SERVER_ENABLED] = False
170167
distributed_training_env.current_host = HOST2
171-
training.train(distributed_training_env)
168+
training.train(distributed_training_env, MODEL_DIR_CMD_LIST)
172169

173-
run.assert_called_with(MODULE_DIR, MODULE_NAME, distributed_training_env.to_cmd_args(),
170+
run.assert_called_with(MODULE_DIR, MODULE_NAME, MODEL_DIR_CMD_LIST,
174171
distributed_training_env.to_env_vars(), runner=runner.ProcessRunnerType)
175172

176173

@@ -251,7 +248,7 @@ def test_main(configure_s3_env, read_hyperparameters, training_env,
251248
training.main()
252249
read_hyperparameters.assert_called_once_with()
253250
training_env.assert_called_once_with(hyperparameters={})
254-
train.assert_called_once_with(single_machine_training_env)
251+
train.assert_called_once_with(single_machine_training_env, MODEL_DIR_CMD_LIST)
255252
configure_s3_env.assert_called_once()
256253

257254

@@ -276,10 +273,25 @@ def test_main_simple_training_model_dir(configure_s3_env, read_hyperparameters,
276273
@patch('sagemaker_containers.beta.framework.env.read_hyperparameters', return_value={'model_dir': MODEL_DIR,
277274
'_tuning_objective_metric': 'auc'})
278275
@patch('sagemaker_tensorflow_container.s3_utils.configure')
279-
def test_main_tunning_model_dir(configure_s3_env, read_hyperparameters, training_env,
280-
set_level, train, logger, single_machine_training_env):
276+
def test_main_tuning_model_dir(configure_s3_env, read_hyperparameters, training_env,
277+
set_level, train, logger, single_machine_training_env):
281278
training_env.return_value = single_machine_training_env
282279
os.environ['SAGEMAKER_REGION'] = REGION
283280
training.main()
284-
expected_model_dir = os.path.join(MODEL_DIR, single_machine_training_env.job_name, 'checkpoints')
281+
expected_model_dir = '{}/{}/model'.format(MODEL_DIR, single_machine_training_env.job_name)
285282
configure_s3_env.assert_called_once_with(expected_model_dir, REGION)
283+
284+
285+
@patch('sagemaker_tensorflow_container.training.logger')
286+
@patch('sagemaker_tensorflow_container.training.train')
287+
@patch('logging.Logger.setLevel')
288+
@patch('sagemaker_containers.beta.framework.training_env')
289+
@patch('sagemaker_containers.beta.framework.env.read_hyperparameters', return_value={'model_dir': '/opt/ml/model',
290+
'_tuning_objective_metric': 'auc'})
291+
@patch('sagemaker_tensorflow_container.s3_utils.configure')
292+
def test_main_tuning_mpi_model_dir(configure_s3_env, read_hyperparameters, training_env,
293+
set_level, train, logger, single_machine_training_env):
294+
training_env.return_value = single_machine_training_env
295+
os.environ['SAGEMAKER_REGION'] = REGION
296+
training.main()
297+
configure_s3_env.assert_called_once_with('/opt/ml/model', REGION)

0 commit comments

Comments
 (0)