-
Notifications
You must be signed in to change notification settings - Fork 1.2k
feature: Adding support for Multi Worker Mirrored Strategy in TF estimator #3192
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
39 commits
Select commit
Hold shift + click to select a range
16d7058
feature: adding support for Multi Worker Mirrored Strategy in TF esti…
Lokiiiiii f187cd0
tests: adding unit tests targeting MWMS in TF
Lokiiiiii dd6fd51
test: adding integration tests targetting MWMS in TF
Lokiiiiii c954c79
fix: linting and removing accidental file addition
Lokiiiiii cf0740f
doc: Adding doc strings to tests
Lokiiiiii 6e3c4e2
Fixing MWMS unit test for TF2
Lokiiiiii 6e3fa48
Fixing MWMS tests for TF2
Lokiiiiii 4cb17fe
Fixing MWMS tests for TF2
Lokiiiiii aecd12c
Fixing MWMS tests for TF2
Lokiiiiii ab2ad0d
Fixing MWMS tests for TF2
Lokiiiiii f7d9f6c
Fixing MWMS tests for TF2
Lokiiiiii 3ba2c39
Finishing up MWMS tests
Lokiiiiii 55af827
Finishing up MWMS tests
Lokiiiiii b87df13
Using entire imagenet dataset instead of a subset
Lokiiiiii 2f99064
update: pruning unused fixtures in MWMS test
Lokiiiiii 6599d8d
fix: stop saving artifacts from MWMS test
Lokiiiiii 96a3224
fix: save artifacts from MWMS test to tmp directory to discard
Lokiiiiii 838f66b
Adding a new test for HF transformers
Lokiiiiii a36fa7a
Removing stale tests and fixtures
Lokiiiiii ddbb4ea
fix: fixing syntax error in MWMS test
Lokiiiiii 30cb055
Update src/sagemaker/tensorflow/estimator.py
Lokiiiiii fbed9cb
Update src/sagemaker/tensorflow/estimator.py
Lokiiiiii 7f7ad85
Fixing docstring syntax and auto-formatting
Lokiiiiii b988b8f
Adding more validation checks when using MWMS
Lokiiiiii 6185edc
fixing new test targeting mwms-smdist
Lokiiiiii dd679c8
Adding training script that leverages MultiWorkerMirroredStrategy
Lokiiiiii a37842b
Fixing merge conflict
Lokiiiiii 70e9ff3
Auto reformat with black
Lokiiiiii 3e7e9f8
Updating HF TF example to latest from repo
Lokiiiiii bc9edde
Switching to a simpler test for keras examples for MWMS
Lokiiiiii 0507792
Switching to a simpler test for keras examples for MWMS
Lokiiiiii 32da5f1
Removing stale test training scripts
Lokiiiiii aac0532
black -l 100
Lokiiiiii cf92c0d
python3.10 -m black -l 100
Lokiiiiii 5f4959e
Fixing unit tests for MWMS
Lokiiiiii 2c48152
Fixing unit tests for MWMS
Lokiiiiii c6ebb5d
Fixing unit tests for MWMS
Lokiiiiii 9a527b7
Removing unused fixtures
Lokiiiiii a838741
Merge remote-tracking branch 'aws/master' into mwms-2
Lokiiiiii File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
# https://www.tensorflow.org/tutorials/distribute/multi_worker_with_keras | ||
|
||
import json | ||
import os | ||
import tensorflow as tf | ||
import numpy as np | ||
|
||
|
||
def mnist_dataset(batch_size): | ||
(x_train, y_train), _ = tf.keras.datasets.mnist.load_data() | ||
# The `x` arrays are in uint8 and have values in the [0, 255] range. | ||
# You need to convert them to float32 with values in the [0, 1] range. | ||
x_train = x_train / np.float32(255) | ||
y_train = y_train.astype(np.int64) | ||
train_dataset = ( | ||
tf.data.Dataset.from_tensor_slices((x_train, y_train)) | ||
.shuffle(60000) | ||
.repeat() | ||
.batch(batch_size) | ||
) | ||
return train_dataset | ||
|
||
|
||
def build_and_compile_cnn_model(): | ||
model = tf.keras.Sequential( | ||
[ | ||
tf.keras.layers.InputLayer(input_shape=(28, 28)), | ||
tf.keras.layers.Reshape(target_shape=(28, 28, 1)), | ||
tf.keras.layers.Conv2D(32, 3, activation="relu"), | ||
tf.keras.layers.Flatten(), | ||
tf.keras.layers.Dense(128, activation="relu"), | ||
tf.keras.layers.Dense(10), | ||
] | ||
) | ||
model.compile( | ||
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), | ||
optimizer=tf.keras.optimizers.SGD(learning_rate=0.001), | ||
metrics=["accuracy"], | ||
) | ||
return model | ||
|
||
|
||
per_worker_batch_size = 64 | ||
tf_config = json.loads(os.environ["TF_CONFIG"]) | ||
num_workers = len(tf_config["cluster"]["worker"]) | ||
|
||
strategy = tf.distribute.MultiWorkerMirroredStrategy() | ||
|
||
global_batch_size = per_worker_batch_size * num_workers | ||
multi_worker_dataset = mnist_dataset(global_batch_size) | ||
|
||
with strategy.scope(): | ||
multi_worker_model = build_and_compile_cnn_model() | ||
|
||
multi_worker_model.fit(multi_worker_dataset, epochs=3, steps_per_epoch=70) | ||
|
||
print(f"strategy.num_replicas_in_sync={strategy.num_replicas_in_sync}") |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.