Skip to content

Commit 00a7a0b

Browse files
authored
fix: change model_dir to training job name if it is for tuning. (#179)
* Change model_dir to training job if the training job is for tuning.
1 parent c286f01 commit 00a7a0b

File tree

2 files changed

+37
-0
lines changed

2 files changed

+37
-0
lines changed

src/sagemaker_tensorflow_container/training.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,12 @@ def main():
192192
"""
193193
hyperparameters = framework.env.read_hyperparameters()
194194
env = framework.training_env(hyperparameters=hyperparameters)
195+
196+
# If the training job is part of the multiple training jobs for tuning, we need to append the training job name to
197+
# model_dir in case they read from/write to the same object
198+
if '_tuning_objective_metric' in hyperparameters:
199+
env.hyperparameters['model_dir'] = os.path.join(hyperparameters.get('model_dir'), env.job_name, 'checkpoints')
200+
195201
s3_utils.configure(env.hyperparameters.get('model_dir'), os.environ.get('SAGEMAKER_REGION'))
196202
logger.setLevel(env.log_level)
197203
train(env)

test/unit/test_training.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ def simple_training_env():
7070
env.hosts = CURRENT_HOST
7171
env.current_host = CURRENT_HOST
7272
env.to_env_vars = lambda: {}
73+
env.job_name = 'test-training-job'
7374
return env
7475

7576

@@ -252,3 +253,33 @@ def test_main(configure_s3_env, read_hyperparameters, training_env,
252253
training_env.assert_called_once_with(hyperparameters={})
253254
train.assert_called_once_with(single_machine_training_env)
254255
configure_s3_env.assert_called_once()
256+
257+
258+
@patch('sagemaker_tensorflow_container.training.logger')
259+
@patch('sagemaker_tensorflow_container.training.train')
260+
@patch('logging.Logger.setLevel')
261+
@patch('sagemaker_containers.beta.framework.training_env')
262+
@patch('sagemaker_containers.beta.framework.env.read_hyperparameters', return_value={'model_dir': MODEL_DIR})
263+
@patch('sagemaker_tensorflow_container.s3_utils.configure')
264+
def test_main_simple_training_model_dir(configure_s3_env, read_hyperparameters, training_env,
265+
set_level, train, logger, single_machine_training_env):
266+
training_env.return_value = single_machine_training_env
267+
os.environ['SAGEMAKER_REGION'] = REGION
268+
training.main()
269+
configure_s3_env.assert_called_once_with(MODEL_DIR, REGION)
270+
271+
272+
@patch('sagemaker_tensorflow_container.training.logger')
273+
@patch('sagemaker_tensorflow_container.training.train')
274+
@patch('logging.Logger.setLevel')
275+
@patch('sagemaker_containers.beta.framework.training_env')
276+
@patch('sagemaker_containers.beta.framework.env.read_hyperparameters', return_value={'model_dir': MODEL_DIR,
277+
'_tuning_objective_metric': 'auc'})
278+
@patch('sagemaker_tensorflow_container.s3_utils.configure')
279+
def test_main_tunning_model_dir(configure_s3_env, read_hyperparameters, training_env,
280+
set_level, train, logger, single_machine_training_env):
281+
training_env.return_value = single_machine_training_env
282+
os.environ['SAGEMAKER_REGION'] = REGION
283+
training.main()
284+
expected_model_dir = os.path.join(MODEL_DIR, single_machine_training_env.job_name, 'checkpoints')
285+
configure_s3_env.assert_called_once_with(expected_model_dir, REGION)

0 commit comments

Comments
 (0)