Skip to content

Commit ca0ee1b

Browse files
icywang86ruiEliza Zhang
authored andcommitted
Set S3 environment variables (aws#112)
* Setting S3 environment variables before training starts * Remove S3 environment variable setting in test training script * Add unit tests
1 parent 0c96260 commit ca0ee1b

File tree

5 files changed

+53
-7
lines changed

5 files changed

+53
-7
lines changed

src/sagemaker_tensorflow_container/s3_utils.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
def configure(model_dir, job_region):
2222

23+
<<<<<<< HEAD
2324
os.environ['S3_REGION'] = _s3_region(job_region, model_dir)
2425

2526
# setting log level to WARNING
@@ -41,3 +42,24 @@ def _s3_region(job_region, model_dir):
4142
return bucket_location or job_region
4243
else:
4344
return job_region
45+
=======
46+
if not model_dir:
47+
return
48+
49+
s3 = boto3.client('s3', region_name=job_region)
50+
51+
# We get the AWS region of the checkpoint bucket, which may be different from
52+
# the region this container is currently running in.
53+
parsed_url = urlparse(model_dir)
54+
bucket_name = parsed_url.netloc
55+
56+
bucket_location = s3.get_bucket_location(Bucket=bucket_name)['LocationConstraint']
57+
58+
# Configure environment variables used by TensorFlow S3 file system
59+
if bucket_location:
60+
os.environ['S3_REGION'] = bucket_location
61+
62+
# setting log level to WARNING
63+
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '1'
64+
os.environ['S3_USE_HTTPS'] = '1'
65+
>>>>>>> Set S3 environment variables (#112)

src/sagemaker_tensorflow_container/training.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,14 @@
1515
import json
1616
import logging
1717
<<<<<<< HEAD
18+
<<<<<<< HEAD
1819
import multiprocessing
1920
import os
2021
=======
2122
>>>>>>> Add distributed training support (#98)
23+
=======
24+
import os
25+
>>>>>>> Set S3 environment variables (#112)
2226
import subprocess
2327
import time
2428

@@ -27,6 +31,8 @@
2731

2832
from sagemaker_tensorflow_container import s3_utils
2933

34+
import sagemaker_tensorflow_container.s3_utils as s3_utils
35+
3036

3137
logger = logging.getLogger(__name__)
3238

@@ -305,6 +311,7 @@ def main():
305311
"""
306312
hyperparameters = framework.env.read_hyperparameters()
307313
env = framework.training_env(hyperparameters=hyperparameters)
314+
<<<<<<< HEAD
308315

309316
user_hyperparameters = env.hyperparameters
310317

@@ -318,3 +325,8 @@ def main():
318325
s3_utils.configure(user_hyperparameters.get('model_dir'), os.environ.get('SAGEMAKER_REGION'))
319326
train(env, framework.mapping.to_cmd_args(user_hyperparameters))
320327
_log_model_missing_warning(MODEL_DIR)
328+
=======
329+
s3_utils.configure(env.hyperparameters.get('model_dir'), os.environ.get('SAGEMAKER_REGION'))
330+
logger.setLevel(env.log_level)
331+
train(env)
332+
>>>>>>> Set S3 environment variables (#112)

test/resources/mnist/distributed_mnist.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -136,11 +136,6 @@ def _parse_args():
136136
tf_logger = tf_logging._get_logger()
137137
tf_logger.handlers = [_handler]
138138

139-
if args.checkpoint_path.startswith('s3://'):
140-
os.environ['S3_REGION'] = 'us-west-2'
141-
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '0'
142-
os.environ['S3_USE_HTTPS'] = '1'
143-
144139
train_data, train_labels = _load_training_data(args.train)
145140
eval_data, eval_labels = _load_testing_data(args.train)
146141

test/unit/test_s3_utils.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ def test_configure(client):
3232
client.return_value = s3
3333
loc = {'LocationConstraint': BUCKET_REGION}
3434
s3.get_bucket_location.return_value = loc
35+
<<<<<<< HEAD
3536

3637
s3_utils.configure(MODEL_DIR, JOB_REGION)
3738

@@ -46,3 +47,9 @@ def test_configure_local_dir():
4647
assert os.environ['S3_REGION'] == JOB_REGION
4748
assert os.environ['TF_CPP_MIN_LOG_LEVEL'] == '1'
4849
assert os.environ['S3_USE_HTTPS'] == '1'
50+
=======
51+
s3_utils.configure(MODEL_DIR, JOB_REGION)
52+
assert os.environ['S3_REGION'] == BUCKET_REGION
53+
assert os.environ['TF_CPP_MIN_LOG_LEVEL'] == '1'
54+
assert os.environ['S3_USE_HTTPS'] == '1'
55+
>>>>>>> Set S3 environment variables (#112)

test/unit/test_training.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -319,6 +319,7 @@ def test_main_tuning_mpi_model_dir(configure_s3_env, read_hyperparameters, train
319319
>>>>>>> Add tox.ini and configure coverage and flake runs (#80)
320320
=======
321321
import json
322+
import os
322323

323324
>>>>>>> Add distributed training support (#98)
324325
from mock import MagicMock, patch
@@ -343,6 +344,8 @@ def test_main_tuning_mpi_model_dir(configure_s3_env, read_hyperparameters, train
343344
WORKER_TASK = {'index': 0, 'type': 'worker'}
344345
PS_TASK_1 = {'index': 0, 'type': 'ps'}
345346
PS_TASK_2 = {'index': 1, 'type': 'ps'}
347+
MODEL_DIR = 's3://bucket/prefix'
348+
REGION = 'us-west-2'
346349

347350

348351
@pytest.fixture
@@ -368,7 +371,7 @@ def single_machine_training_env():
368371

369372
env.module_dir = MODULE_DIR
370373
env.module_name = MODULE_NAME
371-
env.hyperparameters = {}
374+
env.hyperparameters = {'model_dir': MODEL_DIR}
372375
env.log_level = LOG_LEVEL
373376

374377
return env
@@ -505,11 +508,18 @@ def test_build_tf_config_error():
505508
@patch('logging.Logger.setLevel')
506509
@patch('sagemaker_containers.beta.framework.training_env')
507510
@patch('sagemaker_containers.beta.framework.env.read_hyperparameters', return_value={})
508-
def test_main(read_hyperparameters, training_env, set_level, train, single_machine_training_env):
511+
@patch('sagemaker_tensorflow_container.s3_utils.configure')
512+
def test_main(configure_s3_env, read_hyperparameters, training_env,
513+
set_level, train, single_machine_training_env):
509514
training_env.return_value = single_machine_training_env
515+
os.environ['SAGEMAKER_REGION'] = REGION
510516
training.main()
511517
read_hyperparameters.assert_called_once_with()
512518
training_env.assert_called_once_with(hyperparameters={})
513519
set_level.assert_called_once_with(LOG_LEVEL)
514520
train.assert_called_once_with(single_machine_training_env)
521+
<<<<<<< HEAD
515522
>>>>>>> Add tox.ini and configure coverage and flake runs (#80)
523+
=======
524+
configure_s3_env.assert_called_once()
525+
>>>>>>> Set S3 environment variables (#112)

0 commit comments

Comments
 (0)