Skip to content

Commit b3cb548

Browse files
authored
Specify region when creating S3 resource in integ tests (#169)
1 parent e16a936 commit b3cb548

File tree

1 file changed

+14
-12
lines changed

1 file changed

+14
-12
lines changed

test/integration/sagemaker/test_mnist.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def test_mnist(sagemaker_session, ecr_image, instance_type, framework_version):
3737
path=os.path.join(resource_path, 'mnist', 'data'),
3838
key_prefix='scriptmode/mnist')
3939
estimator.fit(inputs)
40-
_assert_s3_file_exists(estimator.model_data)
40+
_assert_s3_file_exists(sagemaker_session.boto_region_name, estimator.model_data)
4141

4242

4343
def test_distributed_mnist_no_ps(sagemaker_session, ecr_image, instance_type, framework_version):
@@ -56,7 +56,7 @@ def test_distributed_mnist_no_ps(sagemaker_session, ecr_image, instance_type, fr
5656
path=os.path.join(resource_path, 'mnist', 'data'),
5757
key_prefix='scriptmode/mnist')
5858
estimator.fit(inputs)
59-
_assert_s3_file_exists(estimator.model_data)
59+
_assert_s3_file_exists(sagemaker_session.boto_region_name, estimator.model_data)
6060

6161

6262
def test_distributed_mnist_ps(sagemaker_session, ecr_image, instance_type, framework_version):
@@ -76,8 +76,8 @@ def test_distributed_mnist_ps(sagemaker_session, ecr_image, instance_type, frame
7676
path=os.path.join(resource_path, 'mnist', 'data-distributed'),
7777
key_prefix='scriptmode/mnist-distributed')
7878
estimator.fit(inputs)
79-
_assert_checkpoint_exists(estimator.model_dir, 0)
80-
_assert_s3_file_exists(estimator.model_data)
79+
_assert_checkpoint_exists(sagemaker_session.boto_region_name, estimator.model_dir, 0)
80+
_assert_s3_file_exists(sagemaker_session.boto_region_name, estimator.model_data)
8181

8282

8383
def test_s3_plugin(sagemaker_session, ecr_image, instance_type, region, framework_version):
@@ -107,17 +107,19 @@ def test_s3_plugin(sagemaker_session, ecr_image, instance_type, region, framewor
107107
py_version='py3',
108108
base_job_name='test-tf-sm-s3-mnist')
109109
estimator.fit('s3://sagemaker-sample-data-{}/tensorflow/mnist'.format(region))
110-
_assert_s3_file_exists(estimator.model_data)
111-
_assert_checkpoint_exists(estimator.model_dir, 200)
110+
_assert_s3_file_exists(region, estimator.model_data)
111+
_assert_checkpoint_exists(region, estimator.model_dir, 200)
112112

113113

114-
def _assert_checkpoint_exists(model_dir, checkpoint_number):
115-
_assert_s3_file_exists(os.path.join(model_dir, 'graph.pbtxt'))
116-
_assert_s3_file_exists(os.path.join(model_dir, 'model.ckpt-{}.index'.format(checkpoint_number)))
117-
_assert_s3_file_exists(os.path.join(model_dir, 'model.ckpt-{}.meta'.format(checkpoint_number)))
114+
def _assert_checkpoint_exists(region, model_dir, checkpoint_number):
115+
_assert_s3_file_exists(region, os.path.join(model_dir, 'graph.pbtxt'))
116+
_assert_s3_file_exists(region,
117+
os.path.join(model_dir, 'model.ckpt-{}.index'.format(checkpoint_number)))
118+
_assert_s3_file_exists(region,
119+
os.path.join(model_dir, 'model.ckpt-{}.meta'.format(checkpoint_number)))
118120

119121

120-
def _assert_s3_file_exists(s3_url):
122+
def _assert_s3_file_exists(region, s3_url):
121123
parsed_url = urlparse(s3_url)
122-
s3 = boto3.resource('s3')
124+
s3 = boto3.resource('s3', region_name=region)
123125
s3.Object(parsed_url.netloc, parsed_url.path.lstrip('/')).load()

0 commit comments

Comments
 (0)