Skip to content

Commit bc9edde

Browse files
committed
Switching to a simpler test for keras examples for MWMS
1 parent 3e7e9f8 commit bc9edde

File tree

2 files changed

+57
-18
lines changed

2 files changed

+57
-18
lines changed
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
# https://www.tensorflow.org/tutorials/distribute/multi_worker_with_keras
2+
3+
import json
4+
import os
5+
import tensorflow as tf
6+
import numpy as np
7+
8+
9+
def mnist_dataset(batch_size):
10+
(x_train, y_train), _ = tf.keras.datasets.mnist.load_data()
11+
# The `x` arrays are in uint8 and have values in the [0, 255] range.
12+
# You need to convert them to float32 with values in the [0, 1] range.
13+
x_train = x_train / np.float32(255)
14+
y_train = y_train.astype(np.int64)
15+
train_dataset = (
16+
tf.data.Dataset.from_tensor_slices((x_train, y_train))
17+
.shuffle(60000)
18+
.repeat()
19+
.batch(batch_size)
20+
)
21+
return train_dataset
22+
23+
24+
def build_and_compile_cnn_model():
25+
model = tf.keras.Sequential(
26+
[
27+
tf.keras.layers.InputLayer(input_shape=(28, 28)),
28+
tf.keras.layers.Reshape(target_shape=(28, 28, 1)),
29+
tf.keras.layers.Conv2D(32, 3, activation="relu"),
30+
tf.keras.layers.Flatten(),
31+
tf.keras.layers.Dense(128, activation="relu"),
32+
tf.keras.layers.Dense(10),
33+
]
34+
)
35+
model.compile(
36+
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
37+
optimizer=tf.keras.optimizers.SGD(learning_rate=0.001),
38+
metrics=["accuracy"],
39+
)
40+
return model
41+
42+
43+
per_worker_batch_size = 64
44+
tf_config = json.loads(os.environ["TF_CONFIG"])
45+
num_workers = len(tf_config["cluster"]["worker"])
46+
47+
strategy = tf.distribute.MultiWorkerMirroredStrategy()
48+
49+
global_batch_size = per_worker_batch_size * num_workers
50+
multi_worker_dataset = mnist_dataset(global_batch_size)
51+
52+
with strategy.scope():
53+
multi_worker_model = build_and_compile_cnn_model()
54+
55+
multi_worker_model.fit(multi_worker_dataset, epochs=3, steps_per_epoch=70)

tests/integ/test_tf.py

Lines changed: 2 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -189,30 +189,14 @@ def test_mwms_gpu(
189189
):
190190
instance_count = 2
191191
estimator = TensorFlow(
192-
source_dir=os.path.join(RESOURCE_PATH, "huggingface", "run_mlm"),
193-
entry_point="run_mlm.py",
192+
source_dir=os.path.join(RESOURCE_PATH, "tensorflow_mnist"),
193+
entry_point="mnist_mwms.py",
194194
model_dir=False,
195195
instance_type=kwargs["instance_type"],
196196
instance_count=instance_count,
197197
framework_version=tensorflow_training_latest_version,
198198
py_version=tensorflow_training_latest_py_version,
199199
distribution=MWMS_DISTRIBUTION,
200-
hyperparameters={
201-
"model_name_or_path": "bert-base-uncased",
202-
"output_dir": "/opt/ml/model",
203-
"dataset_name": "glue",
204-
"dataset_config_name": "sst2",
205-
"do_train": True,
206-
"do_eval": False,
207-
"max_seq_length": 128,
208-
"num_train_epochs": 1,
209-
"max_steps": 16,
210-
"overwrite_output_dir": True,
211-
"save_strategy": "no",
212-
"evaluation_strategy": "no",
213-
"logging_strategy": "epoch",
214-
"per_device_train_batch_size": 16,
215-
},
216200
environment={"NCCL_DEBUG": "INFO"},
217201
max_run=60 * 60 * 1, # 1 hour
218202
role=ROLE,

0 commit comments

Comments
 (0)