Skip to content

Commit 46fdf0b

Browse files
icywang86ruiEliza Zhang
authored andcommitted
Fix broken test test_distributed_mnist_no_ps (aws#156)
This test shouldn't save checkpoints since the two hosts are justing running training jobs independently. The checkpoints interfere with each other. Changing the test to use the Keras mnist script here. This change also changed the saved model path to /opt/ml/opt so we can just use the estimator.model_data path to assert the model exists.
1 parent 4d18645 commit 46fdf0b

File tree

1 file changed

+1
-41
lines changed

1 file changed

+1
-41
lines changed

test/resources/mnist/mnist.py

Lines changed: 1 addition & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,7 @@
22
import argparse
33
import os
44
import numpy as np
5-
<<<<<<< HEAD
6-
<<<<<<< HEAD
75
import json
8-
=======
9-
import sys
10-
>>>>>>> Scriptmode single machine training implementation (#78)
11-
=======
12-
>>>>>>> Add Keras support (#126)
136

147

158
def _parse_args():
@@ -19,7 +12,6 @@ def _parse_args():
1912
# hyperparameters sent by the client are passed as command-line arguments to the script.
2013
parser.add_argument('--epochs', type=int, default=1)
2114
# Data, model, and output directories
22-
<<<<<<< HEAD
2315
parser.add_argument('--model-dir', type=str, default=os.environ['SM_MODEL_DIR'])
2416
parser.add_argument('--train', type=str, default=os.environ['SM_CHANNEL_TRAINING'])
2517
parser.add_argument('--hosts', type=list, default=json.loads(os.environ['SM_HOSTS']))
@@ -28,32 +20,12 @@ def _parse_args():
2820
return parser.parse_known_args()
2921

3022

31-
=======
32-
parser.add_argument('--output-data-dir', type=str, default=os.environ['SM_OUTPUT_DATA_DIR'])
33-
parser.add_argument('--model-dir', type=str, default=os.environ['SM_MODEL_DIR'])
34-
parser.add_argument('--train', type=str, default=os.environ['SM_CHANNEL_TRAINING'])
35-
36-
return parser.parse_known_args()
37-
<<<<<<< HEAD
38-
#
39-
>>>>>>> Scriptmode single machine training implementation (#78)
40-
=======
41-
42-
43-
>>>>>>> Add distributed training support (#98)
4423
def _load_training_data(base_dir):
4524
x_train = np.load(os.path.join(base_dir, 'train', 'x_train.npy'))
4625
y_train = np.load(os.path.join(base_dir, 'train', 'y_train.npy'))
4726
return x_train, y_train
4827

49-
<<<<<<< HEAD
50-
<<<<<<< HEAD
51-
52-
=======
53-
>>>>>>> Scriptmode single machine training implementation (#78)
54-
=======
5528

56-
>>>>>>> Add distributed training support (#98)
5729
def _load_testing_data(base_dir):
5830
x_test = np.load(os.path.join(base_dir, 'test', 'x_test.npy'))
5931
y_test = np.load(os.path.join(base_dir, 'test', 'y_test.npy'))
@@ -63,15 +35,7 @@ def _load_testing_data(base_dir):
6335
args, unknown = _parse_args()
6436

6537
model = tf.keras.models.Sequential([
66-
<<<<<<< HEAD
67-
<<<<<<< HEAD
68-
tf.keras.layers.Flatten(input_shape=(28, 28)),
69-
=======
70-
tf.keras.layers.Flatten(),
71-
>>>>>>> Scriptmode single machine training implementation (#78)
72-
=======
7338
tf.keras.layers.Flatten(input_shape=(28, 28)),
74-
>>>>>>> Add distributed training support (#98)
7539
tf.keras.layers.Dense(512, activation=tf.nn.relu),
7640
tf.keras.layers.Dropout(0.2),
7741
tf.keras.layers.Dense(10, activation=tf.nn.softmax)
@@ -84,9 +48,5 @@ def _load_testing_data(base_dir):
8448
x_test, y_test = _load_testing_data(args.train)
8549
model.fit(x_train, y_train, epochs=args.epochs)
8650
model.evaluate(x_test, y_test)
87-
<<<<<<< HEAD
8851
if args.current_host == args.hosts[0]:
89-
model.save(os.path.join('/opt/ml/model', 'my_model.h5'))
90-
=======
91-
model.save(os.path.join(args.model_dir, 'my_model.h5'))
92-
>>>>>>> Scriptmode single machine training implementation (#78)
52+
model.save(os.path.join('/opt/ml/model', 'my_model.h5'))

0 commit comments

Comments
 (0)