@@ -37,7 +37,7 @@ def test_mnist(sagemaker_session, ecr_image, instance_type, framework_version):
37
37
path = os .path .join (resource_path , 'mnist' , 'data' ),
38
38
key_prefix = 'scriptmode/mnist' )
39
39
estimator .fit (inputs )
40
- _assert_s3_file_exists (estimator .model_data )
40
+ _assert_s3_file_exists (sagemaker_session . boto_region_name , estimator .model_data )
41
41
42
42
43
43
def test_distributed_mnist_no_ps (sagemaker_session , ecr_image , instance_type , framework_version ):
@@ -56,7 +56,7 @@ def test_distributed_mnist_no_ps(sagemaker_session, ecr_image, instance_type, fr
56
56
path = os .path .join (resource_path , 'mnist' , 'data' ),
57
57
key_prefix = 'scriptmode/mnist' )
58
58
estimator .fit (inputs )
59
- _assert_s3_file_exists (estimator .model_data )
59
+ _assert_s3_file_exists (sagemaker_session . boto_region_name , estimator .model_data )
60
60
61
61
62
62
def test_distributed_mnist_ps (sagemaker_session , ecr_image , instance_type , framework_version ):
@@ -76,8 +76,8 @@ def test_distributed_mnist_ps(sagemaker_session, ecr_image, instance_type, frame
76
76
path = os .path .join (resource_path , 'mnist' , 'data-distributed' ),
77
77
key_prefix = 'scriptmode/mnist-distributed' )
78
78
estimator .fit (inputs )
79
- _assert_checkpoint_exists (estimator .model_dir , 0 )
80
- _assert_s3_file_exists (estimator .model_data )
79
+ _assert_checkpoint_exists (sagemaker_session . boto_region_name , estimator .model_dir , 0 )
80
+ _assert_s3_file_exists (sagemaker_session . boto_region_name , estimator .model_data )
81
81
82
82
83
83
def test_s3_plugin (sagemaker_session , ecr_image , instance_type , region , framework_version ):
@@ -107,17 +107,19 @@ def test_s3_plugin(sagemaker_session, ecr_image, instance_type, region, framewor
107
107
py_version = 'py3' ,
108
108
base_job_name = 'test-tf-sm-s3-mnist' )
109
109
estimator .fit ('s3://sagemaker-sample-data-{}/tensorflow/mnist' .format (region ))
110
- _assert_s3_file_exists (estimator .model_data )
111
- _assert_checkpoint_exists (estimator .model_dir , 200 )
110
+ _assert_s3_file_exists (region , estimator .model_data )
111
+ _assert_checkpoint_exists (region , estimator .model_dir , 200 )
112
112
113
113
114
- def _assert_checkpoint_exists (model_dir , checkpoint_number ):
115
- _assert_s3_file_exists (os .path .join (model_dir , 'graph.pbtxt' ))
116
- _assert_s3_file_exists (os .path .join (model_dir , 'model.ckpt-{}.index' .format (checkpoint_number )))
117
- _assert_s3_file_exists (os .path .join (model_dir , 'model.ckpt-{}.meta' .format (checkpoint_number )))
114
+ def _assert_checkpoint_exists (region , model_dir , checkpoint_number ):
115
+ _assert_s3_file_exists (region , os .path .join (model_dir , 'graph.pbtxt' ))
116
+ _assert_s3_file_exists (region ,
117
+ os .path .join (model_dir , 'model.ckpt-{}.index' .format (checkpoint_number )))
118
+ _assert_s3_file_exists (region ,
119
+ os .path .join (model_dir , 'model.ckpt-{}.meta' .format (checkpoint_number )))
118
120
119
121
120
- def _assert_s3_file_exists (s3_url ):
122
+ def _assert_s3_file_exists (region , s3_url ):
121
123
parsed_url = urlparse (s3_url )
122
- s3 = boto3 .resource ('s3' )
124
+ s3 = boto3 .resource ('s3' , region_name = region )
123
125
s3 .Object (parsed_url .netloc , parsed_url .path .lstrip ('/' )).load ()
0 commit comments