Skip to content

Commit c02f3fc

Browse files
icywang86ruiEliza Zhang
authored andcommitted
Set S3 environment variables (aws#112)
* Setting S3 environment variables before training starts * Remove S3 environment variable setting in test training script * Add unit tests
1 parent 15e32a4 commit c02f3fc

File tree

3 files changed

+2
-65
lines changed

3 files changed

+2
-65
lines changed

src/sagemaker_tensorflow_container/s3_utils.py

Lines changed: 1 addition & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,6 @@
2020

2121
def configure(model_dir, job_region):
2222

23-
<<<<<<< HEAD
24-
<<<<<<< HEAD
2523
os.environ['S3_REGION'] = _s3_region(job_region, model_dir)
2624

2725
# setting log level to WARNING
@@ -42,37 +40,4 @@ def _s3_region(job_region, model_dir):
4240

4341
return bucket_location or job_region
4442
else:
45-
return job_region
46-
=======
47-
if not model_dir:
48-
return
49-
=======
50-
os.environ['S3_REGION'] = _s3_region(job_region, model_dir)
51-
>>>>>>> Add Keras support (#126)
52-
53-
# setting log level to WARNING
54-
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '1'
55-
os.environ['S3_USE_HTTPS'] = '1'
56-
57-
58-
def _s3_region(job_region, model_dir):
59-
if model_dir and model_dir.startswith('s3://'):
60-
s3 = boto3.client('s3', region_name=job_region)
61-
62-
# We get the AWS region of the checkpoint bucket, which may be different from
63-
# the region this container is currently running in.
64-
parsed_url = urlparse(model_dir)
65-
bucket_name = parsed_url.netloc
66-
67-
<<<<<<< HEAD
68-
# setting log level to WARNING
69-
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '1'
70-
os.environ['S3_USE_HTTPS'] = '1'
71-
>>>>>>> Set S3 environment variables (#112)
72-
=======
73-
bucket_location = s3.get_bucket_location(Bucket=bucket_name)['LocationConstraint']
74-
75-
return bucket_location or job_region
76-
else:
77-
return job_region
78-
>>>>>>> Add Keras support (#126)
43+
return job_region

test/resources/mnist/distributed_mnist.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -136,11 +136,6 @@ def _parse_args():
136136
tf_logger = tf_logging._get_logger()
137137
tf_logger.handlers = [_handler]
138138

139-
if args.checkpoint_path.startswith('s3://'):
140-
os.environ['S3_REGION'] = 'us-west-2'
141-
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '0'
142-
os.environ['S3_USE_HTTPS'] = '1'
143-
144139
train_data, train_labels = _load_training_data(args.train)
145140
eval_data, eval_labels = _load_testing_data(args.train)
146141

test/unit/test_s3_utils.py

Lines changed: 1 addition & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,6 @@ def test_configure(client):
3232
client.return_value = s3
3333
loc = {'LocationConstraint': BUCKET_REGION}
3434
s3.get_bucket_location.return_value = loc
35-
<<<<<<< HEAD
36-
<<<<<<< HEAD
3735

3836
s3_utils.configure(MODEL_DIR, JOB_REGION)
3937

@@ -47,25 +45,4 @@ def test_configure_local_dir():
4745

4846
assert os.environ['S3_REGION'] == JOB_REGION
4947
assert os.environ['TF_CPP_MIN_LOG_LEVEL'] == '1'
50-
assert os.environ['S3_USE_HTTPS'] == '1'
51-
=======
52-
=======
53-
54-
>>>>>>> Add Keras support (#126)
55-
s3_utils.configure(MODEL_DIR, JOB_REGION)
56-
57-
assert os.environ['S3_REGION'] == BUCKET_REGION
58-
assert os.environ['TF_CPP_MIN_LOG_LEVEL'] == '1'
59-
assert os.environ['S3_USE_HTTPS'] == '1'
60-
<<<<<<< HEAD
61-
>>>>>>> Set S3 environment variables (#112)
62-
=======
63-
64-
65-
def test_configure_local_dir():
66-
s3_utils.configure('/opt/ml/model', JOB_REGION)
67-
68-
assert os.environ['S3_REGION'] == JOB_REGION
69-
assert os.environ['TF_CPP_MIN_LOG_LEVEL'] == '1'
70-
assert os.environ['S3_USE_HTTPS'] == '1'
71-
>>>>>>> Add Keras support (#126)
48+
assert os.environ['S3_USE_HTTPS'] == '1'

0 commit comments

Comments
 (0)