Skip to content

Commit a897135

Browse files
authored
Set S3 environment variables (#112)
* Setting S3 environment variables before training starts * Remove S3 environment variable setting in test training script * Add unit tests
1 parent 177773d commit a897135

File tree

5 files changed

+92
-7
lines changed

5 files changed

+92
-7
lines changed
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
# Copyright 2017-2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the 'license' file accompanying this file. This file is
10+
# distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
from __future__ import absolute_import
14+
15+
import os
16+
17+
import boto3
18+
from six.moves.urllib.parse import urlparse
19+
20+
21+
def configure(model_dir, job_region):
22+
23+
if not model_dir:
24+
return
25+
26+
s3 = boto3.client('s3', region_name=job_region)
27+
28+
# We get the AWS region of the checkpoint bucket, which may be different from
29+
# the region this container is currently running in.
30+
parsed_url = urlparse(model_dir)
31+
bucket_name = parsed_url.netloc
32+
33+
bucket_location = s3.get_bucket_location(Bucket=bucket_name)['LocationConstraint']
34+
35+
# Configure environment variables used by TensorFlow S3 file system
36+
if bucket_location:
37+
os.environ['S3_REGION'] = bucket_location
38+
39+
# setting log level to WARNING
40+
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '1'
41+
os.environ['S3_USE_HTTPS'] = '1'

src/sagemaker_tensorflow_container/training.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,14 @@
1515

1616
import json
1717
import logging
18+
import os
1819
import subprocess
1920
import time
2021

2122
import sagemaker_containers.beta.framework as framework
2223

24+
import sagemaker_tensorflow_container.s3_utils as s3_utils
25+
2326

2427
logger = logging.getLogger(__name__)
2528

@@ -151,5 +154,6 @@ def main():
151154
"""
152155
hyperparameters = framework.env.read_hyperparameters()
153156
env = framework.training_env(hyperparameters=hyperparameters)
157+
s3_utils.configure(env.hyperparameters.get('model_dir'), os.environ.get('SAGEMAKER_REGION'))
154158
logger.setLevel(env.log_level)
155159
train(env)

test/resources/mnist/distributed_mnist.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -136,11 +136,6 @@ def _parse_args():
136136
tf_logger = tf_logging._get_logger()
137137
tf_logger.handlers = [_handler]
138138

139-
if args.checkpoint_path.startswith('s3://'):
140-
os.environ['S3_REGION'] = 'us-west-2'
141-
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '0'
142-
os.environ['S3_USE_HTTPS'] = '1'
143-
144139
train_data, train_labels = _load_training_data(args.train)
145140
eval_data, eval_labels = _load_testing_data(args.train)
146141

test/unit/test_s3_utils.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
# Copyright 2017-2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the 'license' file accompanying this file. This file is
10+
# distributed on an 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
from __future__ import absolute_import
14+
15+
import os
16+
17+
from mock import MagicMock, patch
18+
19+
from sagemaker_tensorflow_container import s3_utils
20+
21+
22+
BUCKET_REGION = 'us-west-2'
23+
JOB_REGION = 'us-west-1'
24+
JOB_BUKCET = 'sagemaker-us-west-2-000-00-1'
25+
PREFIX = 'sagemaker/something'
26+
MODEL_DIR = 's3://{}/{}'.format(JOB_BUKCET, PREFIX)
27+
28+
29+
@patch('boto3.client')
30+
def test_configure(client):
31+
s3 = MagicMock()
32+
client.return_value = s3
33+
loc = {'LocationConstraint': BUCKET_REGION}
34+
s3.get_bucket_location.return_value = loc
35+
s3_utils.configure(MODEL_DIR, JOB_REGION)
36+
assert os.environ['S3_REGION'] == BUCKET_REGION
37+
assert os.environ['TF_CPP_MIN_LOG_LEVEL'] == '1'
38+
assert os.environ['S3_USE_HTTPS'] == '1'

test/unit/test_training.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from __future__ import absolute_import
1414

1515
import json
16+
import os
1617

1718
from mock import MagicMock, patch
1819
import pytest
@@ -36,6 +37,8 @@
3637
WORKER_TASK = {'index': 0, 'type': 'worker'}
3738
PS_TASK_1 = {'index': 0, 'type': 'ps'}
3839
PS_TASK_2 = {'index': 1, 'type': 'ps'}
40+
MODEL_DIR = 's3://bucket/prefix'
41+
REGION = 'us-west-2'
3942

4043

4144
@pytest.fixture
@@ -61,7 +64,7 @@ def single_machine_training_env():
6164

6265
env.module_dir = MODULE_DIR
6366
env.module_name = MODULE_NAME
64-
env.hyperparameters = {}
67+
env.hyperparameters = {'model_dir': MODEL_DIR}
6568
env.log_level = LOG_LEVEL
6669

6770
return env
@@ -195,10 +198,14 @@ def test_build_tf_config_error():
195198
@patch('logging.Logger.setLevel')
196199
@patch('sagemaker_containers.beta.framework.training_env')
197200
@patch('sagemaker_containers.beta.framework.env.read_hyperparameters', return_value={})
198-
def test_main(read_hyperparameters, training_env, set_level, train, single_machine_training_env):
201+
@patch('sagemaker_tensorflow_container.s3_utils.configure')
202+
def test_main(configure_s3_env, read_hyperparameters, training_env,
203+
set_level, train, single_machine_training_env):
199204
training_env.return_value = single_machine_training_env
205+
os.environ['SAGEMAKER_REGION'] = REGION
200206
training.main()
201207
read_hyperparameters.assert_called_once_with()
202208
training_env.assert_called_once_with(hyperparameters={})
203209
set_level.assert_called_once_with(LOG_LEVEL)
204210
train.assert_called_once_with(single_machine_training_env)
211+
configure_s3_env.assert_called_once()

0 commit comments

Comments
 (0)