Skip to content

Commit ec07c35

Browse files
authored
Fix broken test test_distributed_mnist_no_ps (#156)
This test shouldn't save checkpoints since the two hosts are justing running training jobs independently. The checkpoints interfere with each other. Changing the test to use the Keras mnist script here. This change also changed the saved model path to /opt/ml/opt so we can just use the estimator.model_data path to assert the model exists.
1 parent 48507bb commit ec07c35

File tree

2 files changed

+10
-9
lines changed

2 files changed

+10
-9
lines changed

test/integration/sagemaker/test_mnist.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -37,13 +37,12 @@ def test_mnist(sagemaker_session, ecr_image, instance_type):
3737
path=os.path.join(resource_path, 'mnist', 'data'),
3838
key_prefix='scriptmode/mnist')
3939
estimator.fit(inputs)
40-
model_s3_url = estimator.create_model().model_data
41-
_assert_s3_file_exists(model_s3_url)
40+
_assert_s3_file_exists(estimator.model_data)
4241

4342

4443
def test_distributed_mnist_no_ps(sagemaker_session, ecr_image, instance_type):
4544
resource_path = os.path.join(os.path.dirname(__file__), '../..', 'resources')
46-
script = os.path.join(resource_path, 'mnist', 'mnist_estimator.py')
45+
script = os.path.join(resource_path, 'mnist', 'mnist.py')
4746
estimator = TensorFlow(entry_point=script,
4847
role='SageMakerRole',
4948
train_instance_count=2,
@@ -54,10 +53,9 @@ def test_distributed_mnist_no_ps(sagemaker_session, ecr_image, instance_type):
5453
py_version='py3',
5554
base_job_name='test-tf-sm-distributed-mnist')
5655
inputs = estimator.sagemaker_session.upload_data(
57-
path=os.path.join(resource_path, 'mnist', 'data-distributed'),
58-
key_prefix='scriptmode/mnist-distributed')
56+
path=os.path.join(resource_path, 'mnist', 'data'),
57+
key_prefix='scriptmode/mnist')
5958
estimator.fit(inputs)
60-
_assert_checkpoint_exists(estimator.model_dir, 0)
6159
_assert_s3_file_exists(estimator.model_data)
6260

6361

test/resources/mnist/mnist.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import argparse
33
import os
44
import numpy as np
5+
import json
56

67

78
def _parse_args():
@@ -11,10 +12,11 @@ def _parse_args():
1112
# hyperparameters sent by the client are passed as command-line arguments to the script.
1213
parser.add_argument('--epochs', type=int, default=1)
1314
# Data, model, and output directories
14-
parser.add_argument('--output-data-dir', type=str, default=os.environ['SM_OUTPUT_DATA_DIR'])
1515
parser.add_argument('--model-dir', type=str, default=os.environ['SM_MODEL_DIR'])
1616
parser.add_argument('--train', type=str, default=os.environ['SM_CHANNEL_TRAINING'])
17-
17+
parser.add_argument('--hosts', type=list, default=json.loads(os.environ['SM_HOSTS']))
18+
parser.add_argument('--current-host', type=str, default=os.environ['SM_CURRENT_HOST'])
19+
1820
return parser.parse_known_args()
1921

2022

@@ -46,4 +48,5 @@ def _load_testing_data(base_dir):
4648
x_test, y_test = _load_testing_data(args.train)
4749
model.fit(x_train, y_train, epochs=args.epochs)
4850
model.evaluate(x_test, y_test)
49-
model.save(os.path.join(args.model_dir, 'my_model.h5'))
51+
if args.current_host == args.hosts[0]:
52+
model.save(os.path.join('/opt/ml/model', 'my_model.h5'))

0 commit comments

Comments
 (0)